transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Correctly initialize the text model (Mistral) of Idefics2 with Flash Attention

Open zafstojano opened this issue 1 year ago • 13 comments

This PR attempts to resolve the issue of the text model not being loaded with Flash Attention 2

Relevant issue: #30394


Currently, whatever combination of parameters I pass to the instantiation of the Idefics2 models, the text model is not being loaded with Flash Attention 2. Here are several examples:

  1. Pass _attn_implementation to from_pretrained
import torch
from transformers import Idefics2ForConditionalGeneration

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)

Output:

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.21it/s]
Mistral model attention implementation:  sdpa
  1. Passattn_implementation to from_pretrained
import torch
from transformers import Idefics2ForConditionalGeneration

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)

Output:

Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.08it/s]
Mistral model attention implementation:  sdpa
  1. Pass config object with property _attn_implementation:
import torch
from transformers import AutoConfig, Idefics2ForConditionalGeneration

config = AutoConfig.from_pretrained("HuggingFaceM4/idefics2-8b")
config._attn_implementation = "flash_attention_2"
config.torch_dtype = torch.bfloat16

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    config=config,
    torch_dtype=torch.bfloat16,
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)

Output:

Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.09it/s]
Mistral model attention implementation:  sdpa
  1. Pass both config and attn_implementation to from_pretrained:
import torch
from transformers import AutoConfig, Idefics2ForConditionalGeneration

config = AutoConfig.from_pretrained("HuggingFaceM4/idefics2-8b")
config._attn_implementation = "flash_attention_2"
config.torch_dtype = torch.bfloat16

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    config=config,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)

Output:

Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.04it/s]
Mistral model attention implementation:  sdpa

This PR contains a simple patch which would allow the text model to be loaded with Flash Attention. Here is the output with the changes included:

import torch
from transformers import AutoConfig, Idefics2ForConditionalGeneration

config = AutoConfig.from_pretrained("HuggingFaceM4/idefics2-8b")
config._attn_implementation = "flash_attention_2"
config.torch_dtype = torch.bfloat16

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    config=config,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)

Output

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.13it/s]
Mistral model attention implementation:  flash_attention_2

It is not an ideal fix, since it requires both passing a config object and a attn_implementation parameter. Moreover, it relies on the use_flash_attention_2 parameter which might be deprecated soon.

Criticism, feedback and requests for changes are welcomed.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Ping: @amyeroberts

zafstojano avatar Apr 22 '24 14:04 zafstojano

Hi @zafstojano, thanks for opening this PR and addressing this issue!

At the moment in the diff and commit history there's lots of changes which are unrelated to this PR which should be resolved before merge. It looks like what happens after rebasing and pushing without force pushing. If this is the case, simply force pushing should resolve

amyeroberts avatar Apr 22 '24 15:04 amyeroberts

@amyeroberts I have now force-pushed only my changes 👍

zafstojano avatar Apr 22 '24 19:04 zafstojano

(curious) how to push without force after rebasing ... 👀 ?

ydshieh avatar Apr 22 '24 19:04 ydshieh

(curious) how to push without force after rebasing ... 👀 ?

I've done it before but can't remember exactly the steps I took to achieve it! I think it rejects the push, you can pull and then push again.

amyeroberts avatar Apr 22 '24 19:04 amyeroberts

@amyeroberts thank you for the constructive feedback.

I am currently experiencing some weird behavior when I integrate those changes, perhaps I am not 100% familiar with the internals of the transformers library.

For the following implementation of the init method in Idefics2Model:

