Fast-LLM
Fast-LLM copied to clipboard
Activation/feature-level distillation
🎯 Goal (What & Why)
Add activation-level distillation, usually leading to better student performance.
🚀 Execution Plan
Step 1: What is the smallest working version?
- Distill based on all mixer-layer outputs
- Support the case where the student has the same number of layers as the teacher.
- Use MSE loss.
- Add a single coefficient that balances feature-level distillation vs logit-level.
As a first version: $$L = L_{\text{logit}} + \lambda L_{\text{activation}}$$ with: $$\mathcal{L}_{\text{activation}} = \frac{1}{N} \sum_{i=1}^{N}||T_i(\mathbf{x}) - S_i(\mathbf{x})||_2$$ $T_i(x)$ and $S_i(x)$ denoting the outputs of the i-th layer's mixer of the teacher and student models.
Details:
- teacher stores intermediate activations in
kwargs - student uses
kwargsto compute activation-distillation losses, which it stores inlosses
Step 2: What additional optimizations are possible (but optional)?
- Should support TP with sequence-parallelism (this is actually not optional, but can be done in a second step)
- Can configure which layers' outputs are used for distillation. For example, we could distill only based on mixer-layer outputs, or also based on MLP outputs, etc. Pass a {student -> teacher} mapping of layer-names to use for distillation
- Configurable loss: MSE, cosine, others?
📌 Acceptance Criteria (Must-Haves for Completion)
- The feature must be functional and tested.
- The implementation must be documented in practical terms.
- The PR must include a performance/impact summary.
- No refactors unless directly necessary for feature completion.
🛠️ Project Management
- [ ] Assign the project to the Fast-LLM project.
- [ ] Set the
Estimatefield (in days) in the GitHub project. - [ ] Use the
Sizefield to categorize the PR size (Small/Medium/Large). - [ ] Assign an owner when opening the issue.