Demo usage of Flash Attention
This is my understanding of how Flash Attention works based on this picture:

ref: https://github.com/HazyResearch/flash-attention
The implementation is here:
https://github.com/ggerganov/llama.cpp/blob/flash-attn/ggml.c#L8122-L8367
I don't plan on merging this because on M1 it is the same performance as without FA.
However, in whisper.cpp I have gained performance from using this same exact call in the Encoder:
https://github.com/ggerganov/whisper.cpp/blob/0a2d1210bcb98978214bbf4e100922a413afd39d/whisper.cpp#L1482-L1508
Putting this here if someone wants to play with it or figures out how to implement sparse attention.
The idea is just to merge the ggml operators into a single op and avoid intermediate tensors.
Please merge this because it's amazing on x86 with longer context. I tried generating 1500 tokens with the 7B model ( --ignore-eos -c 2048 -n 1500). On the master branch the generation took 1385 seconds. On the flash-attn branch it took 200 seconds.
~~Strange, when comparing #775 to this I noticed a regression in the time it took to generate 1024 tokens. #775~~
llama_print_timings: load time = 2776.69 ms
llama_print_timings: sample time = 801.28 ms / 1024 runs ( 0.78 ms per run)
llama_print_timings: prompt eval time = 1912.48 ms / 14 tokens ( 136.61 ms per token)
llama_print_timings: eval time = 189655.06 ms / 1023 runs ( 185.39 ms per run)
llama_print_timings: total time = 193245.67 ms
#778 + #775 (fluke)
llama_print_timings: load time = 2745.93 ms
llama_print_timings: sample time = 814.02 ms / 1024 runs ( 0.79 ms per run)
llama_print_timings: prompt eval time = 1880.81 ms / 14 tokens ( 134.34 ms per token)
llama_print_timings: eval time = 250896.90 ms / 1023 runs ( 245.26 ms per run)
llama_print_timings: total time = 254470.03 ms
Please merge this because it's amazing on x86 with longer context. I tried generating 1500 tokens with the 7B model ( --ignore-eos -c 2048 -n 1500). On the master branch the generation took 1385 seconds. On the flash-attn branch it took 200 seconds.
~~Are you certain that uplift isn't a result of #775? If you cloned the flash-attn branch it included that commit.~~
Edit: Will run some more tests just to make sure this isn't coincidental for my machine. Edit2: It was a fluke. Re-running again on this PR and I got a slightly better result now.
llama_print_timings: load time = 2875.85 ms
llama_print_timings: sample time = 802.98 ms / 1024 runs ( 0.78 ms per run)
llama_print_timings: prompt eval time = 1938.22 ms / 14 tokens ( 138.44 ms per token)
llama_print_timings: eval time = 180435.09 ms / 1023 runs ( 176.38 ms per run)
llama_print_timings: total time = 184126.72 ms
Alright, #775 clearly contributed to the results I got. I pulled master again with #775 already merged and now I'm getting:
llama_print_timings: load time = 928.29 ms
llama_print_timings: sample time = 859.52 ms / 1500 runs ( 0.57 ms per run)
llama_print_timings: prompt eval time = 454.82 ms / 8 tokens ( 56.85 ms per token)
llama_print_timings: eval time = 200382.48 ms / 1500 runs ( 133.59 ms per run)
llama_print_timings: total time = 202195.50 ms
Exactly the same result as with flash attention. Just for reference this is what I got previously:
llama_print_timings: load time = 1817.75 ms
llama_print_timings: sample time = 857.91 ms / 1500 runs ( 0.57 ms per run)
llama_print_timings: prompt eval time = 1454.90 ms / 8 tokens ( 181.86 ms per token)
llama_print_timings: eval time = 1383048.29 ms / 1500 runs ( 922.03 ms per run)
llama_print_timings: total time = 1385748.68 ms
I guess FA needs more testing.
Alright, #775 clearly contributed to the results I got. I pulled master again with #775 already merged and now I'm getting:
llama_print_timings: load time = 928.29 ms llama_print_timings: sample time = 859.52 ms / 1500 runs ( 0.57 ms per run) llama_print_timings: prompt eval time = 454.82 ms / 8 tokens ( 56.85 ms per token) llama_print_timings: eval time = 200382.48 ms / 1500 runs ( 133.59 ms per run) llama_print_timings: total time = 202195.50 msExactly the same result as with flash attention. Just for reference this is what I got previously:
llama_print_timings: load time = 1817.75 ms llama_print_timings: sample time = 857.91 ms / 1500 runs ( 0.57 ms per run) llama_print_timings: prompt eval time = 1454.90 ms / 8 tokens ( 181.86 ms per token) llama_print_timings: eval time = 1383048.29 ms / 1500 runs ( 922.03 ms per run) llama_print_timings: total time = 1385748.68 msI guess FA needs more testing.
Wow, that's quite a dramatic change nonetheless! I guess some systems were hit way harder than others by the V transpose on every token.
I couldn't find a measurable difference between this and master on a 9900k.
Yeah, no noticeable difference on a Ryzen 2600. But interesting if it can go somewhere.
There's a good chance that CPU is more bottlenecked by compute than GPU, and that orig implementation already prefetches cache lines.
See: https://github.com/HazyResearch/flash-attention/issues/59
Would this implementation also work on GPUs? Has anyone tried how well it works on GPUs?
@ggerganov Tks. What about disabling it by default but merging it for people to be able at least to try with master via a cli arg ?
how do you do the benchmarking?
Is this the correct implementation? I think the effect on the GPU is good because it uses shared memory with higher bandwidth. On the CPU, should the block data be temporarily stored in registers to obtain higher bandwidth?
A faster metal implementation:
https://github.com/philipturner/metal-flash-attention
cc: @philipturner