Correctly initialize the text model (Mistral) of Idefics2 with Flash Attention
This PR attempts to resolve the issue of the text model not being loaded with Flash Attention 2
Relevant issue: #30394
Currently, whatever combination of parameters I pass to the instantiation of the Idefics2 models, the text model is not being loaded with Flash Attention 2. Here are several examples:
- Pass
_attn_implementationtofrom_pretrained
import torch
from transformers import Idefics2ForConditionalGeneration
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)
Output:
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00, 3.21it/s]
Mistral model attention implementation: sdpa
- Pass
attn_implementationtofrom_pretrained
import torch
from transformers import Idefics2ForConditionalGeneration
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)
Output:
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00, 3.08it/s]
Mistral model attention implementation: sdpa
- Pass
configobject with property_attn_implementation:
import torch
from transformers import AutoConfig, Idefics2ForConditionalGeneration
config = AutoConfig.from_pretrained("HuggingFaceM4/idefics2-8b")
config._attn_implementation = "flash_attention_2"
config.torch_dtype = torch.bfloat16
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
config=config,
torch_dtype=torch.bfloat16,
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)
Output:
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00, 3.09it/s]
Mistral model attention implementation: sdpa
- Pass both
configandattn_implementationtofrom_pretrained:
import torch
from transformers import AutoConfig, Idefics2ForConditionalGeneration
config = AutoConfig.from_pretrained("HuggingFaceM4/idefics2-8b")
config._attn_implementation = "flash_attention_2"
config.torch_dtype = torch.bfloat16
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
config=config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)
Output:
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00, 3.04it/s]
Mistral model attention implementation: sdpa
This PR contains a simple patch which would allow the text model to be loaded with Flash Attention. Here is the output with the changes included:
import torch
from transformers import AutoConfig, Idefics2ForConditionalGeneration
config = AutoConfig.from_pretrained("HuggingFaceM4/idefics2-8b")
config._attn_implementation = "flash_attention_2"
config.torch_dtype = torch.bfloat16
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
config=config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)
Output
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00, 3.13it/s]
Mistral model attention implementation: flash_attention_2
It is not an ideal fix, since it requires both passing a config object and a attn_implementation parameter. Moreover, it relies on the use_flash_attention_2 parameter which might be deprecated soon.
Criticism, feedback and requests for changes are welcomed.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Ping: @amyeroberts
Hi @zafstojano, thanks for opening this PR and addressing this issue!
At the moment in the diff and commit history there's lots of changes which are unrelated to this PR which should be resolved before merge. It looks like what happens after rebasing and pushing without force pushing. If this is the case, simply force pushing should resolve
@amyeroberts I have now force-pushed only my changes 👍
(curious) how to push without force after rebasing ... 👀 ?
(curious) how to push without force after rebasing ... 👀 ?
I've done it before but can't remember exactly the steps I took to achieve it! I think it rejects the push, you can pull and then push again.
@amyeroberts thank you for the constructive feedback.
I am currently experiencing some weird behavior when I integrate those changes, perhaps I am not 100% familiar with the internals of the transformers library.
For the following implementation of the init method in Idefics2Model:
class Idefics2Model(Idefics2PreTrainedModel):
def __init__(self, config: Idefics2Config):
super().__init__(config)
self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size
self.vision_model = Idefics2VisionTransformer(config.vision_config)
self.connector = Idefics2Connector(config)
torch_dtype = config.text_config.torch_dtype
if config.torch_dtype is not None:
torch_dtype = config.torch_dtype
attn_implementation = config.text_config._attn_implementation
if config._attn_implementation is not None:
attn_implementation = config._attn_implementation
print("=================")
print("torch_dtype being passed to text_model in Idefics2Model.__init__():", torch_dtype)
print("=================")
self.text_model = AutoModel.from_config(
config.text_config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = self.config.image_token_id
self.post_init()
and the following code sample:
import torch
from transformers import Idefics2ForConditionalGeneration
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("Perceiver model flash attention: ", model.model.connector.perceiver_resampler._use_flash_attention_2)
print("Vision model flash attention: ", model.model.vision_model._use_flash_attention_2)
print("Text model flash attention: ", model.model.text_model._attn_implementation == "flash_attention_2")
print('-----------------')
print("Model dtype: ", model.dtype)
I get the output:
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in MistralModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
=================
torch_dtype being passed to text_model in Idefics2Model.__init__(): torch.float32
=================
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00, 2.68it/s]
Perceiver model flash attention: True
Vision model flash attention: True
Text model flash attention: True
-----------------
Model dtype: torch.bfloat16
So, the flash attention is correctly propagate to all submodules when the user specifies attn_implementation="flash_attention_2" in Idefics2ForConditionalGeneration.from_pretrained, but the torch_dtype is for some reason not: look at the torch_dtype being passed to the text model.
Do you have any idea why this is happening?
Unrelated to the above issue, I have another suggestion: Since the vision model is initialized from config.vision_config, and in turn it uses this sub-config file to infer the attention implementation, it would be a good idea to override the config.vision_config._attn_implementation property with the one inferred above. What do you think?
Hi @zafstojano, thanks for sharing this script!
OK, so the behaviour of torch_dtype is quite complex and not the area of the code I'm most familiar with. In terms of what's happening in the script, I think:
- torch_dtype isn't set in the idefics2config, so defaults to torch.float32
- this is what is passed along when constructing the model. However, when loading the model from a checkpoint, what actually happens is we create an empty model, and then fill in the values when loading the weights. The
torch.float32you're seeing is when this empty model is made. - When loading the weights, the passed in
torch_dtypevalue in thefrom_pretrainedcall determines how load the weights. If it's"auto"then we use the value in the config. If it's set totorch.xxxthen it uses this value. If unset, it defaults totorch.float32. This will determine the model's dtype. - In this case, we don't want to pass
torch_dtypefrom the config, as this just describes the format of the weights as they were saved. Instead, we should just pass the params for setting the attention value and skip the torch_dtype logic.
cc @younesbelkada To confirm if this is right and if there's anything else to be aware of. To understand more, if in composite models like this and llava have their language model saved in e.g. float16, and their vision tower in float32; what will happen when we use torch_dtype="auto"?
Since the vision model is initialized from config.vision_config, and in turn it uses this sub-config file to infer the attention implementation, it would be a good idea to override the config.vision_config._attn_implementation property with the one inferred above. What do you think?
Yes! Good idea
Hi @amyeroberts @younesbelkada
With the implementation you proposed, for the following sample code:
import torch
from transformers import Idefics2ForConditionalGeneration
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("Perceiver model flash attention: ", model.model.connector.perceiver_resampler._use_flash_attention_2)
print("Vision model flash attention: ", model.model.vision_model._use_flash_attention_2)
print("Text model flash attention: ", model.model.text_model._attn_implementation == "flash_attention_2")
print('-----------------')
print("Model dtype: ", model.dtype)
I get the following output:
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00, 3.23it/s]
Perceiver model flash attention: True
Vision model flash attention: True
Text model flash attention: True
-----------------
Model dtype: torch.bfloat16
The reason why I wanted to explicitly pass torch_dtype is because of the warning You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour.
Is this acceptable?
Moreover, when using the vision tower with Flash Attention, I get this exception:
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Traceback (most recent call last):
File "/home/z/.pyenv/versions/3.11.5/lib/python3.11/runpy.py", line 198, in _run_module_as_main
return _run_code(code, main_globals, None,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/lib/python3.11/runpy.py", line 88, in _run_code
exec(code, run_globals)
File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
cli.main()
File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="__main__")
File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "uxo_ml/idefics/train.py", line 162, in <module>
trainer.train()
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 1859, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 3138, in training_step
loss = self.compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 3161, in compute_loss
outputs = model(**inputs)
^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/utils/operations.py", line 825, in forward
return model_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/utils/operations.py", line 813, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/peft/peft_model.py", line 563, in forward
return self.get_base_model()(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1823, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1643, in forward
image_hidden_states = self.connector(
^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1317, in forward
image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1287, in forward
layer_outputs = perceiver_layer(
^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1220, in forward
latents, self_attn_weights, present_key_value = self.self_attn(
^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1004, in forward
attn_output = self._flash_attention_forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1071, in _flash_attention_forward
attn_output_unpad = flash_attn_varlen_func(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 1066, in flash_attn_varlen_func
return FlashAttnVarlenFunc.apply(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 581, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 86, in _flash_attn_varlen_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: query and key must have the same dtype
The above error can get fixed with casting the input image hidden states to the same dtype as the input tokens going into the Mistral model:
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
).last_hidden_state.to(dtype=self.dtype, device=input_ids.device)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
).to(dtype=self.dtype, device=input_ids.device)
Although, I still get the warning about upcasting:
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Weirdly, this happens even if I cast all params to the same dtype (e.g. bfloat16).
@zafstojano I see. So this is an issue, and a tricky one at that.
@younesbelkada it doesn't seem to be the case that passing torch_dtype correctly propogates the specified weights to the other classes. Although it seems to set it globally -- the parent model has bfloat16 se -- if I query the weights in the loaded mistral model they're all in float32.
hmm interesting ok, I will have a deeper look then !
@younesbelkada @zafstojano Just to follow up on the dtype investigation, I suspect there might be a difference between the torch_dtype being passed in the model inits during instantiation, and the torch dtype used when the pretrained weights are loaded in. I just ran a quick test, and the weights do seem to be loaded in as expected:
import torch
from transformers import Idefics2ForConditionalGeneration
print("Loading in as torch.bfloat16")
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model.dtype
print(model.model.text_model.dtype)
print(model.model.vision_model.embeddings.position_embedding.weight.dtype)
print(model.model.connector.perceiver_resampler.latents.dtype)
print("\nLoading in as torch.float32")
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
attn_implementation="flash_attention_2",
)
model.dtype
print(model.model.text_model.dtype)
print(model.model.vision_model.embeddings.position_embedding.weight.dtype)
print(model.model.connector.perceiver_resampler.latents.dtype)
Produces output:
Loading in as torch.bfloat16
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00, 2.06it/s]
torch.bfloat16
torch.bfloat16
torch.bfloat16
Loading in as torch.float32
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00, 1.82it/s]
torch.float32
torch.float32
torch.float32
thanks for investigating @amyeroberts and apologies for not investigating ! Seems all is good then ? 🙏
@younesbelkada Yep! I think so
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.