Incorrect reference responses when using PEFT with PPOTrainer
Below is a snippet from ppo_trainer.py
if generate_ref_response:
ref_model = self.model if self.is_peft_model else self.ref_model
if isinstance(query_tensor, List):
response = self._generate_batched(
self.model,
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
if generate_ref_response:
ref_response = self._generate_batched(
ref_model,
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
When training with PEFT, we have ref_model the same as the base model but instead called with a context to disable the adapters:
with torch.no_grad():
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
self.model,
queries,
responses,
model_inputs,
response_masks=response_masks,
return_logits=full_kl_penalty,
)
with self.optional_peft_ctx():
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
self.model if self.is_peft_model else self.ref_model,
queries,
responses,
model_inputs,
return_logits=full_kl_penalty,
)
However, code to generate reference responses doesn't use this context. This leads to the reference responses logged in the tables to come from the optimized RL model rather than the reference model.
To reproduce, run any training loop with the PPOTrainer with your logging software of choice -- my setup uses WandB -- and look at the table of responses. The reference responses will be drawn from the same distribution as the model responses. Below is a screenshot from a dummy run where I rewarded the model for outputting the word "but." The reference responses should not be any different after the loop.
I observed same problem with DPOTrainer. generate_during_eval=True in DPOConfig produces reference outputs from current model being trained.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Closing as PPOTrainer as been deprecated (replaced by PPOv2Trainer)