class Idefics2Model(Idefics2PreTrainedModel):
    def __init__(self, config: Idefics2Config):
        super().__init__(config)
        self.padding_idx = self.config.text_config.pad_token_id
        self.vocab_size = self.config.text_config.vocab_size

        self.vision_model = Idefics2VisionTransformer(config.vision_config)
        self.connector = Idefics2Connector(config)
        torch_dtype = config.text_config.torch_dtype
        if config.torch_dtype is not None:
            torch_dtype = config.torch_dtype
        attn_implementation = config.text_config._attn_implementation
        if config._attn_implementation is not None:
            attn_implementation = config._attn_implementation
        print("=================")
        print("torch_dtype being passed to text_model in Idefics2Model.__init__():", torch_dtype)
        print("=================")
        self.text_model = AutoModel.from_config(
            config.text_config, 
            attn_implementation=attn_implementation,
            torch_dtype=torch_dtype,    
        )
        self.image_seq_len = config.perceiver_config.resampler_n_latents
        self.image_token_id = self.config.image_token_id

        self.post_init()

and the following code sample:

import torch
from transformers import Idefics2ForConditionalGeneration

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

print("Perceiver model flash attention: ", model.model.connector.perceiver_resampler._use_flash_attention_2)
print("Vision model flash attention: ", model.model.vision_model._use_flash_attention_2)
print("Text model flash attention: ", model.model.text_model._attn_implementation == "flash_attention_2")
print('-----------------')
print("Model dtype: ", model.dtype)

I get the output:

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in MistralModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
=================
torch_dtype being passed to text_model in Idefics2Model.__init__(): torch.float32
=================
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  2.68it/s]
Perceiver model flash attention:  True
Vision model flash attention:  True
Text model flash attention:  True
-----------------
Model dtype:  torch.bfloat16

So, the flash attention is correctly propagate to all submodules when the user specifies attn_implementation="flash_attention_2" in Idefics2ForConditionalGeneration.from_pretrained, but the torch_dtype is for some reason not: look at the torch_dtype being passed to the text model.

Do you have any idea why this is happening?

Unrelated to the above issue, I have another suggestion: Since the vision model is initialized from config.vision_config, and in turn it uses this sub-config file to infer the attention implementation, it would be a good idea to override the config.vision_config._attn_implementation property with the one inferred above. What do you think?

zafstojano avatar Apr 23 '24 13:04 zafstojano

Hi @zafstojano, thanks for sharing this script!

OK, so the behaviour of torch_dtype is quite complex and not the area of the code I'm most familiar with. In terms of what's happening in the script, I think:

  • torch_dtype isn't set in the idefics2config, so defaults to torch.float32
  • this is what is passed along when constructing the model. However, when loading the model from a checkpoint, what actually happens is we create an empty model, and then fill in the values when loading the weights. The torch.float32 you're seeing is when this empty model is made.
  • When loading the weights, the passed in torch_dtype value in the from_pretrained call determines how load the weights. If it's "auto" then we use the value in the config. If it's set to torch.xxx then it uses this value. If unset, it defaults to torch.float32. This will determine the model's dtype.
  • In this case, we don't want to pass torch_dtype from the config, as this just describes the format of the weights as they were saved. Instead, we should just pass the params for setting the attention value and skip the torch_dtype logic.

cc @younesbelkada To confirm if this is right and if there's anything else to be aware of. To understand more, if in composite models like this and llava have their language model saved in e.g. float16, and their vision tower in float32; what will happen when we use torch_dtype="auto"?

Since the vision model is initialized from config.vision_config, and in turn it uses this sub-config file to infer the attention implementation, it would be a good idea to override the config.vision_config._attn_implementation property with the one inferred above. What do you think?

Yes! Good idea

amyeroberts avatar Apr 25 '24 19:04 amyeroberts

Hi @amyeroberts @younesbelkada

With the implementation you proposed, for the following sample code:

import torch
from transformers import Idefics2ForConditionalGeneration

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

print("Perceiver model flash attention: ", model.model.connector.perceiver_resampler._use_flash_attention_2)
print("Vision model flash attention: ", model.model.vision_model._use_flash_attention_2)
print("Text model flash attention: ", model.model.text_model._attn_implementation == "flash_attention_2")
print('-----------------')
print("Model dtype: ", model.dtype)

I get the following output:

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.23it/s]
Perceiver model flash attention:  True
Vision model flash attention:  True
Text model flash attention:  True
-----------------
Model dtype:  torch.bfloat16

