trax icon indicating copy to clipboard operation
trax copied to clipboard

Multiple heads option is not working in SelfAttention

Open kenenbek opened this issue 4 years ago • 1 comments

Description

I use just some input activations, one SelfAttention layer and n_heads=2, but my code breaks. However, when I set n_heads=1, everything works fine.

Environment information

OS: <MacOS>

$ pip freeze | grep trax
# your output here
trax==1.3.9
$ pip freeze | grep tensor
# your output here
mesh-tensorflow==0.1.19
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.1
tensorflow-datasets==4.3.0
tensorflow-estimator==2.4.0
tensorflow-hub==0.12.0
tensorflow-metadata==0.30.0
tensorflow-text==2.4.3
$ pip freeze | grep jax
# your output here
jax==0.2.19
jaxlib==0.1.70
$ python -V
# your output here
Python 3.8.10

Steps to reproduce:

Here is a minimal code:

import trax
import numpy as np

attention = trax.layers.SelfAttention(n_heads=2)

activations = np.random.randint(0, 10, (1, 100, 1)).astype(np.float32)
input = (activations, )

init = attention.init(input)

output = attention(input)

Error logs:

 File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/layers/research/efficient_attention.py, line 1637, in forward_unbatched_h
    return forward_unbatched(*i_h, weights=w_h, state=s_h)

  File [...]/layers/research/efficient_attention.py, line 1175, in forward_unbatched
    q_info = kv_info = np.arange(q.shape[-2], dtype=np.int32)

IndexError: tuple index out of range

kenenbek avatar Aug 17 '21 16:08 kenenbek

If I define SelfAttention class with reference_code argument :

attention = trax.layers.SelfAttention(n_heads=2, , use_reference_code=True)

Everything works fine.

Is it a bug?

kenenbek avatar Aug 31 '21 12:08 kenenbek