Removing model layers throws an index error.
Feature request
Hello,
When I try to remove a layer from the LLaMa models using the code snippet below, I get an index error (pasted below the snippet). From what I could tell, layer_idx attribute of self.attn is being used for generation, and the layer_idx are not updated automatically. I believe the same behaviour holds in other models (e.g. gemma-2b). Apologies if there is another existing way to remove layers. I'm posting this after an extensive search.
import torch
model_name = "meta-llama/Meta-Llama-3-8B"
torch_dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch_dtype,
)
model.model.layers = torch.nn.ModuleList([layer for i, layer in enumerate(model.model.layers) if i != 16])
prompt = "hello"
tokenized = tokenizer(prompt, return_tensors="pt").to(model.device)["input_ids"]
output = model(tokenized, return_dict=True, output_hidden_states=True)
IndexError Traceback (most recent call last)
Cell In[132], line 14
12 prompt = "hello"
13 tokenized = tokenizer(prompt, return_tensors="pt").to(model.device)["input_ids"]
---> 14 output = model(tokenized, return_dict=True, output_hidden_states=True)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1208, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1205 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1207 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1208 outputs = self.model(
1209 input_ids=input_ids,
1210 attention_mask=attention_mask,
1211 position_ids=position_ids,
1212 past_key_values=past_key_values,
1213 inputs_embeds=inputs_embeds,
1214 use_cache=use_cache,
1215 output_attentions=output_attentions,
1216 output_hidden_states=output_hidden_states,
1217 return_dict=return_dict,
1218 cache_position=cache_position,
1219 )
1221 hidden_states = outputs[0]
1222 if self.config.pretraining_tp > 1:
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1018, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1007 layer_outputs = self._gradient_checkpointing_func(
1008 decoder_layer.__call__,
1009 hidden_states,
(...)
1015 cache_position,
1016 )
1017 else:
-> 1018 layer_outputs = decoder_layer(
1019 hidden_states,
1020 attention_mask=causal_mask,
1021 position_ids=position_ids,
1022 past_key_value=past_key_values,
1023 output_attentions=output_attentions,
1024 use_cache=use_cache,
1025 cache_position=cache_position,
1026 )
1028 hidden_states = layer_outputs[0]
1030 if use_cache:
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:741, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
738 hidden_states = self.input_layernorm(hidden_states)
740 # Self Attention
--> 741 hidden_states, self_attn_weights, present_key_value = self.self_attn(
742 hidden_states=hidden_states,
743 attention_mask=attention_mask,
744 position_ids=position_ids,
745 past_key_value=past_key_value,
746 output_attentions=output_attentions,
747 use_cache=use_cache,
748 cache_position=cache_position,
749 **kwargs,
750 )
751 hidden_states = residual + hidden_states
753 # Fully Connected
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:653, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
650 if past_key_value is not None:
651 # sin and cos are specific to RoPE models; cache_position needed for the static cache
652 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 653 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
655 key_states = repeat_kv(key_states, self.num_key_value_groups)
656 value_states = repeat_kv(value_states, self.num_key_value_groups)
File ~/.local/lib/python3.10/site-packages/transformers/cache_utils.py:149, in DynamicCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
146 self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
147 self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
--> 149 return self.key_cache[layer_idx], self.value_cache[layer_idx]
IndexError: list index out of range
Motivation
I think it'd be fantastic to decouple the layer_idx variable somehow to allow easy removal of entire blocks. I imagine this would be useful for the general research community to experiment with these models.
Your contribution
I'm not very familiar with the inner workings of the library, however I'd be happy to make a PR if you can give me some high level suggestions on how to make this change. Thanks!
cc @ArthurZucker
@candemircan I'm pretty sure I ran into something similar when trying to chop/remove the majority of a model for local dev and it makes what you are trying to do somewhat impossible.
Seems like there are two possibilities, one is figure out what args you need to pass in that possibly allow you to use the model without much modifying, for instance with what you posted using use_cache=False (but other models use the layer_idx for cache or other parts in different ways that also is not decoupled so you may need to figure out other kwargs as well):
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
torch_dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch_dtype)
model.model.layers.pop(16)
prompt = "hello"
tokenized = tokenizer(prompt, return_tensors="pt").to(model.device)["input_ids"]
output = model(tokenized, use_cache=False, return_dict=True, output_hidden_states=True)
should work. The other approach which is less layer adaptable but seems to be less prone to breaking across various forwards (but might not be helpful for what you are trying to do) is to modify a config that you pass into the model creation and change the number of layers (e.g. config.num_hidden_layers) before it is passed to from_pretrained.
Kind of annoying and I agree the model layers that are in something like a ModuleList should be decoupled from the model to allow for more easily debugging/dev locally without having to wrap/subclass/etc the Model/Config.
Layer index is mostly ( and only) used for the chace. And this is more a feature request than a bug: you are manually re-ordering the layers without updating the layer index.
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_name = "meta-llama/Meta-Llama-3-8B"
torch_dtype = torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch_dtype,
)
for i, layer in enumerate(model.model.layers[:-1]):
if i<16:
model.model.layers[i]
else:
model.model.layers[i] = model.model.layers[i+1]
model.model.layers[i].layer_idx = i+1
prompt = "hello"
tokenized = tokenizer(prompt, return_tensors="pt").to(model.device)["input_ids"]
output = model(tokenized, return_dict=True, output_hidden_states=True)
It's not really part of the API not sure we want to add some kind of trick to automatically update the layer idx
hi Arthur,
thanks for the response. I think the workarounds you and @grahamannett suggested suit what I need
Glad we could help! 🤗