The reason why I wanted to explicitly pass torch_dtype is because of the warning You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour.

Is this acceptable?

zafstojano avatar Apr 27 '24 16:04 zafstojano

Moreover, when using the vision tower with Flash Attention, I get this exception:

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Traceback (most recent call last):
  File "/home/z/.pyenv/versions/3.11.5/lib/python3.11/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, main_globals, None,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/lib/python3.11/runpy.py", line 88, in _run_code
    exec(code, run_globals)
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "uxo_ml/idefics/train.py", line 162, in <module>
    trainer.train()
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 3138, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 3161, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/utils/operations.py", line 825, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/utils/operations.py", line 813, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/peft/peft_model.py", line 563, in forward
    return self.get_base_model()(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1823, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1643, in forward
    image_hidden_states = self.connector(
                          ^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1317, in forward
    image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1287, in forward
    layer_outputs = perceiver_layer(
                    ^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1220, in forward
    latents, self_attn_weights, present_key_value = self.self_attn(
                                                    ^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1004, in forward
    attn_output = self._flash_attention_forward(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1071, in _flash_attention_forward
    attn_output_unpad = flash_attn_varlen_func(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 1066, in flash_attn_varlen_func
    return FlashAttnVarlenFunc.apply(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 581, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
                                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 86, in _flash_attn_varlen_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
                                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: query and key must have the same dtype

The above error can get fixed with casting the input image hidden states to the same dtype as the input tokens going into the Mistral model:

            # Get sequence from the vision encoder
            image_hidden_states = self.vision_model(
                pixel_values=pixel_values,
                patch_attention_mask=patch_attention_mask,
            ).last_hidden_state.to(dtype=self.dtype, device=input_ids.device)

            # Modality projection & resampling
            image_hidden_states = self.connector(
                image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
            ).to(dtype=self.dtype, device=input_ids.device)

Although, I still get the warning about upcasting:

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.

Weirdly, this happens even if I cast all params to the same dtype (e.g. bfloat16).

zafstojano avatar Apr 27 '24 16:04 zafstojano

@zafstojano I see. So this is an issue, and a tricky one at that.

@younesbelkada it doesn't seem to be the case that passing torch_dtype correctly propogates the specified weights to the other classes. Although it seems to set it globally -- the parent model has bfloat16 se -- if I query the weights in the loaded mistral model they're all in float32.

amyeroberts avatar Apr 30 '24 18:04 amyeroberts

hmm interesting ok, I will have a deeper look then !

younesbelkada avatar May 02 '24 11:05 younesbelkada

@younesbelkada @zafstojano Just to follow up on the dtype investigation, I suspect there might be a difference between the torch_dtype being passed in the model inits during instantiation, and the torch dtype used when the pretrained weights are loaded in. I just ran a quick test, and the weights do seem to be loaded in as expected:

import torch
from transformers import Idefics2ForConditionalGeneration

print("Loading in as torch.bfloat16")
model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

model.dtype
print(model.model.text_model.dtype)
print(model.model.vision_model.embeddings.position_embedding.weight.dtype)
print(model.model.connector.perceiver_resampler.latents.dtype)

print("\nLoading in as torch.float32")
model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    attn_implementation="flash_attention_2",
)

model.dtype
print(model.model.text_model.dtype)
print(model.model.vision_model.embeddings.position_embedding.weight.dtype)
print(model.model.connector.perceiver_resampler.latents.dtype)

Produces output:

Loading in as torch.bfloat16
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00,  2.06it/s]
torch.bfloat16
torch.bfloat16
torch.bfloat16

Loading in as torch.float32
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00,  1.82it/s]
torch.float32
torch.float32
torch.float32

amyeroberts avatar May 23 '24 13:05 amyeroberts

thanks for investigating @amyeroberts and apologies for not investigating ! Seems all is good then ? 🙏

younesbelkada avatar May 23 '24 13:05 younesbelkada

@younesbelkada Yep! I think so

amyeroberts avatar May 23 '24 14:05 amyeroberts

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jun 17 '24 08:06 github-actions[bot]