[feat] Track entropy and MI of routing distribution for topk MoE
✨ Description
To better detect potential routing collapse and have a better understanding about the routing distribution, we can track the average entropy and mutual information of routing probabilities.
Collapse routing would have low entropy and low mutual information. A healthy and specialised router would have low entropy and high mutual information, meaning that routing is specialised and considerably different across tokens.
More specifically: Mutual info. measures the difference between:
- The entropy of the average distribution across all tokens.
- The average of the individual token entropies.
🔍 Type of change
Select all that apply:
- [ ] 🐛 Bug fix (non-breaking change that addresses a specific issue)
- [x] 🚀 New feature (non-breaking change that adds functionality)
- [ ] ⚠️ Breaking change (a change that could affect existing functionality)
- [ ] 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
- [ ] 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
- [ ] 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
- [ ] 📝 Documentation change (updates documentation, including new content or typo fixes)
- [ ] 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)
📝 Changes
- added calculation of both metrics in the
mixture_of_experts.py, they are calculated only for the topk routing type.
✅ Checklist
General
- [x] 📜 I have read and followed the contributing guidelines.
- [x] 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
- [x] 🎉 The functionality is complete, and I have tested the changes.
- [x] ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
- [x] 🧩 I have commented my code, especially in hard-to-understand areas.
Testing
- [x] 🧪 I have added or updated tests to cover my changes.
- [ ] ✔️ New and existing tests pass locally with my changes.
- [x] 🚦 I have tested these changes on GPUs and verified training stability.
- [x] 🏋️ I have tested the changes on realistic training workloads, if applicable.
Performance Impact
- [ ] 📊 I have run benchmarks where applicable to evaluate the performance impact.
- [ ] ✅ The benchmarks show no performance regression.
- [ ] 🚀 The benchmarks indicate a potential performance improvement.
- [ ] ⚠️ The benchmarks indicate a potential performance degradation.
- [ ] 📈 I have provided benchmark results and detailed any performance impact below, if applicable.
📊 Performance Impact Details
I am not 100% sure there is no performance impact, we are calculating the stats at each forward pass through the router.
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.
Yes @tscholak, addressed. Using metrics dict instead.
@oleksost Are you still working on this?
@jlamypoirier yes, will address your comments today. Sorry, it was deprioritised in favour of mamba.
@jlamypoirier I think I addressed all the comments.