oneflow icon indicating copy to clipboard operation
oneflow copied to clipboard

[Feature Request]: oneflow.nn.functional 缺失 scaled_dot_product_attention 属性

Open ccssu opened this issue 2 years ago • 1 comments

Background and motivation

image

https://github.com/huggingface/diffusers/blob/705c592ea98ba4e288d837b9cba2767623c78603/src/diffusers/models/attention_processor.py#L806-L807

>>> import oneflow.nn.functional as F
>>> hasattr(F,"scaled_dot_product_attention")
False
>>> import oneflow
>>> oneflow.__version__
'0.9.1+cu117.git.e4118c70a'

oneflow__version__ = 0.9.1+cu117.git.e4118c70a

diffusers >= 0.19.3

API Proposal

  • pytorch文档: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html?highlight=scaled_dot_product_attention#torch.nn.functional.scaled_dot_product_attention'
image
  • aten/src/ATen/native/transformers/attention.cpp#L695-L702: https://github.com/pytorch/pytorch/blob/e5548f81956157e1fdb76cb52579bffe5915dee1/aten/src/ATen/native/transformers/attention.cpp#L695-L702

API Usage

No response

Alternatives

No response

Risks

No response

ccssu avatar Sep 03 '23 03:09 ccssu

  • Scaled Dot Product Attention 算子开发调研
    • 1. 算法和数学公式
    • 2. 目标框架
    • 3. Torch 算子接口
      • 3.1 Python 接口
      • 3.2 C++ 接口
        • 算子接口定义
        • 算子接口实现
    • 单元测试和验证相关
    • 性能测试
    • 文档和分享

Scaled Dot Product Attention 算子开发调研

1. 算法和数学公式

缩放点积注意力是深度学习中常用的注意力机制,其数学公式如下:

$$ \mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$

参考 1. The Annotated Transformer

2. 目标框架

OneFlow 以实现缩放点积注意力算子。

3. Torch 算子接口

3.1 Python 接口

scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor:

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html?highlight=scaled_dot_product_attention#torch.nn.functional.scaled_dot_product_attention

import math
import torch
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

import torch
import torch.nn.functional as F
query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")

y = F.scaled_dot_product_attention(query,key,value)
print(y.shape)
print(y.flatten()[:10])

y = attention(query,key,value)[0]
print(y.shape)
print(y.flatten()[:10])
Output
torch.Size([32, 8, 128, 64])
tensor([0.4727, 0.5171, 0.5117, 0.5264, 0.4832, 0.5044, 0.4951, 0.5439, 0.4460,
        0.4988], device='cuda:0', dtype=torch.float16)
torch.Size([32, 8, 128, 64])
tensor([0.4727, 0.5171, 0.5117, 0.5264, 0.4832, 0.5044, 0.4951, 0.5439, 0.4460,
        0.4988], device='cuda:0', dtype=torch.float16)

3.2 C++ 接口

算子接口定义

# packages/pytorch/aten/src/ATen/native/native_functions.yaml:13806


- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> Tensor
  python_module: nn
  variants: function
  autogen: scaled_dot_product_attention.out

# TODO: THIS NEEDS TO BE REMOVED BUT PEOPLE HAVE TRAINED THEIR MODELS WITH THIS OP BUILTIN
- func: _scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
  python_module: nn
  variants: function
  autogen: _scaled_dot_product_attention.out

# This aten function is kept so that we can test the choice function from Python
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> int
  dispatch:
    Meta: _fused_sdp_choice_meta
    CPU, NestedTensorCPU: _fused_sdp_choice_cpp
    CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda

- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None) -> (Tensor, Tensor)
  variants: function

- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, int philox_seed, int philox_offset, Tensor debug_attn_mask)
  dispatch:
    CUDA: _scaled_dot_product_flash_attention_cuda
    NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda

算子接口实现

# pytorch/aten/src/ATen/native/transformers/attention.cpp:775

Tensor scaled_dot_product_attention(
    const Tensor& query_,
    const Tensor& key,
    const Tensor& value,
    const c10::optional<Tensor>& attn_mask_,
    double dropout_p,
    bool is_causal) {
  validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal);
  int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
  if (query_.device().type() == DeviceType::CUDA){
    choice_int = _fused_sdp_choice_stub(query_.device().type(),
      query_, key, value, attn_mask_, dropout_p, is_causal);
  }
  sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
  switch (backend) {
    case sdp::SDPBackend::flash_attention: {
      auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
          query_, key, value, dropout_p, is_causal);
      return std::get<0>(out_lse_softmax);
    }
    case sdp::SDPBackend::efficient_attention: {
      bool compute_logsumexp =
          (query_.requires_grad() || key.requires_grad() ||
           value.requires_grad());
      auto out_and_lse = at::_scaled_dot_product_efficient_attention(
          query_, key, value, compute_logsumexp, is_causal);
      return std::get<0>(out_and_lse);
    }
    case sdp::SDPBackend::math:
      return std::get<0>(at::_scaled_dot_product_attention_math(
          query_,
          key,
          value,
          attn_mask_,
          dropout_p,
          is_causal));
    default:
      TORCH_CHECK(
          false,
          "No viable backend for scaled_dot_product_attention was found.");
      return Tensor();
  }
}

单元测试和验证相关

  • pytorch/benchmarks/transformer/sdp_backwards.py
  • pytorch/benchmarks/transformer/sdp.py
  • pytorch/test/test_nestedtensor.py
  • pytorch/test/test_transformers.py
  • 。。。。。

性能测试

对优化后的模型进行性能评估,比较其在训练和推理阶段的速度与资源消耗。

文档和分享

ccssu avatar Sep 10 '23 16:09 ccssu