[GRPO] initial GRPO trainer
Implementation of the DeepSeekMath GRPO: https://arxiv.org/pdf/2402.03300
Still a work in progress
- Will be adding iterative reward model training
- Only outcome supervision has been enabled, will be implementing process supervision later
closes #2103
Thank you for working on this nifty algorithm @saisurbehera ! I see you're basing your implementation on PPOTrainer but we've recently overhauled our RL implementations to be more aligned with the rest of the library, e.g. here's the new PPO version: https://github.com/huggingface/trl/blob/main/trl/trainer/ppov2_trainer.py
Would you mind adapting your implementation to this new API? Since GRPO is somewhat similar to RLOO, you might find it is possible to copy-paste a large part of that code: https://github.com/huggingface/trl/blob/main/trl/trainer/rloo_trainer.py
Sure, i can make the changes similar to PPOtrainerv2
Hello @lewtun ,
I ported the format to the new methodlogy, it was way simpler than the first version. I still have to do some validations and testing.
Thanks for you contribution!
I've also implement a version of GRPO trainer, instead of using a for loop in https://github.com/saisurbehera/trl/blob/grpo/trl/trainer/grpo_trainer.py#L380, I directly view it to (-1, sampling_group_size) and calculate the normalized_group_scores in a tensor-friendly way and then view it back to the original shape. I am not sure if this will help to optimize the performance.
I think we should test it out. Thanks a lot for the change.
Overall, I do think most of my work is done based on the limits of trl. We need more extensive changes to add PRM and reward model training.
Sorry for not spending some more time on it, I was busy at work and family stuff.
Hi, thanks for the PR, it would be great to have GRPO and looking forward to it!
@saisurbehera curious if it's ready to test? want to try my hand at hacking in PRM rewards
Go ahead
Hi, is there any updates? I would appreciate it if this could be merged!
Let me work over this weekend to verify, Sorry for the delay.
Looking forward to it!
REINFORCE++ is better than GRPO: https://www.researchgate.net/publication/387487679_REINFORCE_A_SIMPLE_AND_EFFICIENT_APPROACH_FOR_ALIGNING_LARGE_LANGUAGE_MODELS
Hi is there any updates?
Sorry for the late response. My code works now. The problem is the new model has very high KL divergence compared to the reference model. The scores compared to rloo don't look right. I have to debug as to why. Sorry for the it.
It's OK and looking forward to the fix!
Do you have example training dataset?