FireRedASR icon indicating copy to clipboard operation
FireRedASR copied to clipboard

用torch原生的flash attention性能更好

Open xphh opened this issue 6 months ago • 2 comments

transformer_decoder.py里面可以替换torch原生的scaled_dot_product_attention函数

第247行:output = self.attention(q, k, v, mask=mask)

改成:output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask.bool())

整体性能大概可以提升10%

conformer_encoder.py里面应该也可以,但逻辑稍微有点不一样,我还不知道怎么改,麻烦作者可以看看

xphh avatar Jul 04 '25 07:07 xphh

或者直接使用torch.compile

xphh avatar Jul 07 '25 02:07 xphh

我优化了一版性能提升了50%左右吧,你可以是试一试:https://github.com/FireRedTeam/FireRedASR/pull/105

xsank avatar Nov 13 '25 06:11 xsank