fms-fsdp
fms-fsdp copied to clipboard
add Mamba-MoE training support
A detailed list of TODOs
Mamba repo
- [x] create a Mamba-MoE branch in Mamba repo @fabianlim
FMS-FSDP repo
- [x] add mamba moe configs
- [x] modify loss to moe-load-balancing-loss
Job yaml file
- [x] add extra dependencies for mamba moe
- [x] switch to the mamba moe fork
Prepare config
- [x] 30b, 8 active experts, 64 total experts
- [x] 120b, 16 active experts, 256 total experts
cc @raghukiran1224
30b:
>>> model_config = {
... "d_model": 6144,
... "d_intermediate": 336,
... "n_layer": 48,
... "vocab_size": 128256,
... "ssm_cfg": {"layer": "Mamba2"},
... "attn_layer_idx": [9, 18, 27, 36, 45],
... "attn_cfg": {
... "causal": True,
... "d_conv": 0,
... "head_dim": 128,
... "num_heads": 48,
... "num_heads_kv": 8,
... "out_proj_bias": False,
... "qkv_proj_bias": False,
... "rotary_emb_dim": 64,
... },
... "mlp_cfg": {"n_expert": 64, "load_balancing_loss": True, "top_k": 8},
... "rms_norm": True,
... "residual_in_fp32": True,
... "fused_add_norm": True,
... "pad_vocab_size_multiple": 16,
... "tie_embeddings": False,
... }
>>> mamba_config = MambaConfig(**config_data)
>>> with torch.device("meta"):
... model = MambaLMHeadModel(mamba_config)
...
>>> sum(p.numel() for p in model.parameters() if p.requires_grad)
30903152576
120b:
>>> model_config = {
... "d_model": 8192,
... "d_intermediate": 112,
... "n_layer": 108,
... "vocab_size": 128256,
... "ssm_cfg": {"layer": "Mamba2"},
... "attn_layer_idx": [9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99],
... "attn_cfg": {
... "causal": True,
... "d_conv": 0,
... "head_dim": 128,
... "num_heads": 64,
... "num_heads_kv": 8,
... "out_proj_bias": False,
... "qkv_proj_bias": False,
... "rotary_emb_dim": 64,
... },
... "mlp_cfg": {"n_expert": 256, "load_balancing_loss": True, "top_k": 16},
... "rms_norm": True,
... "residual_in_fp32": True,
... "fused_add_norm": True,
... "pad_vocab_size_multiple": 16,
... "tie_embeddings": False,
... }
>>> mamba_config = MambaConfig(**config_data)
>>> with torch.device("meta"):
... model = MambaLMHeadModel(mamba_config)
...
>>> sum(p.numel() for p in model.parameters() if p.requires_grad)
119339460608