Add dimensionality of heads argument to SABlock
Fixes #7661.
Description
The changes made add a parameter (dim_head) to set the output paramters of all the heads in the Self-attention Block (SABlock). Currently the output dimension is set to be hidden_size and when increasing the number of heads this is equally distributed among all heads.
Example
The original implementation automatically determines equally_distributed_head_dim:
(qkv * num_heds * equally_distributed_head_dim = 3*hidden_size
in this example -> 3 * 8 * 16 = 384)
block = SABlock(hidden_size=128, num_heads=8)
x = torch.zeros(1, 256, 128)
x = block.qkv(x)
print(x.shape)
x = block.input_rearrange(x)
print(x.shape)
> torch.Size([1, 256, 384])
> torch.Size([3, 1, 8, 256, 16]) # <- This corresponds to (qkv batch num_heads sequence_length equally_distributed_head_dim)
The propesed implementation fixes this by setting the new argument dim_head:
block_new = SABlock(hidden_size=128, num_heads=8, dim_head=32)
x = torch.zeros(1, 256, 128)
x = block_new.qkv(x)
print(x.shape)
x = block_new.input_rearrange(x)
print(x.shape)
> torch.Size([1, 256, 384])
> torch.Size([3, 1, 8, 256, 32]) # <- This corresponds to (qkv batch num_heads sequence_length dim_head)
Types of changes
- [x] Non-breaking change (fix or new feature that would not break existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running
./runtests.sh -f -u --net --coverage. - [x] Quick tests passed locally by running
./runtests.sh --quick --unittests --disttests. - [x] In-line docstrings updated.
- [ ] Documentation updated, tested
make htmlcommand in thedocs/folder.
Hi @NabJa thanks for the contribution. You'll have to fix some of the issues that you're seeing but some others may be related to our testing system. Please follow the DCO instructions for doing a remediation commit, and please run ./runtests.sh --autofix to fix other issues. @marksgraham is there anything here that would affect merging the generative code?
Hi @NabJa, I guess it's expected behavior.
Increasing the num_heads in the self-attention block (SABlock) does not increase the number of trainable parameters. The original input embeddings are divided into smaller chunks across a specified number of attention heads in a multi-head attention mechanism. Each of these heads then independently performs an attention mechanism on their allocated chunk of the data. Refer: https://github.com/pytorch/pytorch/blob/81740fd1f6fcd70c6ba4812c1289fe7efcc82908/torch/nn/modules/activation.py#L1010 https://discuss.huggingface.co/t/what-does-increasing-number-of-heads-do-in-the-multi-head-attention/1847 https://www.mathworks.com/matlabcentral/answers/2068031-attention-layer-number-of-parameters-doesn-t-change-when-changing-number-of-heads
But @ericspod and @marksgraham might have more expertise on this. What are your thoughts? If we indeed want to make this change, perhaps we need to include the original implementation. Thanks.
Hi @KumoLiu ,
thank you for the references! Indeed, the official PyTorch implementation splits the embeddings across all heads resulting in a head dimension of embedding dimension // number heads.
However, I still think having the option of manually setting this makes a lot of sense, because I might want to be able to increase the number of heads without loosing representational power used for every attention map. Other frequently used Attention implementations also manually set the head dimension (lucidrains).
I will make another commit changing it in a way that the default behaviour stays as it is, additionally having the option of manually setting the dimension head.
Looking forward to your opinions / reviews.
We need to be careful not to break backwards compatibility here so the default behaviour stays the same. If that happens then I don't see anything that would affect the generative code merge
@marksgraham complete backward compatibility should be guaranteed with 1ccb5de43f936720d8fc82307d703f507682d135 . @ericspod DCO is updated and linting passes the checks.
Would it be possible to hold off merging this for a couple of days? I'm working on some changes for self-attention/cross-attention for the MONAI Generative merge and it would be good to get a little bit further into them so I can work out if i need to make any further changes to accommodate this PR.
Would it be possible to hold off merging this for a couple of days? I'm working on some changes for self-attention/cross-attention for the MONAI Generative merge and it would be good to get a little bit further into them so I can work out if i need to make any further changes to accommodate this PR.
I think we're good to delay until you're set, thanks.
/build