Attention takes softmax over wrong dimension
PLEASE CORRECT ME IF IM WRONG.
I believe the line attn = attn.softmax(dim=2) is incorrect.
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/05321d644e4fed67d8b2856adc2f8585e79dfbee/labml_nn/diffusion/ddpm/unet.py#L188
Dim 1 contains the index (i) over the query sequence entries, and dim 2 contains the index (j) over the key sequency entries. If my understanding is correct, for any query (dim 1), we would like the sum of associated keys to be 1, so the copied information to that query position will remain the same scale, and so it may ignore information from many keys.
However the current implementation has it such that for any key, the attention from all query positions to it sums up to 1 after the softmax. Then some query positions may be close to 0 for all keys, while this forces EVERY key to be used by at least one query position.
We should not take the softmax over the key dimension (2) but over the query dimension (1).
This implementation, based on the current, uses dim 1. https://github.com/pdearena/pdearena/blob/db7664bb8ba1fe6ec3217e4079979a5e4f800151/pdearena/modules/conditioned/twod_unet.py#L223
Or am I mistaken in the output of the softmax?
Each query should attend to all keys and normalize over them is right, this is what softmax(dim=2) achieves in [batch, query, key, heads] format. Each query attends over all keys, making the current implementation correct.