diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Simplify CrossAttention to run on Apple Neural Engine

Open MatthewWaller opened this issue 3 years ago • 3 comments

Is your feature request related to a problem? Please describe. I'm trying to convert portions of unet into CoreML. However, CrossAttention fails to compile to the Apple Neural Engine.

Describe the solution you'd like My best guess after a lot of experimentation on converting CrossAttention in many respects is that there are too many reshapes and transposes.

Is there a way to simplify

    def reshape_heads_to_batch_dim(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.heads
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
        return tensor

    def reshape_batch_dim_to_heads(self, tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.heads
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
        return tensor

Using an einsum? Or some other method?

Thanks!

MatthewWaller avatar Sep 28 '22 16:09 MatthewWaller

Hey @MatthewWaller,

Generally yes I think we can do this, we would have to see a PR though. Would you like to open a PR for this one? :-)

patrickvonplaten avatar Sep 29 '22 18:09 patrickvonplaten

Also cc @pcuenca

patrickvonplaten avatar Sep 29 '22 18:09 patrickvonplaten

yeah, giving it a shot now. I want to make sure the einsum I'm using will convert with coremltools though. Will check back in!

MatthewWaller avatar Sep 29 '22 18:09 MatthewWaller

@pcuenca do you have any general ideas how to handle Apple Neural Engine?

patrickvonplaten avatar Oct 27 '22 08:10 patrickvonplaten

Should we only focus on MPS for now?

patrickvonplaten avatar Oct 27 '22 08:10 patrickvonplaten

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Nov 20 '22 15:11 github-actions[bot]