Support LM-loss and knowledge distillation together
🎯 Goal (What & Why)
Knowledge distillation was added in #229 , but it currently disables the standard LM loss. Enabling knowledge distillation and standard LM loss would allow to use distillation to prevent the model from diverging too much from a base model when training a model with MTP for example. In other scenarios, it could be beneficial to add LM-loss to distillation when the teacher was not trained on a particular dataset @nitsanluke
🚀 Execution Plan
Currently, we pass a single target to the lm-heads. It carries either labels coming from the dataset, or logits coming from a reference model.
Allow to pass both logits and labels to the model, and use all available targets to compute the loss: distillation, lm-loss, or both (or none).
Step 1: What is the smallest working version?
We can use both distillation and lm-loss. The student can have multiple prediction heads. Only the first head is trained with distillation loss (assuming the teacher has only one prediction head).
Step 2: What additional optimizations are possible (but optional)?
(List potential refinements that can be added in later PRs if needed.)
📌 Acceptance Criteria (Must-Haves for Completion)
- The feature must be functional and tested.
- The implementation must be documented in practical terms.
- The PR must include a performance/impact summary.
- No refactors unless directly necessary for feature completion.
🛠️ Project Management
- [x] Assign the project to the Fast-LLM project.
- [ ] Set the
Estimatefield (in days) in the GitHub project. - [ ] Use the
Sizefield to categorize the PR size (Small/Medium/Large). - [ ] Assign an owner when opening the issue.