annotated_deep_learning_paper_implementations icon indicating copy to clipboard operation
annotated_deep_learning_paper_implementations copied to clipboard

Attention takes softmax over wrong dimension

Open Trezorro opened this issue 11 months ago • 1 comments

PLEASE CORRECT ME IF IM WRONG.

I believe the line attn = attn.softmax(dim=2) is incorrect. https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/05321d644e4fed67d8b2856adc2f8585e79dfbee/labml_nn/diffusion/ddpm/unet.py#L188

Dim 1 contains the index (i) over the query sequence entries, and dim 2 contains the index (j) over the key sequency entries. If my understanding is correct, for any query (dim 1), we would like the sum of associated keys to be 1, so the copied information to that query position will remain the same scale, and so it may ignore information from many keys.

However the current implementation has it such that for any key, the attention from all query positions to it sums up to 1 after the softmax. Then some query positions may be close to 0 for all keys, while this forces EVERY key to be used by at least one query position.

We should not take the softmax over the key dimension (2) but over the query dimension (1).

This implementation, based on the current, uses dim 1. https://github.com/pdearena/pdearena/blob/db7664bb8ba1fe6ec3217e4079979a5e4f800151/pdearena/modules/conditioned/twod_unet.py#L223

Or am I mistaken in the output of the softmax?

Trezorro avatar Feb 04 '25 13:02 Trezorro

Each query should attend to all keys and normalize over them is right, this is what softmax(dim=2) achieves in [batch, query, key, heads] format. Each query attends over all keys, making the current implementation correct.

pharrera avatar May 30 '25 01:05 pharrera