[feat] Option to disable top-k routing weights normalization
🧐 Problem Description
OLMoE has disabled the normalization for top-k routing probabilities. There is no clear motivation or ablation for why this was done. DeepSeekMoE also disables top-k normalization, while Mixtral-8x7b-v0.1 normalizes them.
💡 Proposed Solution
Apply softmax before torch.topk in https://github.com/ServiceNow/Fast-LLM/blob/51d57158d625883da189bcce3af3c8908e527824/fast_llm/layers/transformer/mixture_of_experts.py#L167
🔄 Alternatives Considered
Normalize top-k scores as usual, since there's no clear motivation for the same. Good thing is that it's config driven in the HF implementation for OLMoE
📈 Potential Benefits
No clear benefits, but it could instead slow down training by a bit since now we're applying softmax on logits from all experts.
📝 Additional Context
See OLMoE implementation for reference
Reply from the OLMoE authors for reference. There might not be strong reasons to do this as of now
https://huggingface.co/allenai/OLMoE-1B-7B-0924/discussions/4#674e9e974232e14c900fa528