trax
trax copied to clipboard
Multiple heads option is not working in SelfAttention
Description
I use just some input activations, one SelfAttention layer and n_heads=2, but my code breaks. However, when I set n_heads=1, everything works fine.
Environment information
OS: <MacOS>
$ pip freeze | grep trax
# your output here
trax==1.3.9
$ pip freeze | grep tensor
# your output here
mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.1
tensorflow-datasets==4.3.0
tensorflow-estimator==2.4.0
tensorflow-hub==0.12.0
tensorflow-metadata==0.30.0
tensorflow-text==2.4.3
$ pip freeze | grep jax
# your output here
jax==0.2.19
jaxlib==0.1.70
$ python -V
# your output here
Python 3.8.10
Steps to reproduce:
Here is a minimal code:
import trax
import numpy as np
attention = trax.layers.SelfAttention(n_heads=2)
activations = np.random.randint(0, 10, (1, 100, 1)).astype(np.float32)
input = (activations, )
init = attention.init(input)
output = attention(input)
Error logs:
File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File [...]/layers/research/efficient_attention.py, line 1637, in forward_unbatched_h
return forward_unbatched(*i_h, weights=w_h, state=s_h)
File [...]/layers/research/efficient_attention.py, line 1175, in forward_unbatched
q_info = kv_info = np.arange(q.shape[-2], dtype=np.int32)
IndexError: tuple index out of range
If I define SelfAttention class with reference_code argument :
attention = trax.layers.SelfAttention(n_heads=2, , use_reference_code=True)
Everything works fine.
Is it a bug?