trl icon indicating copy to clipboard operation
trl copied to clipboard

Incorrect reference responses when using PEFT with PPOTrainer

Open Sean-OB opened this issue 1 year ago • 2 comments

Below is a snippet from ppo_trainer.py

Line permalink

if generate_ref_response:
            ref_model = self.model if self.is_peft_model else self.ref_model
        if isinstance(query_tensor, List):
            response = self._generate_batched(
                self.model,
                query_tensor,
                length_sampler=length_sampler,
                batch_size=batch_size,
                return_prompt=return_prompt,
                **generation_kwargs,
            )
            if generate_ref_response:
                ref_response = self._generate_batched(
                    ref_model,
                    query_tensor,
                    length_sampler=length_sampler,
                    batch_size=batch_size,
                    return_prompt=return_prompt,
                    **generation_kwargs,
                )

When training with PEFT, we have ref_model the same as the base model but instead called with a context to disable the adapters:

with torch.no_grad():
            all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
                self.model,
                queries,
                responses,
                model_inputs,
                response_masks=response_masks,
                return_logits=full_kl_penalty,
            )
            with self.optional_peft_ctx():
                ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
                    self.model if self.is_peft_model else self.ref_model,
                    queries,
                    responses,
                    model_inputs,
                    return_logits=full_kl_penalty,
                )

However, code to generate reference responses doesn't use this context. This leads to the reference responses logged in the tables to come from the optimized RL model rather than the reference model.

To reproduce, run any training loop with the PPOTrainer with your logging software of choice -- my setup uses WandB -- and look at the table of responses. The reference responses will be drawn from the same distribution as the model responses. Below is a screenshot from a dummy run where I rewarded the model for outputting the word "but." The reference responses should not be any different after the loop.

image

Sean-OB avatar Jul 24 '24 17:07 Sean-OB

I observed same problem with DPOTrainer. generate_during_eval=True in DPOConfig produces reference outputs from current model being trained.

skylooop avatar Jul 26 '24 08:07 skylooop

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Aug 24 '24 15:08 github-actions[bot]

Closing as PPOTrainer as been deprecated (replaced by PPOv2Trainer)

qgallouedec avatar Oct 20 '24 15:10 qgallouedec