[whisper] static kv cache
What does this PR do?
Supersedes https://github.com/huggingface/transformers/pull/28931 and extends it by adding static k/v cache support for Whisper. Also improves the performance of the eager attention implementation by removing un-necessary reshapes (inspired by LlamaAttention).
Similar to #28931, we use a separate cache for the self-attention and cross-attention layers. We define a lightweight EncoderDecoderCache wrapper that holds these two cache classes and implements common base methods (e.g. to_legacy_cache()) by calling the corresponding methods for each cache class.
However, there is one hurdle in enabling compatibility with torch.compile. Namely, we have to determine whether we're in the first decoding step, or second step onwards:
- In the first decoding step, we compute the cross-attention k/v states and update the cache accordingly
- In the second step onwards, we re-use the k/v states directly from the cache. There’s no further update to the cross-attention cache, since the k/v states are derived entirely from the encoder hidden-states (which stay fixed)
=> the difficulty is in detecting whether we’re in the first decoding step (1), or second step onwards (2). With eager mode, we can condition on past_key_values.get_seq_length() to determine the decoding step. However, for torch.compile this introduces a graph break. Consequently, we add a boolean flag is_updated to the StaticCache class, which informs us whether the cache has been updated or not. The alternative would be to employ the same logic we do in the Flax code, where we re-compute the cross-attention k/v states each time. Benchmarks show this approach is 1.4x slower than adding the CPU flag.
Using the .generate API with Whisper medium, we get approximately 5x speed-up when generating 64 tokens using sdpa attention. Note here that we compile the forward pass only:
| bsz | dynamic tok/s | compiled tok/s | Speed-up |
|---|---|---|---|
| 1 | 55.6 | 270.7 | 4.9 |
| 2 | 111.4 | 541.3 | 4.9 |
| 4 | 222.3 | 1078.8 | 4.9 |
| 8 | 446.3 | 2167.4 | 4.9 |
Extended results:
Whisper large-v3
| bsz | dynamic tok/s | compiled tok/s | Speed-up |
|---|---|---|---|
| 1 | 41.1 | 190.4 | 4.6 |
| 2 | 82.1 | 381.2 | 4.6 |
| 4 | 162.9 | 761.2 | 4.7 |
| 8 | 331.3 | 1522.5 | 4.6 |
Distil-Whisper distil-large-v3
| bsz | dynamic tok/s | compiled tok/s | Speed-up |
|---|---|---|---|
| 1 | 278.7 | 449.1 | 1.6 |
| 2 | 560.5 | 900.3 | 1.6 |
| 4 | 1113.2 | 1798.7 | 1.6 |
| 8 | 2225.0 | 3592.8 | 1.6 |
As expected, the speed-ups for Distil-Whisper are less pronounced:
- With only 2 decoder layers, the decoder forward pass is already >6x faster than Whisper, and we have a very small decoder graph that can be compiled
- The overhead from the logits post-processing now occupies a greater proportion of the generation time. Compiling the logits processors is a good next step for speeding-up generation further.
Code example:
from datasets import load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch
import logging
import time
torch._logging.set_logs(graph_breaks=True, recompiles=True)
torch_device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
model.to(torch_device, dtype=torch_dtype)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
inputs = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").to(torch_device)
input_features = inputs.input_features.to(torch_dtype)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
model.generation_config.cache_implementation = "static"
# compile
for i in range(2):
model.generate(input_features)
# inference
pred_ids = model.generate(input_features)
In refactoring the eager attention implementation for the cache abstraction, I managed to remove a lot of wasteful .view operations, generally aligning it with LLaMA and giving a performance boost even without compile (TODO: quantify speed-up).
The only regression comes when using FA2 and compile, where we have to introduce a bunch of new .transpose operations for compatibility with the shape of our k/v cache (TODO: quantify regression). This is also a known problem in LLaMA.
There are a few tidy-up points left TODO. Once we're happy with the design, I'll complete the PR with the final checklist items:
- [x] Fix failing fast tests
- [x] Tidy docstrings for new arguments (
past_key_values,cache_position) - [x] Update model doc with FA2 usage
- [x] Run all Whisper slow tests
- [x] Run all ASR pipeline slow tests
- [ ] Check gradients propagate correctly when training with
output_attentions=True
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
You can reference my PR #30949 for tests failing part, it passes all the tests that the current main branch passes and will save you a lot of time debugging @sanchit-gandhi
Commit https://github.com/huggingface/transformers/pull/31166/commits/93c97c1feff239fee19e8012294d2622a6e3339f details what the design would look like with a lightweight EncoderDecoderCache wrapper around the current Cache classes.
While this wrapper simplifies the API slightly, it also introduces the maintenance burden that we wanted to avoid with encoder-decoder static cache, since it adds a new class that has to be updated in accordance with the existing Cache classes. My opinion is that the improvement to the API is minimal, and does not justify the maintenance burden.
Consequently, I've reverted back to the tuple design originally proposed and discussed with @gante. Would be interested in hearing whether you still agree with this design having seen what it looks like in-code!
(@sanchit-gandhi pls ping when it's ready for re-review!)
Also, 5x speedup 🔥 🔥
(copying from internal discussion with @gante)
Tried a few different options for detecting the decoding step: the one that seemed cleanest and most performant was adding a boolean flag (commit). We get a 1.4x speed-up using this flag compared to re-computing the cross-attn k/v states at each step:
- Dynamic: 450 tok/s
- Re-computing: 1600 tok/s
- With flag: 2200 tok/s
I also compared this to the theoretical upper-bound, where we prefill the cross-attn k/v cache outside of the generation loop and re-use this cache at every step (not just for steps 2 onwards). This achives 2300 tok/s, so we can be pretty confident we’re at the limit of performance with this boolean flag approach (2200 tok/s)
Thanks both for the reviews! Confirming that the slow tests pass on the DGX A100.
Going to merge this one to enable static kv cache for:
- Short-form generation
- Long-form generation without fallback (i.e. sequential generation without temperature fallback)
We'll need a follow-up PR to enable:
- Long-form generation with fallback: remember that we dynamically reduce the batch size when we do temperature fallback. We'll need to change this to fixed batch sizes for compile
- Long-form chunked generation with pipeline: again, the batch size is set dynamically in the
pipelineclass, depending on the length of the inputs
Hi, I am getting some cache errors while doing generation with llama3 and fsdp. I am using flash_attention_2, and the use_cache=True in the generate function. Latest transformer from the repo including your recent PR.
[rank1]: Traceback (most recent call last):
[rank1]: File "/fs01/home/snajafi/codes/llm-research/src/squadv2_finetuning.py", line 129, in <module>
[rank1]: app.run(main)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/absl/app.py", line 308, in run
[rank1]: _run_main(main, args)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
[rank1]: sys.exit(main(argv))
[rank1]: File "/fs01/home/snajafi/codes/llm-research/src/squadv2_finetuning.py", line 91, in main
[rank1]: results = train(
[rank1]: File "/fs01/home/snajafi/codes/llm-research/src/utils/train_utils.py", line 108, in train
[rank1]: eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity, eval_scores = evaluation(
[rank1]: File "/fs01/home/snajafi/codes/llm-research/src/utils/train_utils.py", line 405, in evaluation
[rank1]: for ret_row, ret_loss in model.predict(batch):
[rank1]: File "/fs01/home/snajafi/codes/llm-research/src/llm.py", line 245, in predict
[rank1]: answers, log_ps = self.generation_pass(batch)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/src/llm.py", line 216, in generation_pass
[rank1]: predictions_output = self.model.generate(
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/peft/peft_model.py", line 1491, in generate
[rank1]: outputs = self.base_model.generate(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/generation/utils.py", line 1945, in generate
[rank1]: result = self._sample(
[rank1]: File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/generation/utils.py", line 2693, in _sample
[rank1]: outputs = self(
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/models/llama/modeling_llama.py", line 1174, in forward
[rank1]: outputs = self.model(
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/models/llama/modeling_llama.py", line 978, in forward
[rank1]: layer_outputs = decoder_layer(
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 857, in forward
[rank1]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward
[rank1]: return self.checkpoint_fn( # type: ignore[misc]
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
[rank1]: return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
[rank1]: return fn(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
[rank1]: return fn(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 494, in checkpoint
[rank1]: ret = function(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/models/llama/modeling_llama.py", line 718, in forward
[rank1]: hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/llm-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/models/llama/modeling_llama.py", line 431, in forward
[rank1]: key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
[rank1]: File "/fs01/home/snajafi/codes/llm-research/transformers/src/transformers/cache_utils.py", line 366, in update
[rank1]: return self.key_cache[layer_idx], self.value_cache[layer_idx]
[rank1]: IndexError: list index out of range
Hey @SaeedNajafi - do you have a minimal reproducer you could use to open a new issue on the repo? Thanks!
The pipeline needs more work, specifically for longer audios + the merging solution. Your controbution is welcome, especially for 1) if you have a wroking snippet feel free to add it to the doc
The pipeline needs more work, specifically for longer audios + the merging solution. Your controbution is welcome, especially for 1) if you have a wroking snippet feel free to add it to the doc
Thanks. I deleted the comment once I saw the PR already in progress https://github.com/huggingface/transformers/pull/31772 for this exact thing. I think it's better to wait for the merge.