DeepSpeedExamples icon indicating copy to clipboard operation
DeepSpeedExamples copied to clipboard

fix: the source part should not participate in loss calculation in SFT stage

Open xffxff opened this issue 2 years ago • 1 comments

fix https://github.com/microsoft/DeepSpeedExamples/issues/660

In the SFT stage, it's essential that the source part doesn't contribute to the loss calculation, only the completion part should be considered. To address this issue, I've adjusted the labels for the source part to be set as -100. This specific value, -100, corresponds to the default "ignore index" in the torch.nn.CrossEntropyLoss function. Importantly, both OPT and LLAMA models utilize torch.nn.CrossEntropyLoss for their loss calculations, as seen in OPTForCausalLM and LLamaForCausalLM. As a result, there is no need to make any modifications to the way the loss is computed, as it will automatically handle the source part as intended.

The training loss of opt-350m with Dahoas/rm-static as dataset image

xffxff avatar Oct 10 '23 11:10 xffxff

  1. The function is now moved to "DeepSpeedExamples/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py"
  2. This solution seems to assume single-turn conversation, please consider cases for multi-turn conversation.

AndyW-llm avatar Nov 28 '23 20:11 AndyW-llm