executorch icon indicating copy to clipboard operation
executorch copied to clipboard

[ET-VK] Implement SDPA + KV-Cache operator

Open SS-JIA opened this issue 1 year ago • 5 comments

Stack from ghstack (oldest at bottom):

  • -> #5799
  • #5831
  • #5830

Context

As title, this diff adds an implementation for a fused SDPA + KV-Cache update operator which will be used in LLaMA models. Currently the SDPA portion of the operator is implemented via it's consituent operators, but a future optimization opportunity would be to implement a single flash attention shader.

Reference Implementation

For future reference, a reference implementation of the SDPA + KV cache update mechanism is shown below. This reference implementation was originally used to check intermediate outputs but in the end I decided to compare against the sdpa_with_kv_cache operator in extension/llm for simplicity.

at::Tensor convert_boolean_attn_mask(
    const at::Tensor& attn_mask,
    caffe2::TypeMeta dtype) {
  // Convert boolean mask to additive mask; need to invert mask to indicate what
  // to mask *out*.
  if (attn_mask.dtype() == at::kBool) {
    return at::where(
        attn_mask.logical_not(),
        -std::numeric_limits<double>::infinity(),
        at::scalar_tensor(
            0.0, at::TensorOptions().dtype(dtype).device(attn_mask.device())));
  }
  // Otherwise, attn_mask represents an additive attention tensor
  return attn_mask;
}

at::Tensor construct_attention_mask(
    const at::Tensor& q,
    const at::Tensor& k_cache,
    const int start_pos) {
  const int max_seq_len = k_cache.size(1);
  const int seq_len = q.size(1);
  at::Tensor attn_mask_base =
      at::ones({max_seq_len, start_pos + seq_len}, q.options().dtype(at::kBool))
          .tril();

  at::Tensor attn_mask_sliced =
      at::slice(attn_mask_base, 0, start_pos, start_pos + seq_len);

  attn_mask_sliced = convert_boolean_attn_mask(attn_mask_sliced, q.dtype());

  return attn_mask_sliced;
}

std::vector<at::Tensor> sdpa_reference_impl(
    const at::Tensor& q_projected,
    const at::Tensor& k_projected,
    const at::Tensor& v_projected,
    at::Tensor& key_cache,
    at::Tensor& value_cache,
    const int64_t start_pos,
    const int64_t seq_len,
    const c10::optional<at::Tensor> __attn_mask_ignored,
    const double dropout_p,
    const bool is_causal,
    const c10::optional<double> scale) {
  at::Tensor attn_mask =
      construct_attention_mask(q_projected, key_cache, start_pos);

  at::Tensor key_cache_updated = at::slice_scatter(
      key_cache, k_projected, 1, start_pos, start_pos + k_projected.size(1));
  at::Tensor value_cache_updated = at::slice_scatter(
      value_cache, v_projected, 1, start_pos, start_pos + v_projected.size(1));

  at::Tensor key_cache_sliced =
      at::slice(key_cache_updated, 1, 0, start_pos + q_projected.size(1));

  at::Tensor value_cache_sliced =
      at::slice(value_cache_updated, 1, 0, start_pos + q_projected.size(1));

  at::Tensor q_transposed = q_projected.transpose(1, 2);
  at::Tensor k_transposed = key_cache_sliced.transpose(1, 2);
  at::Tensor v_transposed = value_cache_sliced.transpose(1, 2);

  // Skip doing repeat_interleave; assume that num_attention_heads ==
  // num_kv_heads

  float scale_factor = 1.0 / sqrt(q_transposed.size(-1));

  at::Tensor k_transposed_2 = k_transposed.transpose(-2, -1);

  at::Tensor attn_weight_prescale = at::matmul(q_transposed, k_transposed_2);
  at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask;

  at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1);
  at::Tensor out = at::matmul(attn_weight_softmax, v_transposed);

  return {
      out.transpose(1, 2),
      key_cache_sliced,
      value_cache_sliced,
      q_transposed,
      k_transposed,
      v_transposed,
      k_transposed_2,
      attn_weight_prescale,
      attn_weight,
      attn_weight_softmax,
      out,
  };
}

Differential Revision: D63724114

SS-JIA avatar Oct 01 '24 22:10 SS-JIA

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/5799

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit c22e92565b601bfbd91a5ddc41c25893d8658f45 with merge base 84f5a561e53108f3ee7e99f20df05b65cc359488 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Oct 01 '24 22:10 pytorch-bot[bot]

This pull request was exported from Phabricator. Differential Revision: D63724114

facebook-github-bot avatar Oct 01 '24 22:10 facebook-github-bot

This pull request was exported from Phabricator. Differential Revision: D63724114

facebook-github-bot avatar Oct 02 '24 21:10 facebook-github-bot

This pull request was exported from Phabricator. Differential Revision: D63724114

facebook-github-bot avatar Oct 02 '24 23:10 facebook-github-bot

This pull request was exported from Phabricator. Differential Revision: D63724114

facebook-github-bot avatar Oct 02 '24 23:10 facebook-github-bot

This pull request was exported from Phabricator. Differential Revision: D63724114

facebook-github-bot avatar Oct 07 '24 12:10 facebook-github-bot

This pull request was exported from Phabricator. Differential Revision: D63724114

facebook-github-bot avatar Oct 07 '24 13:10 facebook-github-bot

This pull request was exported from Phabricator. Differential Revision: D63724114

facebook-github-bot avatar Oct 07 '24 14:10 facebook-github-bot

This pull request was exported from Phabricator. Differential Revision: D63724114

facebook-github-bot avatar Oct 07 '24 18:10 facebook-github-bot

This pull request has been merged in pytorch/executorch@6e871c3b617fb3ff72dc8ad155e2f04ab441be3f.

facebook-github-bot avatar Oct 07 '24 21:10 facebook-github-bot