MONAI icon indicating copy to clipboard operation
MONAI copied to clipboard

Add dimensionality of heads argument to SABlock

Open NabJa opened this issue 1 year ago • 7 comments

Fixes #7661.

Description

The changes made add a parameter (dim_head) to set the output paramters of all the heads in the Self-attention Block (SABlock). Currently the output dimension is set to be hidden_size and when increasing the number of heads this is equally distributed among all heads.

Example

The original implementation automatically determines equally_distributed_head_dim:
(qkv * num_heds * equally_distributed_head_dim = 3*hidden_size
in this example -> 3 * 8 * 16 = 384)

block = SABlock(hidden_size=128, num_heads=8)
x = torch.zeros(1, 256, 128)
x = block.qkv(x)
print(x.shape)
x = block.input_rearrange(x)
print(x.shape)

> torch.Size([1, 256, 384])
> torch.Size([3, 1, 8, 256, 16]) # <- This corresponds to (qkv batch num_heads sequence_length equally_distributed_head_dim)

The propesed implementation fixes this by setting the new argument dim_head:

block_new = SABlock(hidden_size=128, num_heads=8, dim_head=32)
x = torch.zeros(1, 256, 128)
x = block_new.qkv(x)
print(x.shape)
x = block_new.input_rearrange(x)
print(x.shape)

> torch.Size([1, 256, 384])
> torch.Size([3, 1, 8, 256, 32]) # <- This corresponds to (qkv batch num_heads sequence_length dim_head)

Types of changes

  • [x] Non-breaking change (fix or new feature that would not break existing functionality).
  • [ ] Breaking change (fix or new feature that would cause existing functionality to change).
  • [x] New tests added to cover the changes.
  • [x] Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • [x] Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • [x] In-line docstrings updated.
  • [ ] Documentation updated, tested make html command in the docs/ folder.

NabJa avatar Apr 18 '24 13:04 NabJa

Hi @NabJa thanks for the contribution. You'll have to fix some of the issues that you're seeing but some others may be related to our testing system. Please follow the DCO instructions for doing a remediation commit, and please run ./runtests.sh --autofix to fix other issues. @marksgraham is there anything here that would affect merging the generative code?

ericspod avatar Apr 23 '24 11:04 ericspod

Hi @NabJa, I guess it's expected behavior.

Increasing the num_heads in the self-attention block (SABlock) does not increase the number of trainable parameters. The original input embeddings are divided into smaller chunks across a specified number of attention heads in a multi-head attention mechanism. Each of these heads then independently performs an attention mechanism on their allocated chunk of the data. Refer: https://github.com/pytorch/pytorch/blob/81740fd1f6fcd70c6ba4812c1289fe7efcc82908/torch/nn/modules/activation.py#L1010 https://discuss.huggingface.co/t/what-does-increasing-number-of-heads-do-in-the-multi-head-attention/1847 https://www.mathworks.com/matlabcentral/answers/2068031-attention-layer-number-of-parameters-doesn-t-change-when-changing-number-of-heads

But @ericspod and @marksgraham might have more expertise on this. What are your thoughts? If we indeed want to make this change, perhaps we need to include the original implementation. Thanks.

KumoLiu avatar Apr 24 '24 06:04 KumoLiu

Hi @KumoLiu ,

thank you for the references! Indeed, the official PyTorch implementation splits the embeddings across all heads resulting in a head dimension of embedding dimension // number heads.
However, I still think having the option of manually setting this makes a lot of sense, because I might want to be able to increase the number of heads without loosing representational power used for every attention map. Other frequently used Attention implementations also manually set the head dimension (lucidrains).
I will make another commit changing it in a way that the default behaviour stays as it is, additionally having the option of manually setting the dimension head.
Looking forward to your opinions / reviews.

NabJa avatar Apr 24 '24 09:04 NabJa

We need to be careful not to break backwards compatibility here so the default behaviour stays the same. If that happens then I don't see anything that would affect the generative code merge

marksgraham avatar Apr 24 '24 10:04 marksgraham

@marksgraham complete backward compatibility should be guaranteed with 1ccb5de43f936720d8fc82307d703f507682d135 . @ericspod DCO is updated and linting passes the checks.

NabJa avatar Apr 24 '24 14:04 NabJa

Would it be possible to hold off merging this for a couple of days? I'm working on some changes for self-attention/cross-attention for the MONAI Generative merge and it would be good to get a little bit further into them so I can work out if i need to make any further changes to accommodate this PR.

marksgraham avatar Apr 25 '24 12:04 marksgraham

Would it be possible to hold off merging this for a couple of days? I'm working on some changes for self-attention/cross-attention for the MONAI Generative merge and it would be good to get a little bit further into them so I can work out if i need to make any further changes to accommodate this PR.

I think we're good to delay until you're set, thanks.

ericspod avatar Apr 25 '24 13:04 ericspod

/build

KumoLiu avatar May 08 '24 13:05 KumoLiu