NATTEN icon indicating copy to clipboard operation
NATTEN copied to clipboard

Replace Vanilla Attention with Natten

Open RisabBiswas opened this issue 3 years ago • 1 comments

Hello!!

Thank you for sharing the details on Natten. Your research is very promising. I am trying to use natten in place of vanilla attention on my existing ViT based encoder decoder architecture. Can you please share some details on this? This will be very helpful. My existing architecture breaks the input image into patches similar to vanilla ViT architecture. I am thinking if we could use natten on these patches it will be great.

Thank you in advance! Looking forward to hear on this :)

RisabBiswas avatar Oct 25 '22 05:10 RisabBiswas

Hello and thank you for your interest.

Yes, in theory, dot product self attention can be replaced with NA/DiNA. There's a few things I'd point out:

  • NA2D expects inputs of shape [B, H, W, C].
  • NA2D also has relative positional biases, so you might be able to drop your positional encoding if you're using that.
  • You could also choose to use the NA1D, but we'd strongly discourage that for images, as we have better 2D kernels than 1D in the current version.
  • It is also recommended to use a combination of NA and DiNA for best performance. Easiest way to do this is to have one layer be NA (dilation=1), and another be NA (dilation=MAX).

A good reference for ViT-like models: Our isotropic variants.

Also keep in mind, if your input is an NxN feature map, and you set NA2D to kernel size N, the output will be identical to self attention, so you typically want to set the kernel size to be smaller than the input.

And also note that your input feature maps have to be at leas KDxKD, where K is kernel size and D is dilation, otherwise the torch module in NATTEN will pad your inputs with zeros, and crop out the output, which might make things a little slower.

A few performance notes:

  1. If you're working on images, it's best not to flatten across the spatial axes (Keeping your x in shape [B, H, W, C] is typically best if you don't have a lot of convolutions). Most operations apply to the last axis by default (nn.Linear, nn.LayerNorm, and the like). In our models NAT and DiNAT, we found that keeping the inputs "channels last" and transposing before and after convolutions is typically faster in hierarchical models.
  2. It is highly recommended to choose your number of channels to be multiples of 32, and have your attention heads be split into 32-dim-per-head. If you use 32 dimensions per head, the backend will call the more efficient "tiled" kernels instead, which run up to 10X the speed of the naive kernels.

Please let us know if you have any other questions.

alihassanijr avatar Oct 26 '22 18:10 alihassanijr

Closing this due to inactivity, but feel free to open the issue back up if you still have questions.

alihassanijr avatar Nov 16 '22 05:11 alihassanijr