onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

[Transformers Optimizer] CLIP-ViT encoder attention not getting fused

Open mbahri opened this issue 1 year ago • 2 comments

Describe the issue

I'm trying to optimize the vision encoder of a CLIP model exported from HuggingFace Transformers, but the attention subgraphs don't get fused.

I tried with a straight ViT model using the vanilla implementation of attention and the optimizer is able to fuse the operations (but it doesn't when using SDPA, though that's another story).

To reproduce

import torch
from torch import nn

import transformers
from transformers import CLIPVisionModel

import onnxruntime
from onnxruntime.transformers import optimizer


pixel_values = torch.rand(1, 3, 224, 224)

class Embedder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16", attn_implementation='eager')
    def forward(self, pixel_values):
        return self.encoder(pixel_values=pixel_values).pooler_output


embedder = Embedder()

with torch.inference_mode():
    _ = embedder(pixel_values)

    torch.onnx.export(
        embedder,
        pixel_values,
        "model.onnx",
        input_names=("pixel_values",),
        output_names=("embedding",),
        dynamic_axes={"pixel_values": {0: "batch"}},
        do_constant_folding=True,
        opset_version=17
    )

optimized_vit = optimizer.optimize_model('model.onnx', model_type='vit', num_heads=12, hidden_size=768, verbose=True)
optimized_clip = optimizer.optimize_model('model.onnx', model_type='clip', num_heads=12, hidden_size=768, verbose=True)

print(f"Using PyTorch {torch.__version__}, Transformers {transformers.__version__}, and ORT {onnxruntime.__version__}")

print("Optimizing with ViT config:")
print(optimized_vit.get_fused_operator_statistics())

print("Optimizing with CLIP config:")
print(optimized_clip.get_fused_operator_statistics())

Outputs:

Using PyTorch 2.3.1, Transformers 4.41.2, and ORT 1.18.0
Optimizing with ViT config:
{'EmbedLayerNormalization': 0, 'Attention': 0, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 0, 'GemmFastGelu': 0, 'LayerNormalization': 3, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 23, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
Optimizing with CLIP config:
{'Attention': 0, 'LayerNormalization': 3, 'SkipLayerNormalization': 23}

No Attention fused operators.

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.18.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

mbahri avatar Jun 28 '24 17:06 mbahri

Can you specify transformers version as well? They made a lot of changes recently.

xadupre avatar Jul 01 '24 08:07 xadupre

Can you specify transformers version as well? They made a lot of changes recently.

Hi,

In [6]: import transformers

In [7]: transformers.__version__
Out[7]: '4.41.2'

I also tried with 4.28.1 that (I think) was released at around the time CLIP fusion support was added to onnxruntime but got the same results:

Using PyTorch 2.3.1, Transformers 4.28.1, and ORT 1.18.0
Optimizing with ViT config:
{'EmbedLayerNormalization': 0, 'Attention': 0, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 0, 'GemmFastGelu': 0, 'LayerNormalization': 3, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 23, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
Optimizing with CLIP config:
{'Attention': 0, 'LayerNormalization': 3, 'SkipLayerNormalization': 23}

I'll run the test again with PyTorch 1.13.1 later

Edit: my mistake, 4.28.1 is much older than either of the two PRs that added CLIP attention fusion

mbahri avatar Jul 01 '24 09:07 mbahri

Can you try your code with the nightly ORT package instead of the stable ORT 1.18.0 package? New fusions for CLIP were recently added in this PR and aren't in the stable ORT package currently. Alternatively, you can try the "from source" instructions in the PR's description.

kunal-vaishnavi avatar Jul 02 '24 07:07 kunal-vaishnavi

Sweet! I assumed that PR was already part of the stable release. It's looking much better with ort-nightly, thank you

Using PyTorch 2.3.1, Transformers 4.41.2, and ORT 1.19.0
Optimizing with ViT config:
{'EmbedLayerNormalization': 0, 'Attention': 0, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 0, 'GemmFastGelu': 0, 'LayerNormalization': 3, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 23, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
Optimizing with CLIP config:
{'Attention': 12, 'LayerNormalization': 3, 'QuickGelu': 12, 'SkipLayerNormalization': 23}

mbahri avatar Jul 02 '24 15:07 mbahri