Mixtral past_key_values and output_router_logits incompatible
System Info
transformers==4.40.2 Python 3.11.8
Who can help?
@ArthurZucker
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
from transformers import MixtralConfig, MixtralForCausalLM, AutoTokenizer
import torch
# Initializing a smaller version of Mixtral for faster execution
configuration = MixtralConfig(
hidden_size=256,
intermediate_size=896,
num_hidden_layers=8,
num_attention_heads=8,
num_key_value_heads=8,
num_local_experts=4,
num_experts_per_tok=1,
)
model = MixtralForCausalLM(configuration)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
prompt = "This is a test"
tokenized = tokenizer(prompt, return_tensors="pt")
output = model(**tokenized, output_router_logits=True)
key_values = output.past_key_values
logits = output.logits
next_token_logits = logits[..., -1, :]
# Softmax
softmaxed = torch.nn.functional.softmax(next_token_logits, dim=-1)
# Sample
sampled = torch.multinomial(softmaxed.squeeze(), num_samples=1)
ids = sampled.item()
attention_mask = torch.cat([tokenized["attention_mask"], torch.tensor([[1]])], dim=-1)
next_output = model(
torch.tensor([[ids]]),
attention_mask=attention_mask,
past_key_values=key_values,
output_router_logits=True
)
Expected behavior
It seems that this is the same underlying issue as in #29087 - I would expect past_key_values to work with output_router_logits.
So what happens?
- Without past key values (and with multiple input ids) the
all_router_logitshas the proper sequence length, thus inload_balancing_loss_functhisnum_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)correctly evaluates the number of hidden layers. - If past key values are used,
all_router_logitshas a sequence length of 1, but since the attention mask is still the whole sequence (from which thesequence_lengthis inferred) the hidden layers evaluate to a small value or 0, leading to the same error as in #29087
Instead, I would like the load_balancing_loss_func to be able to deal with a case where the gate_logits passed are of shape [batch_size X 1, num_experts] instead of [batch_size X sequence_length, num_experts].
Hey! The generate function is not supposed to work for training. That is why we don't test past key values and output router logits. Though it's actually not that incompatible (you could want to look at the distribution of the router logits during generation). Do you want to open a PR for a fix?