Nemotron-H support
🎯 Goal (What & Why)
Add support for training Nemotron-H models.
Nemotron-H is a family of hybrid SSM-Transformer models (8B, 47B, 56B) trained by NVIDIA in FP8 on 20T tokens. They combine strong accuracy with up to 3x faster inference, making them attractive candidates for instruction tuning, continual pretraining, and up-cycling (e.g., via added layers). NVIDIA plans to release weights and reference implementations (HuggingFace, NeMo, Megatron-LM), but nothing is public yet.
We want to be ready to support Nemotron-H in Fast-LLM once details are available. These models require both Transformer and Mamba-2 layers. The forward and backward pass must be implemented. FP8 support is likely not required for initial work (we'll use BF16), but is tracked separately in #63. Mamba-2 support is tracked in #68 and is a prerequisite for this ticket.
Partially enabled by #242.
🚀 Execution Plan
Step 1: What is the smallest working version?
- Wait for NVIDIA to release architectural details.
- Implement Nemotron-H block structure (interleaved Mamba-2, MLP, and Transformer layers).
- Use BF16 as default precision for now.
- Validate by running short training in Fast-LLM with Nemotron-H-8B or 47B weights (if available).
Step 2: What additional optimizations are possible (but optional)?
- Add support for FP8 training (#63) once the NVIDIA recipe is available and stable.
- Extend support for VLM variants if we decide to do multimodal training.
- Optimize long-context support (128k tokens) once basic functionality is in place.
📌 Acceptance Criteria
- Nemotron-H models (at least 8B) can be loaded and trained in Fast-LLM.
- Both forward and backward passes are implemented and tested.
- Training is functional in BF16.
- Code is documented clearly and includes a usage example.
- Performance/impact summary is provided in the PR.
- FP8 support is not part of this ticket but should not block anything here.
🛠️ Project Management
- [x] Assign the project to the Fast-LLM project.
- [ ] Set the
Estimatefield (in days) in the GitHub project. - [x] Use the
Sizefield to categorize the PR size (Small/Medium/Large). - [ ] Assign an owner when opening the issue.