xformers or flash-attention Support ?
We need xformer or flash-attention support for ‘mps’ devices, it can be speed up attention layer inference time 3-5 times !!!!
He's talking about metal-flash-attention which surpassed Apple ML Stable Diffusion in terms of performance.
He's talking about metal-flash-attention which surpassed Apple ML Stable Diffusion in terms of performance.
Thank you for your work on that =) – looks like something I'd like to incorporate in whatever I can. Is there any way I can patch in flash-attention? I was figuring at least with attention.metal(?)
I'm also curious how well, if at all, it could be integrated into MLX. If you haven't seen/heard about it yet
https://github.com/ml-explore/mlx/issues
I planned on integrating it into my ComfyUI StableDiff. instance in-lieu of PyTorch. I don't know if I'm just too novice to know any shortcuts but it seems like there's no way around manually parsing/editing relevant files 1 definition at a time.
But I'm figuring that unlike MLX vs. Pytorch, MFA probably uses the exact same vars/defs as MPS and thus it could just be implemented either atop MPS or in-lieu of it without the headache
Edit: Also, FYI as per the MLX devs (link below) they're very encouraging of any contributions that speed it up, so I figure MFA has a much better chance of implementation in that.
https://github.com/ml-explore/mlx/issues/40