[Transformers Optimizer] CLIP-ViT encoder attention not getting fused
Describe the issue
I'm trying to optimize the vision encoder of a CLIP model exported from HuggingFace Transformers, but the attention subgraphs don't get fused.
I tried with a straight ViT model using the vanilla implementation of attention and the optimizer is able to fuse the operations (but it doesn't when using SDPA, though that's another story).
To reproduce
import torch
from torch import nn
import transformers
from transformers import CLIPVisionModel
import onnxruntime
from onnxruntime.transformers import optimizer
pixel_values = torch.rand(1, 3, 224, 224)
class Embedder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16", attn_implementation='eager')
def forward(self, pixel_values):
return self.encoder(pixel_values=pixel_values).pooler_output
embedder = Embedder()
with torch.inference_mode():
_ = embedder(pixel_values)
torch.onnx.export(
embedder,
pixel_values,
"model.onnx",
input_names=("pixel_values",),
output_names=("embedding",),
dynamic_axes={"pixel_values": {0: "batch"}},
do_constant_folding=True,
opset_version=17
)
optimized_vit = optimizer.optimize_model('model.onnx', model_type='vit', num_heads=12, hidden_size=768, verbose=True)
optimized_clip = optimizer.optimize_model('model.onnx', model_type='clip', num_heads=12, hidden_size=768, verbose=True)
print(f"Using PyTorch {torch.__version__}, Transformers {transformers.__version__}, and ORT {onnxruntime.__version__}")
print("Optimizing with ViT config:")
print(optimized_vit.get_fused_operator_statistics())
print("Optimizing with CLIP config:")
print(optimized_clip.get_fused_operator_statistics())
Outputs:
Using PyTorch 2.3.1, Transformers 4.41.2, and ORT 1.18.0
Optimizing with ViT config:
{'EmbedLayerNormalization': 0, 'Attention': 0, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 0, 'GemmFastGelu': 0, 'LayerNormalization': 3, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 23, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
Optimizing with CLIP config:
{'Attention': 0, 'LayerNormalization': 3, 'SkipLayerNormalization': 23}
No Attention fused operators.
Urgency
No response
Platform
Linux
OS Version
Ubuntu 20.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.18.0
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
Can you specify transformers version as well? They made a lot of changes recently.
Can you specify transformers version as well? They made a lot of changes recently.
Hi,
In [6]: import transformers
In [7]: transformers.__version__
Out[7]: '4.41.2'
I also tried with 4.28.1 that (I think) was released at around the time CLIP fusion support was added to onnxruntime but got the same results:
Using PyTorch 2.3.1, Transformers 4.28.1, and ORT 1.18.0
Optimizing with ViT config:
{'EmbedLayerNormalization': 0, 'Attention': 0, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 0, 'GemmFastGelu': 0, 'LayerNormalization': 3, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 23, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
Optimizing with CLIP config:
{'Attention': 0, 'LayerNormalization': 3, 'SkipLayerNormalization': 23}
I'll run the test again with PyTorch 1.13.1 later
Edit: my mistake, 4.28.1 is much older than either of the two PRs that added CLIP attention fusion
Can you try your code with the nightly ORT package instead of the stable ORT 1.18.0 package? New fusions for CLIP were recently added in this PR and aren't in the stable ORT package currently. Alternatively, you can try the "from source" instructions in the PR's description.
Sweet! I assumed that PR was already part of the stable release. It's looking much better with ort-nightly, thank you
Using PyTorch 2.3.1, Transformers 4.41.2, and ORT 1.19.0
Optimizing with ViT config:
{'EmbedLayerNormalization': 0, 'Attention': 0, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 0, 'GemmFastGelu': 0, 'LayerNormalization': 3, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 23, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
Optimizing with CLIP config:
{'Attention': 12, 'LayerNormalization': 3, 'QuickGelu': 12, 'SkipLayerNormalization': 23}