trl
trl copied to clipboard
Question about value indexing in batched_forward_pass() function
Thank you for your great work!
I read issue #15 but I still don't understand why values should be shifted left in PPOTrainer.batched_forward_pass()
https://github.com/lvwerra/trl/blob/master/trl/ppo.py#L203 .
In #L201, start is already indicating the start index of the model's output's predicted next tokens.
Also, in original code from https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/policy.py#L125 looks like they indexed the logprobs and values at the same position.
Thank you!