fix: the source part should not participate in loss calculation in SFT stage
fix https://github.com/microsoft/DeepSpeedExamples/issues/660
In the SFT stage, it's essential that the source part doesn't contribute to the loss calculation, only the completion part should be considered. To address this issue, I've adjusted the labels for the source part to be set as -100. This specific value, -100, corresponds to the default "ignore index" in the torch.nn.CrossEntropyLoss function. Importantly, both OPT and LLAMA models utilize torch.nn.CrossEntropyLoss for their loss calculations, as seen in OPTForCausalLM and LLamaForCausalLM. As a result, there is no need to make any modifications to the way the loss is computed, as it will automatically handle the source part as intended.
The training loss of opt-350m with Dahoas/rm-static as dataset
- The function is now moved to "DeepSpeedExamples/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py"
- This solution seems to assume single-turn conversation, please consider cases for multi-turn conversation.