lit-llama icon indicating copy to clipboard operation
lit-llama copied to clipboard

Slice before F.scaled_dot_product_attention() to improve the performance

Open mzchtx opened this issue 2 years ago • 8 comments

I think we can slice k, v and mask before calling F.scaled_dot_product_attention() to reduce the calculation, otherwise the calculation is the same as max_seq_len even when input_pos is relatively small

https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py#L209-L235

image

mzchtx avatar Jun 13 '23 07:06 mzchtx

Hi @mzchtx. What changes are you proposing precisely?

k, v should already be sliced to the length of input_pos with https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py#L217-L218

carmocca avatar Jun 13 '23 13:06 carmocca

Hi @mzchtx. What changes are you proposing precisely?

k, v should already be sliced to the length of input_pos with https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py#L217-L218

input_copy does not change the shape of cache_k and cache_v. Here, the shape of k and v still remain max_seq_len, such as (B, nh, max_seq_len, hs). I think the shape should be (B, nh, input_pos, hs) to reduce unnecessary computations. Of course , I'm not certain if F.scaled_dot_product_attention() has any additional optimizations to avoid this.

We conducted an test, where the batch size is 128 and the input token lenght if 32, when we set max_seq_len to 64 and 128, there was a difference in the cost of generating 32 tokens, approximately around 30%.

mzchtx avatar Jun 14 '23 02:06 mzchtx

@mzchtx Would you like to open a PR with your suggested changes?

carmocca avatar Jun 15 '23 02:06 carmocca

@mzchtx Would you like to open a PR with your suggested changes?

I tested more data and had a more detailed comparison:

  1. The current implementation is more efficient When the batch size is relatively small (batch size < 16 or 32). It is IO-bound, so performing more calculations in exchange for IO operations would be more efficient.
  2. The IO bottleneck diminishes significantly When the batch size is relatively large (batch > 32). It is advisable to avoid excessive redundant calculations, making slicing a more efficient approach.

Of course, implementing a custom kernel to avoid memory copy generated by slicing and eliminating redundant computations would be the most efficient method. However, it would require additional work. Therefore, we can retain the current implementation.

Below are the detailed data; all experiments were done on A30 with LLaMA 7B.

image image

The modified code is shown below:

image

mzchtx avatar Jun 17 '23 08:06 mzchtx

@mzchtx The code would need to be indented under the if as before. since this is only relevant for the kv-cache case.

Leaving #382 aside, I believe the code should be

-                input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
+                input_pos = torch.tensor([max_seq_length - 1], device=input_pos.device)
                 # shift 1 position to the left
                 cache_k = torch.roll(cache_k, -1, dims=2)
                 cache_v = torch.roll(cache_v, -1, dims=2)
             k = cache_k.index_copy(2, input_pos, k)
             v = cache_v.index_copy(2, input_pos, v)
             kv_cache = k, v
+            input_pos = torch.arange(0, input_pos[-1] + 1, device=input_pos.device)
+            k = k.index_select(2, input_pos)
+            v = v.index_select(2, input_pos)
+            mask = mask.index_select(3, input_pos)

What is OPT and Baseline exactly in your tables?

carmocca avatar Jun 20 '23 01:06 carmocca

From playing with this, the generated outputs are not the same, meaning that this is not numerically equivalent. However, it's hard to tell if they are worse or just different.

Lastly, you mention that the benefit is mainly for large batch sizes. Our generation scripts do not support batched inference, and the kv-cache is only used during inference.

I am hesitant to make this change

carmocca avatar Jun 20 '23 02:06 carmocca

I stumbled upon this issue: https://github.com/pytorch/pytorch/issues/103082, it might explain the numerical difference.

carmocca avatar Jun 21 '23 18:06 carmocca

From playing with this, the generated outputs are not the same, meaning that this is not numerically equivalent. However, it's hard to tell if they are worse or just different.

Lastly, you mention that the benefit is mainly for large batch sizes. Our generation scripts do not support batched inference, and the kv-cache is only used during inference.

I am hesitant to make this change

Yes, this does only have an impact on Inference, and only on large batch sizes, whereas in real scenarios batch sizes are generally smaller, so I agree with your idea and can leave the status quo.

mzchtx avatar Jun 22 '23 13:06 mzchtx