Some Bugs in JetMoE
System Info
transformers version: 4.43.0.dev0 (installed from source)
Who can help?
@ArthurZucker
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [X] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
Outline: There are a couple of bugs that cause JetMoE to not be able to output logits for gating and calculate aux_loss.
- Code I want to output the logits of the gating.
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
AutoModelForSequenceClassification,
)
import os
import torch
BASE_DIR = "model_ckpt"
# from jetmoe import JetMoEForCausalLM, JetMoEConfig, JetMoEForSequenceClassification
model_name = os.path.join(BASE_DIR, "jetmoe-8b")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="auto", device_map="auto"
)
output = model.forward(
torch.zeros(32, 12, device="cuda", dtype=torch.long),
output_router_logits=True,
return_dict=True,
)
-
It will report an error: Traceback (most recent call last): File "/home/ubuntu/ssk/test_jetmoe.py", line 18, in
output = model.forward( File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/transformers/models/jetmoe/modeling_jetmoe.py", line 1365, in forward self.num_experts, File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1709, in getattr raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'") AttributeError: 'JetMoeForCausalLM' object has no attribute 'num_experts' -
Analysis After examination of the code (https://github.com/huggingface/transformers/blob/main/src/transformers/models/jetmoe/modeling_jetmoe.py), I found serval mistakes:
-
self.num_expertsandself.num_experts_per_tokare not defined in theJetMoeForCausalLMclass. - the code does not pass
output_router_logitsargument to the forward function ofself.modelinJetMoeForCausalLMclass. (see line 1310 and 1341, modeling_jetmoe.py) - for the
JetMoeForSequenceClassificationclass, it misses the process of calculating aux_loss and forgets to passoutput_router_logitsargument toself.model.forward.
- Quick fix of the
JetMoeForCausalLMclass
- Add
self.num_experts = config.num_local_experts, andself.num_experts_per_tok = config.num_experts_per_tokin the__init__function of theJetMoeForCausalLM. - Pass
output_router_logitstoself.model.forward(line 1331)# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, output_router_logits=output_router_logits # Add this line. )
Expected behavior
The solution has been described in the previous section.
Thanks @Phoenix-Shen! Let me cc @yikangshen, who has contributed the model.
Hi @Phoenix-Shen, thanks for bringing up the issue! Your fix looks good to me. Would you like to submit a PR?
Hi @Phoenix-Shen, thanks for bringing up the issue! Your fix looks good to me. Would you like to submit a PR?
Ok, I've fixed all the bugs and am ready to submit a PR.
Thanks, reviewed!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.