transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Mixtral past_key_values and output_router_logits incompatible

Open sorgfresser opened this issue 1 year ago • 1 comments

System Info

transformers==4.40.2 Python 3.11.8

Who can help?

@ArthurZucker

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

from transformers import MixtralConfig, MixtralForCausalLM, AutoTokenizer
import torch
# Initializing a smaller version of Mixtral for faster execution
configuration = MixtralConfig(
    hidden_size=256,
    intermediate_size=896,
    num_hidden_layers=8,
    num_attention_heads=8,
    num_key_value_heads=8,
    num_local_experts=4,
    num_experts_per_tok=1,
)

model = MixtralForCausalLM(configuration)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
prompt = "This is a test"
tokenized = tokenizer(prompt, return_tensors="pt")
output = model(**tokenized, output_router_logits=True)
key_values = output.past_key_values
logits = output.logits
next_token_logits = logits[..., -1, :]
# Softmax
softmaxed = torch.nn.functional.softmax(next_token_logits, dim=-1)
# Sample
sampled = torch.multinomial(softmaxed.squeeze(), num_samples=1)
ids = sampled.item()

attention_mask = torch.cat([tokenized["attention_mask"], torch.tensor([[1]])], dim=-1)
next_output = model(
    torch.tensor([[ids]]),
    attention_mask=attention_mask,
    past_key_values=key_values,
    output_router_logits=True
)

Expected behavior

It seems that this is the same underlying issue as in #29087 - I would expect past_key_values to work with output_router_logits. So what happens?

  1. Without past key values (and with multiple input ids) the all_router_logits has the proper sequence length, thus in load_balancing_loss_func this num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) correctly evaluates the number of hidden layers.
  2. If past key values are used, all_router_logits has a sequence length of 1, but since the attention mask is still the whole sequence (from which the sequence_length is inferred) the hidden layers evaluate to a small value or 0, leading to the same error as in #29087

Instead, I would like the load_balancing_loss_func to be able to deal with a case where the gate_logits passed are of shape [batch_size X 1, num_experts] instead of [batch_size X sequence_length, num_experts].

sorgfresser avatar May 09 '24 15:05 sorgfresser

Hey! The generate function is not supposed to work for training. That is why we don't test past key values and output router logits. Though it's actually not that incompatible (you could want to look at the distribution of the router logits during generation). Do you want to open a PR for a fix?

ArthurZucker avatar May 15 '24 08:05 ArthurZucker