direct-preference-optimization icon indicating copy to clipboard operation
direct-preference-optimization copied to clipboard

Computing faster lopgs

Open alexvishnevskiy opened this issue 1 year ago • 3 comments

Hi, great work! The results and research in this area are truly amazing. I have a question regarding the concatenated_forward part. From my understanding, we just need logs from both chosen and rejected responses. Why can't we have a batch that consists of [prompt + chosen_response + rejected_response] instead of [prompt + chosen_response, prompt + rejected_response]? It should be okay to calculate logps for both chosen and rejected responses without them intersecting with each other, using an attention mask. Correct me if I'm wrong, thanks!

def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
        
           We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = concatenated_inputs(batch)
        all_logits = model(concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32)
        all_logps = _get_batch_logps(all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False)
        chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]]
        rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:]
        return chosen_logps, rejected_logps

alexvishnevskiy avatar Mar 09 '24 03:03 alexvishnevskiy

Also, how do you assure that when doing model fwd step, the prompt+rejected do not attend to the chosen response, at what place in code is this check made?

bhavyashahh avatar Jul 29 '24 21:07 bhavyashahh

I think you are misunderstanding the implementation. They are concatenated in the batch dimension in order to get the logps for both in one forward pass instead of two. They are not concatenated in the sequence dimension so they will not attend to each other.

cthorrez avatar Jul 31 '24 04:07 cthorrez

yes, i did not carefully read concat on dim=0. thank you for pointing out.

bhavyashahh avatar Jul 31 '24 06:07 bhavyashahh