Question about average_log_prob
In your implementation, average_log_prob is set as False, which makes the log prob not being normalized by length.
Is there a reason to set average_log_prob as False? Is the performance of the model with average_log_prob=False better than the one with average_log_prob=True?
Hey @eric-mitchell, I've been wondering the same thing. Any info you could share on this?
+1 to this. By setting average_log_prob to False, we are incentivizing longer answer/response from the model, which seems concerning to me. @eric-mitchell any suggestion on why we want to promote longer answers from DPO?
I have another possible explanation from math.
We understand that $\pi_\theta(y_w) = \pi_\theta(y_w^1) \pi_\theta(y_w^2 | y_w^1) \cdots \pi_\theta(y_w^n | y_w^1\cdots y_w^{n_w-1})$, where $y_w^i$ denotes the i-th token of $y_w$.
Consequently, $\pi_\theta(y_w | x) = \pi_\theta(x | y_w)\pi_\theta(y_w)/\pi_\theta(x) = \pi_\theta(x | y_w) \pi_\theta(y_w^1) \pi_\theta(y_w^2 | y_w^1) \cdots \pi_\theta(y_w^{n_w} | y_w^1\cdots y_w^{n_w-1}) / {\pi_\theta(x)}$.
Similarly, $\pi_\theta(y_l | x) = {\pi_\theta(x | y_l)\pi_\theta(y_l)} / {\pi_\theta(x)} = \pi_\theta(x | y_l){\pi_\theta(y_l^1) \pi_\theta(y_l^2 | y_l^1) \cdots \pi_\theta(y_l^{n_l} | y_l^1\cdots y_l^{n_l-1})} / {\pi_\theta(x)}$.
Thus, $\log \pi_\theta(y_w | x) - \log \pi_\theta(y_l | x)$ $= \log \pi_\theta (x|y_w) + \sum_{i=1}^{n_{w}} \log \pi_\theta(y_w^i | y_w^1 \cdots y_w^{i-1}) - \log \pi_\theta (x|y_l) - \sum_{i=1}^{n_l} \log \pi_\theta(y_l^i | y_l^1 \cdots y_l^{i-1})$
Assuming $\log \pi_\theta (x|y_w) \approx \log \pi_\theta (x|y_l)$, the calculation in the code would be correct.
Even if the assumption $\log \pi_\theta (x|y_w) \approx \log \pi_\theta (x|y_l)$ is not made, the above derivation indicates that a sum of log probabilities is a more justifiable choice than an average of log probabilities.
BTW, due to $\pi_\theta(\cdot) \in [0, 1]$, we have $\log \pi (\cdot) \le 0$. Therefore, with more tokens, the sum of log probs becomes smaller, and thus the sum of log probs would not introduce a bias that makes the generated sentences longer.
I have another possible explanation from math.
We understand that πθ(yw)=πθ(yw1)πθ(yw2|yw1)⋯πθ(ywn|yw1⋯ywnw−1), where ywi denotes the i-th token of yw.
Consequently, πθ(yw|x)=πθ(x|yw)πθ(yw)/πθ(x)=πθ(x|yw)πθ(yw1)πθ(yw2|yw1)⋯πθ(ywnw|yw1⋯ywnw−1)/πθ(x).
Similarly, πθ(yl|x)=πθ(x|yl)πθ(yl)/πθ(x)=πθ(x|yl)πθ(yl1)πθ(yl2|yl1)⋯πθ(ylnl|yl1⋯ylnl−1)/πθ(x).
Thus, logπθ(yw|x)−logπθ(yl|x) =logπθ(x|yw)+∑i=1nwlogπθ(ywi|yw1⋯ywi−1)−logπθ(x|yl)−∑i=1nllogπθ(yli|yl1⋯yli−1)
Assuming logπθ(x|yw)≈logπθ(x|yl), the calculation in the code would be correct.
Even if the assumption logπθ(x|yw)≈logπθ(x|yl) is not made, the above derivation indicates that a sum of log probabilities is a more justifiable choice than an average of log probabilities.
BTW, due to πθ(⋅)∈[0,1], we have logπ(⋅)≤0. Therefore, with more tokens, the sum of log probs becomes smaller, and thus the sum of log probs would not introduce a bias that makes the generated sentences longer.
@Shenzhi-Wang Hi, Thank you for your explanation. However, setting average=True just introduce a fixed divisor factor in your equation, it shouldn't affect the overall properties of the equation, so why would it be less favorable? Could you explain why average=True might not be as good?
I have another possible explanation from math. We understand that πθ(yw)=πθ(yw1)πθ(yw2|yw1)⋯πθ(ywn|yw1⋯ywnw−1), where ywi denotes the i-th token of yw. Consequently, πθ(yw|x)=πθ(x|yw)πθ(yw)/πθ(x)=πθ(x|yw)πθ(yw1)πθ(yw2|yw1)⋯πθ(ywnw|yw1⋯ywnw−1)/πθ(x). Similarly, πθ(yl|x)=πθ(x|yl)πθ(yl)/πθ(x)=πθ(x|yl)πθ(yl1)πθ(yl2|yl1)⋯πθ(ylnl|yl1⋯ylnl−1)/πθ(x). Thus, logπθ(yw|x)−logπθ(yl|x) =logπθ(x|yw)+∑i=1nwlogπθ(ywi|yw1⋯ywi−1)−logπθ(x|yl)−∑i=1nllogπθ(yli|yl1⋯yli−1) Assuming logπθ(x|yw)≈logπθ(x|yl), the calculation in the code would be correct. Even if the assumption logπθ(x|yw)≈logπθ(x|yl) is not made, the above derivation indicates that a sum of log probabilities is a more justifiable choice than an average of log probabilities. BTW, due to πθ(⋅)∈[0,1], we have logπ(⋅)≤0. Therefore, with more tokens, the sum of log probs becomes smaller, and thus the sum of log probs would not introduce a bias that makes the generated sentences longer.
@Shenzhi-Wang Hi, Thank you for your explanation. However, setting average=True just introduce a fixed divisor factor in your equation, it shouldn't affect the overall properties of the equation, so why would it be less favorable? Could you explain why average=True might not be as good?
Hi yata0,
We believe directly using average=True leads to the vanish of supervision signal, thus harm the performance, as shown in the Figure 2 of our paper: https://arxiv.org/pdf/2406.10957. You can have a further scaling factor to enlarge the signal.
best, Junru