fms-fsdp icon indicating copy to clipboard operation
fms-fsdp copied to clipboard

add Mamba-MoE training support

Open lchu6 opened this issue 1 year ago • 2 comments

A detailed list of TODOs

Mamba repo

  • [x] create a Mamba-MoE branch in Mamba repo @fabianlim

FMS-FSDP repo

  • [x] add mamba moe configs
  • [x] modify loss to moe-load-balancing-loss

Job yaml file

  • [x] add extra dependencies for mamba moe
  • [x] switch to the mamba moe fork

Prepare config

  • [x] 30b, 8 active experts, 64 total experts
  • [x] 120b, 16 active experts, 256 total experts

lchu6 avatar Jan 23 '25 18:01 lchu6

cc @raghukiran1224

lchu6 avatar Jan 23 '25 18:01 lchu6

30b:

>>> model_config = {
...             "d_model": 6144,
...             "d_intermediate": 336,
...             "n_layer": 48,
...             "vocab_size": 128256,
...             "ssm_cfg": {"layer": "Mamba2"},
...             "attn_layer_idx": [9, 18, 27, 36, 45],
...             "attn_cfg": {
...                 "causal": True,
...                 "d_conv": 0,
...                 "head_dim": 128,
...                 "num_heads": 48,
...                 "num_heads_kv": 8,
...                 "out_proj_bias": False,
...                 "qkv_proj_bias": False,
...                 "rotary_emb_dim": 64,
...             },
...             "mlp_cfg": {"n_expert": 64, "load_balancing_loss": True, "top_k": 8},
...             "rms_norm": True,
...             "residual_in_fp32": True,
...             "fused_add_norm": True,
...             "pad_vocab_size_multiple": 16,
...             "tie_embeddings": False,
...         }
>>> mamba_config = MambaConfig(**config_data)
>>> with torch.device("meta"):
...     model = MambaLMHeadModel(mamba_config)
... 
>>> sum(p.numel() for p in model.parameters() if p.requires_grad)
30903152576

120b:

>>> model_config = {
...             "d_model": 8192,
...             "d_intermediate": 112,
...             "n_layer": 108,
...             "vocab_size": 128256,
...             "ssm_cfg": {"layer": "Mamba2"},
...             "attn_layer_idx": [9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99],
...             "attn_cfg": {
...                 "causal": True,
...                 "d_conv": 0,
...                 "head_dim": 128,
...                 "num_heads": 64,
...                 "num_heads_kv": 8,
...                 "out_proj_bias": False,
...                 "qkv_proj_bias": False,
...                 "rotary_emb_dim": 64,
...             },
...             "mlp_cfg": {"n_expert": 256, "load_balancing_loss": True, "top_k": 16},
...             "rms_norm": True,
...             "residual_in_fp32": True,
...             "fused_add_norm": True,
...             "pad_vocab_size_multiple": 16,
...             "tie_embeddings": False,
...         }
>>> mamba_config = MambaConfig(**config_data)
>>> with torch.device("meta"):
...     model = MambaLMHeadModel(mamba_config)
... 
>>> sum(p.numel() for p in model.parameters() if p.requires_grad)
119339460608

lchu6 avatar Jan 23 '25 21:01 lchu6