Slice before F.scaled_dot_product_attention() to improve the performance
I think we can slice k, v and mask before calling F.scaled_dot_product_attention() to reduce the calculation, otherwise the calculation is the same as max_seq_len even when input_pos is relatively small
https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py#L209-L235
Hi @mzchtx. What changes are you proposing precisely?
k, v should already be sliced to the length of input_pos with https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py#L217-L218
Hi @mzchtx. What changes are you proposing precisely?
k, vshould already be sliced to the length ofinput_poswith https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py#L217-L218
input_copy does not change the shape of cache_k and cache_v. Here, the shape of k and v still remain max_seq_len, such as (B, nh, max_seq_len, hs). I think the shape should be (B, nh, input_pos, hs) to reduce unnecessary computations. Of course , I'm not certain if F.scaled_dot_product_attention() has any additional optimizations to avoid this.
We conducted an test, where the batch size is 128 and the input token lenght if 32, when we set max_seq_len to 64 and 128, there was a difference in the cost of generating 32 tokens, approximately around 30%.
@mzchtx Would you like to open a PR with your suggested changes?
@mzchtx Would you like to open a PR with your suggested changes?
I tested more data and had a more detailed comparison:
- The current implementation is more efficient When the batch size is relatively small (batch size < 16 or 32). It is IO-bound, so performing more calculations in exchange for IO operations would be more efficient.
- The IO bottleneck diminishes significantly When the batch size is relatively large (batch > 32). It is advisable to avoid excessive redundant calculations, making slicing a more efficient approach.
Of course, implementing a custom kernel to avoid memory copy generated by slicing and eliminating redundant computations would be the most efficient method. However, it would require additional work. Therefore, we can retain the current implementation.
Below are the detailed data; all experiments were done on A30 with LLaMA 7B.
The modified code is shown below:
@mzchtx The code would need to be indented under the if as before. since this is only relevant for the kv-cache case.
Leaving #382 aside, I believe the code should be
- input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
+ input_pos = torch.tensor([max_seq_length - 1], device=input_pos.device)
# shift 1 position to the left
cache_k = torch.roll(cache_k, -1, dims=2)
cache_v = torch.roll(cache_v, -1, dims=2)
k = cache_k.index_copy(2, input_pos, k)
v = cache_v.index_copy(2, input_pos, v)
kv_cache = k, v
+ input_pos = torch.arange(0, input_pos[-1] + 1, device=input_pos.device)
+ k = k.index_select(2, input_pos)
+ v = v.index_select(2, input_pos)
+ mask = mask.index_select(3, input_pos)
What is OPT and Baseline exactly in your tables?
From playing with this, the generated outputs are not the same, meaning that this is not numerically equivalent. However, it's hard to tell if they are worse or just different.
Lastly, you mention that the benefit is mainly for large batch sizes. Our generation scripts do not support batched inference, and the kv-cache is only used during inference.
I am hesitant to make this change
I stumbled upon this issue: https://github.com/pytorch/pytorch/issues/103082, it might explain the numerical difference.
From playing with this, the generated outputs are not the same, meaning that this is not numerically equivalent. However, it's hard to tell if they are worse or just different.
Lastly, you mention that the benefit is mainly for large batch sizes. Our generation scripts do not support batched inference, and the kv-cache is only used during inference.
I am hesitant to make this change
Yes, this does only have an impact on Inference, and only on large batch sizes, whereas in real scenarios batch sizes are generally smaller, so I agree with your idea and can leave the status quo.