trl icon indicating copy to clipboard operation
trl copied to clipboard

[GRPO] initial GRPO trainer

Open saisurbehera opened this issue 1 year ago • 3 comments

Implementation of the DeepSeekMath GRPO: https://arxiv.org/pdf/2402.03300

Still a work in progress

  • Will be adding iterative reward model training
  • Only outcome supervision has been enabled, will be implementing process supervision later

closes #2103

saisurbehera avatar Aug 21 '24 02:08 saisurbehera

Thank you for working on this nifty algorithm @saisurbehera ! I see you're basing your implementation on PPOTrainer but we've recently overhauled our RL implementations to be more aligned with the rest of the library, e.g. here's the new PPO version: https://github.com/huggingface/trl/blob/main/trl/trainer/ppov2_trainer.py

Would you mind adapting your implementation to this new API? Since GRPO is somewhat similar to RLOO, you might find it is possible to copy-paste a large part of that code: https://github.com/huggingface/trl/blob/main/trl/trainer/rloo_trainer.py

lewtun avatar Aug 21 '24 08:08 lewtun

Sure, i can make the changes similar to PPOtrainerv2

saisurbehera avatar Aug 21 '24 14:08 saisurbehera

Hello @lewtun ,

I ported the format to the new methodlogy, it was way simpler than the first version. I still have to do some validations and testing.

saisurbehera avatar Aug 22 '24 03:08 saisurbehera

Thanks for you contribution!

I've also implement a version of GRPO trainer, instead of using a for loop in https://github.com/saisurbehera/trl/blob/grpo/trl/trainer/grpo_trainer.py#L380, I directly view it to (-1, sampling_group_size) and calculate the normalized_group_scores in a tensor-friendly way and then view it back to the original shape. I am not sure if this will help to optimize the performance.

Namco0816 avatar Nov 28 '24 06:11 Namco0816

I think we should test it out. Thanks a lot for the change.

Overall, I do think most of my work is done based on the limits of trl. We need more extensive changes to add PRM and reward model training.

Sorry for not spending some more time on it, I was busy at work and family stuff.

saisurbehera avatar Nov 29 '24 01:11 saisurbehera

Hi, thanks for the PR, it would be great to have GRPO and looking forward to it!

fzyzcjy avatar Nov 30 '24 06:11 fzyzcjy

@saisurbehera curious if it's ready to test? want to try my hand at hacking in PRM rewards

rawsh avatar Dec 17 '24 10:12 rawsh

Go ahead

saisurbehera avatar Dec 17 '24 15:12 saisurbehera

Hi, is there any updates? I would appreciate it if this could be merged!

fzyzcjy avatar Dec 22 '24 09:12 fzyzcjy

Let me work over this weekend to verify, Sorry for the delay.

saisurbehera avatar Dec 25 '24 19:12 saisurbehera

Looking forward to it!

fzyzcjy avatar Dec 25 '24 23:12 fzyzcjy

REINFORCE++ is better than GRPO: https://www.researchgate.net/publication/387487679_REINFORCE_A_SIMPLE_AND_EFFICIENT_APPROACH_FOR_ALIGNING_LARGE_LANGUAGE_MODELS

hijkzzz avatar Dec 27 '24 03:12 hijkzzz

Hi is there any updates?

fzyzcjy avatar Dec 31 '24 23:12 fzyzcjy

Sorry for the late response. My code works now. The problem is the new model has very high KL divergence compared to the reference model. The scores compared to rloo don't look right. I have to debug as to why. Sorry for the it.

saisurbehera avatar Jan 02 '25 03:01 saisurbehera

It's OK and looking forward to the fix!

fzyzcjy avatar Jan 02 '25 03:01 fzyzcjy

Do you have example training dataset?

ehartford avatar Jan 05 '25 20:01 ehartford