[Feature Request]: oneflow.nn.functional 缺失 scaled_dot_product_attention 属性
Background and motivation
https://github.com/huggingface/diffusers/blob/705c592ea98ba4e288d837b9cba2767623c78603/src/diffusers/models/attention_processor.py#L806-L807
>>> import oneflow.nn.functional as F
>>> hasattr(F,"scaled_dot_product_attention")
False
>>> import oneflow
>>> oneflow.__version__
'0.9.1+cu117.git.e4118c70a'
oneflow__version__ = 0.9.1+cu117.git.e4118c70a
diffusers >= 0.19.3
API Proposal
- pytorch文档: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html?highlight=scaled_dot_product_attention#torch.nn.functional.scaled_dot_product_attention'
- aten/src/ATen/native/transformers/attention.cpp#L695-L702: https://github.com/pytorch/pytorch/blob/e5548f81956157e1fdb76cb52579bffe5915dee1/aten/src/ATen/native/transformers/attention.cpp#L695-L702
API Usage
No response
Alternatives
No response
Risks
No response
-
Scaled Dot Product Attention 算子开发调研
- 1. 算法和数学公式
- 2. 目标框架
-
3. Torch 算子接口
- 3.1 Python 接口
-
3.2 C++ 接口
- 算子接口定义
- 算子接口实现
- 单元测试和验证相关
- 性能测试
- 文档和分享
Scaled Dot Product Attention 算子开发调研
1. 算法和数学公式
缩放点积注意力是深度学习中常用的注意力机制,其数学公式如下:
$$ \mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V $$
参考 1. The Annotated Transformer
2. 目标框架
OneFlow 以实现缩放点积注意力算子。
3. Torch 算子接口
3.1 Python 接口
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor:
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html?highlight=scaled_dot_product_attention#torch.nn.functional.scaled_dot_product_attention
import math
import torch
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
import torch
import torch.nn.functional as F
query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
y = F.scaled_dot_product_attention(query,key,value)
print(y.shape)
print(y.flatten()[:10])
y = attention(query,key,value)[0]
print(y.shape)
print(y.flatten()[:10])
Output
torch.Size([32, 8, 128, 64])
tensor([0.4727, 0.5171, 0.5117, 0.5264, 0.4832, 0.5044, 0.4951, 0.5439, 0.4460,
0.4988], device='cuda:0', dtype=torch.float16)
torch.Size([32, 8, 128, 64])
tensor([0.4727, 0.5171, 0.5117, 0.5264, 0.4832, 0.5044, 0.4951, 0.5439, 0.4460,
0.4988], device='cuda:0', dtype=torch.float16)
3.2 C++ 接口
算子接口定义
# packages/pytorch/aten/src/ATen/native/native_functions.yaml:13806
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> Tensor
python_module: nn
variants: function
autogen: scaled_dot_product_attention.out
# TODO: THIS NEEDS TO BE REMOVED BUT PEOPLE HAVE TRAINED THEIR MODELS WITH THIS OP BUILTIN
- func: _scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
python_module: nn
variants: function
autogen: _scaled_dot_product_attention.out
# This aten function is kept so that we can test the choice function from Python
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False) -> int
dispatch:
Meta: _fused_sdp_choice_meta
CPU, NestedTensorCPU: _fused_sdp_choice_cpp
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None) -> (Tensor, Tensor)
variants: function
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, int philox_seed, int philox_offset, Tensor debug_attn_mask)
dispatch:
CUDA: _scaled_dot_product_flash_attention_cuda
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
算子接口实现
# pytorch/aten/src/ATen/native/transformers/attention.cpp:775
Tensor scaled_dot_product_attention(
const Tensor& query_,
const Tensor& key,
const Tensor& value,
const c10::optional<Tensor>& attn_mask_,
double dropout_p,
bool is_causal) {
validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal);
int64_t choice_int = static_cast<int64_t>(sdp::SDPBackend::math);
if (query_.device().type() == DeviceType::CUDA){
choice_int = _fused_sdp_choice_stub(query_.device().type(),
query_, key, value, attn_mask_, dropout_p, is_causal);
}
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
switch (backend) {
case sdp::SDPBackend::flash_attention: {
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
query_, key, value, dropout_p, is_causal);
return std::get<0>(out_lse_softmax);
}
case sdp::SDPBackend::efficient_attention: {
bool compute_logsumexp =
(query_.requires_grad() || key.requires_grad() ||
value.requires_grad());
auto out_and_lse = at::_scaled_dot_product_efficient_attention(
query_, key, value, compute_logsumexp, is_causal);
return std::get<0>(out_and_lse);
}
case sdp::SDPBackend::math:
return std::get<0>(at::_scaled_dot_product_attention_math(
query_,
key,
value,
attn_mask_,
dropout_p,
is_causal));
default:
TORCH_CHECK(
false,
"No viable backend for scaled_dot_product_attention was found.");
return Tensor();
}
}
单元测试和验证相关
- pytorch/benchmarks/transformer/sdp_backwards.py
- pytorch/benchmarks/transformer/sdp.py
- pytorch/test/test_nestedtensor.py
- pytorch/test/test_transformers.py
- 。。。。。
性能测试
对优化后的模型进行性能评估,比较其在训练和推理阶段的速度与资源消耗。