activation-level disillation
✨ Description
Closes #385
TODOs:
- [ ] TP / sequence-tensor parallel: inconsistent gradients between TP=1 and TP=2
- [ ] Add tests: train with student==teacher, check that all losses are
0and gradients as well.
Sanity checks:
- loading student and teacher from the same pretrained model gives
0loss ✔️. But loss then increases to a small value instead of staying at 0. - Distilling from scratch with the same architecture doesn't lead to
0loss (orange) - Distilling the pretrained model, but with a sliding-window leads to low loss, lower with a larger window (green and purple)
| lm-loss | Logit distillation | Logit + Activation distillation | |
|---|---|---|---|
| Tokens/s/gpu | 3500 | 2900 | 2800 |
| max_reserved (GB) | 44 | 77 | 78 |
With the caveat that distillation seems to experience memory spikes at specific points in training. The actual usage was lower most of the time:
🔍 Type of change
Select all that apply:
- [ ] 🐛 Bug fix (non-breaking change that addresses a specific issue)
- [x] 🚀 New feature (non-breaking change that adds functionality)
- [ ] ⚠️ Breaking change (a change that could affect existing functionality)
- [ ] 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
- [ ] 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
- [ ] 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
- [ ] 📝 Documentation change (updates documentation, including new content or typo fixes)
- [ ] 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)
Testing
- [ ] 🧪 I have added or updated tests to cover my changes.
- [ ] ✔️ New and existing tests pass locally with my changes.
- [ ] 🚦 I have tested these changes on GPUs and verified training stability.
- [ ] 🏋️ I have tested the changes on realistic training workloads, if applicable.
Performance Impact
- [ ] 📊 I have run benchmarks where applicable to evaluate the performance impact.
- [ ] ✅ The benchmarks show no performance regression.
- [ ] 🚀 The benchmarks indicate a potential performance improvement.
- [ ] ⚠️ The benchmarks indicate a potential performance degradation.
- [ ] 📈 I have provided benchmark results and detailed any performance impact below, if applicable.
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
great progress! did you freeze everything except the randomly initialized mixers?
Another sanity check: Student is initialized from the teacher, but with randomly initialized attention layers. Then we distill activations while freezing the MLPs.
The loss (grey) is still quite far from 0 (similarly to the orange run where the student is initialized completely from scratch)
Resetting and distilling only one layer, freezing the rest of the model gives satisfactory results:
Note some changes were required to allow loading a pretrained model while freezing certain layers (#394 )
Now getting similar loss curves with TP=1, TP=2, and STP=2
Thank you for the reviews! The comments are addressed, could you have another look? @jlamypoirier