[trainer] feat: Self-Normalized Importance Sampling
Self-Normalized Importance Sampling for rollout:backwards mismatch, adds algorithm.rollout_is_self_norm
SNIS applied to rollout_is_weights
• geo_mean: per-sequence geometric mean
• seq-mean-token-mean / seq-mean-token-sum: per-sequence masked mean/sum
• token-mean, seq-mean-token-sum-norm: global denominator
Given $w_i=\dfrac{p(x_i)}{q(x_i)}$, the self-normalized estimator is
$$\widehat{\mu}{\text{SNIS}}=\frac{\sum{i=1}^{N} w_i\cdot f(x_i)}{\sum_{i=1}^{N} w_i}$$
algorithm:
rollout_is: true
rollout_is_self_norm: true
Example
Experimental, only geo_mean has been properly tested, please test yourself, most of these are not standard SNIS
Sequence index $b$, token $t$, mask $m_{b t}\in{0,1}$, per-token IS weights $w_{b t}>0$
Per-sequence $w'_{bt}=\tfrac{w_{bt}}{d_b}$
-
geo_mean$\quad d_b=\exp\Bigg(\frac{\sum_t m_{bt}\cdot \log w_{bt}}{\sum_t m_{bt}}\Bigg)$ -
seq-mean-token-mean$\quad d_b=\frac{\sum_t m_{bt}\cdot w_{bt}}{\sum_t m_{bt}}$ -
seq-mean-token-sum$\quad d_b=\sum_t m_{bt}\cdot w_{bt}$
Global $w'_{bt}=\tfrac{w_{bt}}{d}$
-
token_mean$\quad d=\frac{\sum_{b,t} m_{bt}\cdot w_{bt}}{\sum_{b,t} m_{bt}}$ -
seq-mean-token-sum-normgiven $T$ token dimension lengthweights_full.shape[-1]$\quad d=\frac{\sum_{b,t} m_{bt}\cdot w_{bt}}{T}$
There's a break change in https://github.com/volcengine/verl/pull/3984
@szrlee Can you help review this PR?
In general this wraps any IS implementation and it doesn't seem much needs to change from the conflicting PR other than the config and checking that the proper parametres are sent to dp_actor from ray_trainer, if @szrlee can confirm the PR he has will be stable I can refactor, but I won't refactor a second time
since rollout_is_batch_normalize, have to make sure it's normalzied across DP then it should be fine
@szrlee
@szrlee
Thank you for the refactor. I will check later.
@tongyx361 it looks good to me.