transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Some Bugs in JetMoE

Open Phoenix-Shen opened this issue 1 year ago • 4 comments

System Info

transformers version: 4.43.0.dev0 (installed from source)

Who can help?

@ArthurZucker

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

Outline: There are a couple of bugs that cause JetMoE to not be able to output logits for gating and calculate aux_loss.

  1. Code I want to output the logits of the gating.
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    AutoModelForSequenceClassification,
)
import os
import torch

BASE_DIR = "model_ckpt"
# from jetmoe import JetMoEForCausalLM, JetMoEConfig, JetMoEForSequenceClassification
model_name = os.path.join(BASE_DIR, "jetmoe-8b")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto"
)

output = model.forward(
    torch.zeros(32, 12, device="cuda", dtype=torch.long),
    output_router_logits=True,
    return_dict=True,
)
  1. It will report an error: Traceback (most recent call last): File "/home/ubuntu/ssk/test_jetmoe.py", line 18, in output = model.forward( File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/accelerate/hooks.py", line 166, in new_forward output = module._old_forward(*args, **kwargs) File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/transformers/models/jetmoe/modeling_jetmoe.py", line 1365, in forward self.num_experts, File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1709, in getattr raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'") AttributeError: 'JetMoeForCausalLM' object has no attribute 'num_experts'

  2. Analysis After examination of the code (https://github.com/huggingface/transformers/blob/main/src/transformers/models/jetmoe/modeling_jetmoe.py), I found serval mistakes:

  • self.num_experts and self.num_experts_per_tok are not defined in the JetMoeForCausalLM class.
  • the code does not pass output_router_logits argument to the forward function of self.model in JetMoeForCausalLM class. (see line 1310 and 1341, modeling_jetmoe.py)
  • for the JetMoeForSequenceClassification class, it misses the process of calculating aux_loss and forgets to pass output_router_logits argument to self.model.forward.
  1. Quick fix of the JetMoeForCausalLM class
  • Add self.num_experts = config.num_local_experts, and self.num_experts_per_tok = config.num_experts_per_tok in the __init__ function of the JetMoeForCausalLM.
  • Pass output_router_logits to self.model.forward (line 1331)
          # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
          outputs = self.model(
              input_ids=input_ids,
              attention_mask=attention_mask,
              position_ids=position_ids,
              past_key_values=past_key_values,
              inputs_embeds=inputs_embeds,
              use_cache=use_cache,
              output_attentions=output_attentions,
              output_hidden_states=output_hidden_states,
              return_dict=return_dict,
              cache_position=cache_position,
              output_router_logits=output_router_logits # Add this line.
          )
    

Expected behavior

The solution has been described in the previous section.

Phoenix-Shen avatar Jul 04 '24 09:07 Phoenix-Shen

Thanks @Phoenix-Shen! Let me cc @yikangshen, who has contributed the model.

LysandreJik avatar Jul 04 '24 12:07 LysandreJik

Hi @Phoenix-Shen, thanks for bringing up the issue! Your fix looks good to me. Would you like to submit a PR?

yikangshen avatar Jul 06 '24 20:07 yikangshen

Hi @Phoenix-Shen, thanks for bringing up the issue! Your fix looks good to me. Would you like to submit a PR?

Ok, I've fixed all the bugs and am ready to submit a PR.

Phoenix-Shen avatar Jul 07 '24 05:07 Phoenix-Shen

Thanks, reviewed!

ArthurZucker avatar Jul 10 '24 10:07 ArthurZucker

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Aug 04 '24 08:08 github-actions[bot]