Fast-LLM icon indicating copy to clipboard operation
Fast-LLM copied to clipboard

Support LM-loss and knowledge distillation together

Open RaymondLi0 opened this issue 9 months ago • 0 comments

🎯 Goal (What & Why)

Knowledge distillation was added in #229 , but it currently disables the standard LM loss. Enabling knowledge distillation and standard LM loss would allow to use distillation to prevent the model from diverging too much from a base model when training a model with MTP for example. In other scenarios, it could be beneficial to add LM-loss to distillation when the teacher was not trained on a particular dataset @nitsanluke

🚀 Execution Plan

Currently, we pass a single target to the lm-heads. It carries either labels coming from the dataset, or logits coming from a reference model. Allow to pass both logits and labels to the model, and use all available targets to compute the loss: distillation, lm-loss, or both (or none).

Step 1: What is the smallest working version?

We can use both distillation and lm-loss. The student can have multiple prediction heads. Only the first head is trained with distillation loss (assuming the teacher has only one prediction head).

Step 2: What additional optimizations are possible (but optional)?

(List potential refinements that can be added in later PRs if needed.)

📌 Acceptance Criteria (Must-Haves for Completion)

  • The feature must be functional and tested.
  • The implementation must be documented in practical terms.
  • The PR must include a performance/impact summary.
  • No refactors unless directly necessary for feature completion.

🛠️ Project Management

  • [x] Assign the project to the Fast-LLM project.
  • [ ] Set the Estimate field (in days) in the GitHub project.
  • [ ] Use the Size field to categorize the PR size (Small/Medium/Large).
  • [ ] Assign an owner when opening the issue.

RaymondLi0 avatar May 09 '25 14:05 RaymondLi0