Fix shapes in modular_gpt_oss.py
What does this PR do?
This PR fixes a comment in modular_gpt_oss.py which has the incorrect shape written in the description for GptOssExperts. I fix the annotated shape for routing_experts from (batch_size * token_num, num_experts) to (batch_size * token_num, top_k). Looking at the git blame, I think the original implementation of GPT OSS had the original shape which was later modified but the comment was not edited. I also made a few other edits to the shape annotations for MoE in this file for consistency (assuming num_tokens = batch_size * seq_len). Please let me know if the changes make sense!
Before submitting
- [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
@ArthurZucker
[For maintainers] Suggested jobs to run (before merge)
run-slow: gpt_oss
Done!