Fix bigbird random attention
What does this PR do?
Fixes the bug mentioned in the issue by transiting from np.random to the jax.random. It also adds several minor changes to be able to run the new code and pass the all the tests
Fixes # (issue) https://github.com/huggingface/transformers/issues/17355
Before submitting
- [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
Who can review?
@sanchit-gandhi @thevasudevgupta @patrickvonplaten
The documentation is not available anymore as the PR was closed or merged.
Hi @sanchit-gandhi! Thank you very much for this detailed review. It is really helpful since this is my first time working with JAX :). I will apply the changes during the weekend. Have a great day!
Awesome, very glad to hear that the pointers were helpful 🤗 feel free to post here if you have any questions - it's a bit of a fiddly fix and I'm more than happy to help if you get stuck on anything!
There's actually a similar rng trick that we use in Flax BEIT: https://github.com/huggingface/transformers/blob/b210c83a78022226ce48402cd67d8c8da7afbd8d/src/transformers/models/beit/modeling_flax_beit.py#L161
You can follow through the logic we employ with "droppath" and droppath_rng to see a working example of what we want to do here!
Hi @sanchit-gandhi! Sorry for the late response but lately I was in the process of changing workplaces as well as on vacation so I have not checked github for a while :). I have implemented your comments but I have two follow up questions:
-
Should I remove all
numpycalls in the modeling file even the ones likenp.zerosornp.arangeor only the ones related to the randomness? -
I have some problems with
indices_prng_keyfor the scenario whenFlaxBigBirdBlockSparseAttentionis used butdeterministic=Truefor whichindices_prng_key=None. Since even though deterministic is set to False the random jax functions are still being called and in this case the providedrng_key=Nonewhich results in the error.
Hey @Bearnardd! Awesome to see that you've picked-up this PR again!
- Yes please! If you could replace all NumPy calls with their JAX equivalents that would be grand! This will keep all tensors on the accelerator device (GPU/TPU) rather than pulling them back to the host
- In this case, could we add
if/elselogic that returns the correct attention mask when deterministic? E.g.
if self.deterministic:
# do the deterministic inference attention with no randomness
else:
# do the stochastic training attention with jnp randomness
A similar logic is used in the Flax dropout module: https://flax.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout
Hi @sanchit-gandhi! I have replaced all NumPy calls but frankly I am not sure if I understand the second part correctly. Could you explain what do you mean by deterministic inference attention and where that if/else logic should be places?
Hey @Bearnardd! Very cool! Do you mind pushing your changes to origin so that I can see the code? This might make it easier to help-out with the deterministic issue!
Essentially, we only want to do the random operations when we're training, not at inference time. During inference, we want everything to be deterministic. This is like dropout - we only do this during training and not inference, when we want to disable dropout and have all the nodes be active.
We can check if the model is deterministic through the attribute self.determisitic (like self.training in PyTorch). What we need to do is add some logic so that the random calls are only made if self.deterministic=False (training): we know we're in training mode and we want all of the randomness, so we activate all the random calls.
Else self.deterministic=True (inference) and we're indeterministic, then we don't want to do any of the randomness, e.g. skip all of it.
Hi @sanchit-gandhi! Sure I will push the changes around Friday since I am currently at a business trip and I do not have my personal laptop :/
Hi @sanchit-gandhi! I have pushed the changes.
Hi @sanchit-gandhi all copied from statements are back, without one for PredictionHead since different dtype still counts are not copied and it results in the error
Hi @amyeroberts @sanchit-gandhi! I changed if checking to deterministic and added unittestskip for equivalence tests. Probably around weekend I will create a issue regarding bug in Pytorch's implementation as well as PR fix. Nevertheless I guess this PR is ready to be merged.
Yep - it all looks good to me. Thanks again for this contribution, @Bearnardd!