Offloaded KV Cache
What does this PR do?
Fixes #30704
This PR introduces OffloadedCache. This is a KV cache implementation that reduces GPU memory usage in exchange for more CPU memory usage and a small increase in generation time. During the forward passes in generate, it only keeps two layers of KV cache on the device: the current layer and the next layer. All other layers are on the CPU and are prefetched/evicted as necessary.
It can be used by passing cache_implementation="offloaded" in the GenerationConfig like this:
gen_config = GenerationConfig(
cache_implementation="offloaded",
# other generation options such as
num_beams=4,
num_beam_groups=2,
num_return_sequences=4,
diversity_penalty=1.0,
max_new_tokens=50,
early_stopping=True,
)
outputs = model.generate(
inputs["input_ids"],
generation_config=gen_config
)
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case. Issue #30704
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @ArthurZucker @gante
cc @ArthurZucker @gante
I have incorporated the comment about removing the two legacy implementations and added a couple of tests that ensure the cache works the same way as the DynamicCache and its peak memory usage is lower.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Is something holding this back? The failing tests are all marked as flaky.
Is something holding this back? The failing tests are all marked as flaky.
@n17s just the limited bandwidth on our end -- my apologies for the delay, the PR looks great! Rebasing should fix the CI issues, we've fixed them on main 🤗
just the limited bandwidth on our end
@gante Oh sorry, I was not aware. I removed flash attention but this unmasked a bug related to synchronization of my non-blocking moving of tensors (never tested without it before😁). I have added explicit synchronization right before a tensor is needed on the device or is safe to be evicted. This makes the tests pass while still keeping all the moving non-blocking.
Some tests are still failing but they seem unrelated and/or flaky.
Since the tests have moved, I have rebased to the latest.
Gentle ping @ArthurZucker
I was off for a week! Back now, will review!
Looks great, small Q about using cuda streams, we should also add this to the doc / show and example as I guess this works with compile no?
Not sure. How to test it?
You can just compile the model's forward and do this:
So this does not work with offloaded cache. The error message is
AttributeError: 'NoneType' object has no attribute 'synchronize'
Stemming from
self.prefetch_stream.synchronize()
Based on my understanding of the discussion on this pytorch issue what is happening is that CUDA streams are not preserved by the current implementation of torch.compile. Since this seems to be a limitation with how torch.compile currently handles things, and may eventually be fixed according to the discussion, I'd say there's not much that can be done on our side.
In light of the recent discussion on the issue that started this PR (#30704), a static offloaded cache can handle the torch.compile case much better than this class.
What is missing here for me is mostly raise error to protect usage (on GPU) mostly, resolve merge conflicts and should be good to go
Merge conflicts have been resolved and I added a check in the constructor to protect usage. Fun fact, the class works even if inference happens on the CPU, as long as there's a GPU on the machine.
(It's agreen on my side to merge otherwise!)
One thing missing is a bit of doc / doc with an exampe of how to easily activate this
Added some docs
this cache is very helpful (for low-memory long-context inference) thanks for adding it!