[ET-VK] Implement SDPA + KV-Cache operator
Stack from ghstack (oldest at bottom):
- -> #5799
- #5831
- #5830
Context
As title, this diff adds an implementation for a fused SDPA + KV-Cache update operator which will be used in LLaMA models. Currently the SDPA portion of the operator is implemented via it's consituent operators, but a future optimization opportunity would be to implement a single flash attention shader.
Reference Implementation
For future reference, a reference implementation of the SDPA + KV cache update mechanism is shown below. This reference implementation was originally used to check intermediate outputs but in the end I decided to compare against the sdpa_with_kv_cache operator in extension/llm for simplicity.
at::Tensor convert_boolean_attn_mask(
const at::Tensor& attn_mask,
caffe2::TypeMeta dtype) {
// Convert boolean mask to additive mask; need to invert mask to indicate what
// to mask *out*.
if (attn_mask.dtype() == at::kBool) {
return at::where(
attn_mask.logical_not(),
-std::numeric_limits<double>::infinity(),
at::scalar_tensor(
0.0, at::TensorOptions().dtype(dtype).device(attn_mask.device())));
}
// Otherwise, attn_mask represents an additive attention tensor
return attn_mask;
}
at::Tensor construct_attention_mask(
const at::Tensor& q,
const at::Tensor& k_cache,
const int start_pos) {
const int max_seq_len = k_cache.size(1);
const int seq_len = q.size(1);
at::Tensor attn_mask_base =
at::ones({max_seq_len, start_pos + seq_len}, q.options().dtype(at::kBool))
.tril();
at::Tensor attn_mask_sliced =
at::slice(attn_mask_base, 0, start_pos, start_pos + seq_len);
attn_mask_sliced = convert_boolean_attn_mask(attn_mask_sliced, q.dtype());
return attn_mask_sliced;
}
std::vector<at::Tensor> sdpa_reference_impl(
const at::Tensor& q_projected,
const at::Tensor& k_projected,
const at::Tensor& v_projected,
at::Tensor& key_cache,
at::Tensor& value_cache,
const int64_t start_pos,
const int64_t seq_len,
const c10::optional<at::Tensor> __attn_mask_ignored,
const double dropout_p,
const bool is_causal,
const c10::optional<double> scale) {
at::Tensor attn_mask =
construct_attention_mask(q_projected, key_cache, start_pos);
at::Tensor key_cache_updated = at::slice_scatter(
key_cache, k_projected, 1, start_pos, start_pos + k_projected.size(1));
at::Tensor value_cache_updated = at::slice_scatter(
value_cache, v_projected, 1, start_pos, start_pos + v_projected.size(1));
at::Tensor key_cache_sliced =
at::slice(key_cache_updated, 1, 0, start_pos + q_projected.size(1));
at::Tensor value_cache_sliced =
at::slice(value_cache_updated, 1, 0, start_pos + q_projected.size(1));
at::Tensor q_transposed = q_projected.transpose(1, 2);
at::Tensor k_transposed = key_cache_sliced.transpose(1, 2);
at::Tensor v_transposed = value_cache_sliced.transpose(1, 2);
// Skip doing repeat_interleave; assume that num_attention_heads ==
// num_kv_heads
float scale_factor = 1.0 / sqrt(q_transposed.size(-1));
at::Tensor k_transposed_2 = k_transposed.transpose(-2, -1);
at::Tensor attn_weight_prescale = at::matmul(q_transposed, k_transposed_2);
at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask;
at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1);
at::Tensor out = at::matmul(attn_weight_softmax, v_transposed);
return {
out.transpose(1, 2),
key_cache_sliced,
value_cache_sliced,
q_transposed,
k_transposed,
v_transposed,
k_transposed_2,
attn_weight_prescale,
attn_weight,
attn_weight_softmax,
out,
};
}
Differential Revision: D63724114
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/5799
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit c22e92565b601bfbd91a5ddc41c25893d8658f45 with merge base 84f5a561e53108f3ee7e99f20df05b65cc359488 ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
This pull request was exported from Phabricator. Differential Revision: D63724114
This pull request was exported from Phabricator. Differential Revision: D63724114
This pull request was exported from Phabricator. Differential Revision: D63724114
This pull request was exported from Phabricator. Differential Revision: D63724114
This pull request was exported from Phabricator. Differential Revision: D63724114
This pull request was exported from Phabricator. Differential Revision: D63724114
This pull request was exported from Phabricator. Differential Revision: D63724114
This pull request was exported from Phabricator. Differential Revision: D63724114
This pull request has been merged in pytorch/executorch@6e871c3b617fb3ff72dc8ad155e2f04ab441be3f.