composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

WIP: Attn bwd prototype 2(adding dropout)

Open ltqin opened this issue 2 years ago • 1 comments

Bare minimum batched multihead attention backward kernel. Many missing functionalities:

  • ~alpha(QK) scaling~ implemented
  • ~masking~ implemented
  • ~dropout~ implemented

Some quirks that need to be ironed out too. Eg:

  • A/B/B1/C tensor sometimes mean Q/K/V/Y tensors
  • Currently exposed tuning parameter is the same as attention fwd; not sure what we should expose now with the added complexity
  • Some sizes / init method can report validation failure; not sure if it is a bug or fp16 quantization error
  • Higher than expected register spills; given 128x128x32 tile size the initial estimate is 192 accumulator VGPRs + some auxiliary VGPRs for other uses, but actual budget exceeds 256 VGPRs and spills quite a lot into global memory

Example output

$ CXX=hipcc cmake . -B build -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=/opt/rocm -DAMDGPU_TARGETS=gfx90a
$ cmake --build build -t example_batched_multihead_attention_backward_fp16
$ build/bin/example_batched_multihead_attention_backward_fp16
q_gs_ms_ks: dim 4, lengths {3, 2, 512, 128}, strides {131072, 65536, 128, 1}
k_gs_ns_ks: dim 4, lengths {3, 2, 512, 128}, strides {131072, 65536, 128, 1}
v_gs_os_ns: dim 4, lengths {3, 2, 128, 512}, strides {131072, 65536, 1, 128}
y_gs_ms_os: dim 4, lengths {3, 2, 512, 128}, strides {131072, 65536, 128, 1}
lse_gs_ms_os: dim 3, lengths {3, 2, 512}, strides {1024, 512, 1}
launch_and_time_kernel: grid_dim {24, 1, 1}, block_dim {256, 1, 1} 
Warm up 1 time
Start running 10 times...
Perf: 0.365166 ms, 5.51328 TFlops, 17.2627 GB/s, DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<256, 128, 128, 32, 8, 8, 128, 128, 64, 2, MNKOPadding, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled>
Checking qgrad:
Checking kgrad:
Checking vgrad:
pass

ltqin avatar Jan 29 '23 10:01 ltqin

I think we should combine backward and backward_dropout.

danyao12 avatar Feb 07 '23 07:02 danyao12