llama.cpp icon indicating copy to clipboard operation
llama.cpp copied to clipboard

Demo usage of Flash Attention

Open ggerganov opened this issue 2 years ago • 7 comments

This is my understanding of how Flash Attention works based on this picture:

image

ref: https://github.com/HazyResearch/flash-attention

The implementation is here:

https://github.com/ggerganov/llama.cpp/blob/flash-attn/ggml.c#L8122-L8367

I don't plan on merging this because on M1 it is the same performance as without FA. However, in whisper.cpp I have gained performance from using this same exact call in the Encoder:

https://github.com/ggerganov/whisper.cpp/blob/0a2d1210bcb98978214bbf4e100922a413afd39d/whisper.cpp#L1482-L1508

Putting this here if someone wants to play with it or figures out how to implement sparse attention. The idea is just to merge the ggml operators into a single op and avoid intermediate tensors.

ggerganov avatar Apr 05 '23 15:04 ggerganov

Please merge this because it's amazing on x86 with longer context. I tried generating 1500 tokens with the 7B model ( --ignore-eos -c 2048 -n 1500). On the master branch the generation took 1385 seconds. On the flash-attn branch it took 200 seconds.

bakamomi avatar Apr 05 '23 18:04 bakamomi

~~Strange, when comparing #775 to this I noticed a regression in the time it took to generate 1024 tokens. #775~~

llama_print_timings:        load time =  2776.69 ms
llama_print_timings:      sample time =   801.28 ms /  1024 runs   (    0.78 ms per run)
llama_print_timings: prompt eval time =  1912.48 ms /    14 tokens (  136.61 ms per token)
llama_print_timings:        eval time = 189655.06 ms /  1023 runs   (  185.39 ms per run)
llama_print_timings:       total time = 193245.67 ms

#778 + #775 (fluke)

llama_print_timings:        load time =  2745.93 ms
llama_print_timings:      sample time =   814.02 ms /  1024 runs   (    0.79 ms per run)
llama_print_timings: prompt eval time =  1880.81 ms /    14 tokens (  134.34 ms per token)
llama_print_timings:        eval time = 250896.90 ms /  1023 runs   (  245.26 ms per run)
llama_print_timings:       total time = 254470.03 ms

Please merge this because it's amazing on x86 with longer context. I tried generating 1500 tokens with the 7B model ( --ignore-eos -c 2048 -n 1500). On the master branch the generation took 1385 seconds. On the flash-attn branch it took 200 seconds.

~~Are you certain that uplift isn't a result of #775? If you cloned the flash-attn branch it included that commit.~~

Edit: Will run some more tests just to make sure this isn't coincidental for my machine. Edit2: It was a fluke. Re-running again on this PR and I got a slightly better result now.

llama_print_timings:        load time =  2875.85 ms
llama_print_timings:      sample time =   802.98 ms /  1024 runs   (    0.78 ms per run)
llama_print_timings: prompt eval time =  1938.22 ms /    14 tokens (  138.44 ms per token)
llama_print_timings:        eval time = 180435.09 ms /  1023 runs   (  176.38 ms per run)
llama_print_timings:       total time = 184126.72 ms

rabidcopy avatar Apr 05 '23 19:04 rabidcopy

Alright, #775 clearly contributed to the results I got. I pulled master again with #775 already merged and now I'm getting:

llama_print_timings:        load time =   928.29 ms
llama_print_timings:      sample time =   859.52 ms /  1500 runs   (    0.57 ms per run)
llama_print_timings: prompt eval time =   454.82 ms /     8 tokens (   56.85 ms per token)
llama_print_timings:        eval time = 200382.48 ms /  1500 runs   (  133.59 ms per run)
llama_print_timings:       total time = 202195.50 ms

Exactly the same result as with flash attention. Just for reference this is what I got previously:

llama_print_timings:        load time =  1817.75 ms
llama_print_timings:      sample time =   857.91 ms /  1500 runs   (    0.57 ms per run)
llama_print_timings: prompt eval time =  1454.90 ms /     8 tokens (  181.86 ms per token)
llama_print_timings:        eval time = 1383048.29 ms /  1500 runs   (  922.03 ms per run)
llama_print_timings:       total time = 1385748.68 ms

I guess FA needs more testing.

bakamomi avatar Apr 05 '23 19:04 bakamomi

Alright, #775 clearly contributed to the results I got. I pulled master again with #775 already merged and now I'm getting:

llama_print_timings:        load time =   928.29 ms
llama_print_timings:      sample time =   859.52 ms /  1500 runs   (    0.57 ms per run)
llama_print_timings: prompt eval time =   454.82 ms /     8 tokens (   56.85 ms per token)
llama_print_timings:        eval time = 200382.48 ms /  1500 runs   (  133.59 ms per run)
llama_print_timings:       total time = 202195.50 ms

Exactly the same result as with flash attention. Just for reference this is what I got previously:

llama_print_timings:        load time =  1817.75 ms
llama_print_timings:      sample time =   857.91 ms /  1500 runs   (    0.57 ms per run)
llama_print_timings: prompt eval time =  1454.90 ms /     8 tokens (  181.86 ms per token)
llama_print_timings:        eval time = 1383048.29 ms /  1500 runs   (  922.03 ms per run)
llama_print_timings:       total time = 1385748.68 ms

I guess FA needs more testing.

Wow, that's quite a dramatic change nonetheless! I guess some systems were hit way harder than others by the V transpose on every token.

rabidcopy avatar Apr 05 '23 19:04 rabidcopy

I couldn't find a measurable difference between this and master on a 9900k.

slaren avatar Apr 05 '23 21:04 slaren

Yeah, no noticeable difference on a Ryzen 2600. But interesting if it can go somewhere.

rabidcopy avatar Apr 05 '23 22:04 rabidcopy

There's a good chance that CPU is more bottlenecked by compute than GPU, and that orig implementation already prefetches cache lines.

See: https://github.com/HazyResearch/flash-attention/issues/59

jon-chuang avatar Apr 12 '23 01:04 jon-chuang

Would this implementation also work on GPUs? Has anyone tried how well it works on GPUs?

NikolaBorisov avatar Sep 02 '23 00:09 NikolaBorisov

@ggerganov Tks. What about disabling it by default but merging it for people to be able at least to try with master via a cli arg ?

WilliamTambellini avatar Sep 20 '23 21:09 WilliamTambellini

how do you do the benchmarking?

Aya-ZIbra avatar Sep 26 '23 02:09 Aya-ZIbra

Is this the correct implementation? I think the effect on the GPU is good because it uses shared memory with higher bandwidth. On the CPU, should the block data be temporarily stored in registers to obtain higher bandwidth?

Waylon-Zhu avatar Sep 27 '23 07:09 Waylon-Zhu

A faster metal implementation:

https://github.com/philipturner/metal-flash-attention

cc: @philipturner

sroussey avatar Sep 27 '23 15:09 sroussey