verl icon indicating copy to clipboard operation
verl copied to clipboard

[trainer] feat: Self-Normalized Importance Sampling

Open EduardDurech opened this issue 3 months ago • 7 comments

Self-Normalized Importance Sampling for rollout:backwards mismatch, adds algorithm.rollout_is_self_norm

SNIS applied to rollout_is_weightsgeo_mean: per-sequence geometric mean • seq-mean-token-mean / seq-mean-token-sum: per-sequence masked mean/sum • token-mean, seq-mean-token-sum-norm: global denominator

Given $w_i=\dfrac{p(x_i)}{q(x_i)}$, the self-normalized estimator is

$$\widehat{\mu}{\text{SNIS}}=\frac{\sum{i=1}^{N} w_i\cdot f(x_i)}{\sum_{i=1}^{N} w_i}$$

algorithm:
  rollout_is: true
  rollout_is_self_norm: true

Example image

Experimental, only geo_mean has been properly tested, please test yourself, most of these are not standard SNIS

Sequence index $b$, token $t$, mask $m_{b t}\in{0,1}$, per-token IS weights $w_{b t}>0$

Per-sequence $w'_{bt}=\tfrac{w_{bt}}{d_b}$

  • geo_mean $\quad d_b=\exp\Bigg(\frac{\sum_t m_{bt}\cdot \log w_{bt}}{\sum_t m_{bt}}\Bigg)$
  • seq-mean-token-mean $\quad d_b=\frac{\sum_t m_{bt}\cdot w_{bt}}{\sum_t m_{bt}}$
  • seq-mean-token-sum $\quad d_b=\sum_t m_{bt}\cdot w_{bt}$

Global $w'_{bt}=\tfrac{w_{bt}}{d}$

  • token_mean $\quad d=\frac{\sum_{b,t} m_{bt}\cdot w_{bt}}{\sum_{b,t} m_{bt}}$
  • seq-mean-token-sum-norm given $T$ token dimension length weights_full.shape[-1] $\quad d=\frac{\sum_{b,t} m_{bt}\cdot w_{bt}}{T}$

EduardDurech avatar Oct 31 '25 18:10 EduardDurech

There's a break change in https://github.com/volcengine/verl/pull/3984

wuxibin89 avatar Nov 03 '25 02:11 wuxibin89

@szrlee Can you help review this PR?

wuxibin89 avatar Nov 03 '25 02:11 wuxibin89

In general this wraps any IS implementation and it doesn't seem much needs to change from the conflicting PR other than the config and checking that the proper parametres are sent to dp_actor from ray_trainer, if @szrlee can confirm the PR he has will be stable I can refactor, but I won't refactor a second time

EduardDurech avatar Nov 03 '25 10:11 EduardDurech

since rollout_is_batch_normalize, have to make sure it's normalzied across DP then it should be fine

EduardDurech avatar Nov 16 '25 15:11 EduardDurech

@szrlee

EduardDurech avatar Nov 22 '25 19:11 EduardDurech

@szrlee

Thank you for the refactor. I will check later.

szrlee avatar Nov 23 '25 07:11 szrlee

@tongyx361 it looks good to me.

szrlee avatar Nov 24 '25 08:11 szrlee