Simplify CrossAttention to run on Apple Neural Engine
Is your feature request related to a problem? Please describe. I'm trying to convert portions of unet into CoreML. However, CrossAttention fails to compile to the Apple Neural Engine.
Describe the solution you'd like My best guess after a lot of experimentation on converting CrossAttention in many respects is that there are too many reshapes and transposes.
Is there a way to simplify
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
Using an einsum? Or some other method?
Thanks!
Hey @MatthewWaller,
Generally yes I think we can do this, we would have to see a PR though. Would you like to open a PR for this one? :-)
Also cc @pcuenca
yeah, giving it a shot now. I want to make sure the einsum I'm using will convert with coremltools though. Will check back in!
@pcuenca do you have any general ideas how to handle Apple Neural Engine?
Should we only focus on MPS for now?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.