coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

How to deploy Vision Transformer with ANE to Achieve Faster Uncached Load Speed

Open cvv-student opened this issue 1 year ago • 2 comments

❓Question

I wanted to deploy some ViT models on an iPhone. I referred to https://machinelearning.apple.com/research/vision-transformers for deployment and wrote a simple demo based on the code from https://github.com/apple/ml-vision-transformers-ane. However, I found that the uncached load time on the phone is very long. According to the blog, the input is already aligned to 64 bytes, but the speed is still very slow. Is there any way to speed it up? This is my test case:

import torch
import coremltools as ct
import math
from torch import nn


class SelfAttn(torch.nn.Module):
    def __init__(self, window_size, num_heads, dim, dim_out):
        super().__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        self.dim = dim
        self.dim_out = dim_out
        self.q_proj = nn.Conv2d(
            in_channels=dim,
            out_channels=dim_out,
            kernel_size=1,
        )
        self.k_proj = nn.Conv2d(
            in_channels=dim,
            out_channels=dim_out,
            kernel_size=1,
        )
        self.v_proj = nn.Conv2d(
            in_channels=dim,
            out_channels=dim_out,
            kernel_size=1,
        )

    def forward(self, x):
        B, HW, C = x.shape
        image_shape = (B, C, self.window_size, self.window_size)
        x_2d = x.permute((0, 2, 1)).reshape(image_shape)  # BCHW
        x_flat = torch.unsqueeze(x.permute((0, 2, 1)), 2)  # BC1L
        q, k, v_2d = self.q_proj(x_flat), self.k_proj(x_flat), self.v_proj(x_2d)

        mh_q = torch.split(q, self.dim_out // self.num_heads, dim=1)  # BC1L
        mh_v = torch.split(
            v_2d.reshape(B, -1, x_flat.shape[2], x_flat.shape[3]), self.dim_out // self.num_heads, dim=1
        )
        mh_k = torch.split(
            torch.permute(k, (0, 3, 2, 1)), self.dim_out // self.num_heads, dim=3
        )
        scale_factor = 1 / math.sqrt(mh_q[0].size(1))
        attn_weights = [
            torch.einsum("bchq, bkhc->bkhq", qi, ki) * scale_factor
            for qi, ki in zip(mh_q, mh_k)
        ]
        attn_weights = [
            torch.softmax(aw, dim=1) for aw in attn_weights
        ]  # softmax applied on channel "C"
        mh_x = [torch.einsum("bkhq,bchk->bchq", wi, vi) for wi, vi in zip(attn_weights, mh_v)]
        x = torch.cat(mh_x, dim=1)
        return x


window_size = 8
path_batch = 1024
emb_dim = 96
emb_dim_out = 96
x = torch.rand(path_batch, window_size * window_size, emb_dim)
qkv_layer = SelfAttn(window_size, 1, emb_dim, emb_dim_out)
jit = torch.jit.trace(qkv_layer, (x))

mlmod_fixed_shape = ct.convert(
    jit,
    inputs=[
        ct.TensorType("x", x.shape),
    ],
    convert_to="mlprogram",
)
mlmodel_path = "test_ane.mlpackage"
mlmod_fixed_shape.save(mlmodel_path)

This is my profiler results: 20240808-091653 The uncached load took nearly 36 seconds, and it was just a single matrix multiplication.

cvv-student avatar Aug 08 '24 01:08 cvv-student

can you try to use slicing rather than torch.split? for me generally leads to better performance

something like this:

mh_q = [
    q[:, i * (self.d_qk // self.n_head):(i + 1) * (self.d_qk // self.n_head), :, :]
    for i in range(self.n_head)
]

kinghchan avatar Nov 08 '24 13:11 kinghchan

Thanks for your reply, but it seems the duration has not decreased. Moreover, the official demo also employs torch.split. I have revised it as follows

        mh_q = [
            q[:, i * (self.dim_out // self.num_heads):(i + 1) * (self.dim_out // self.num_heads), :, :]
            for i in range(self.num_heads)
        ]
        mh_v = [
            v_2d.reshape(B, -1, x_flat.shape[2], x_flat.shape[3])[:, i * (self.dim_out // self.num_heads):(i + 1) * (self.dim_out // self.num_heads), :, :]
            for i in range(self.num_heads)
        ]
        mh_k = [
            torch.permute(k, (0, 3, 2, 1))[:, :, :, i * (self.dim_out // self.num_heads):(i + 1) * (self.dim_out // self.num_heads)]
            for i in range(self.num_heads)
        ]

cvv-student avatar Nov 13 '24 02:11 cvv-student