transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Fix bigbird random attention

Open Bearnardd opened this issue 3 years ago • 10 comments

What does this PR do?

Fixes the bug mentioned in the issue by transiting from np.random to the jax.random. It also adds several minor changes to be able to run the new code and pass the all the tests

Fixes # (issue) https://github.com/huggingface/transformers/issues/17355

Before submitting

  • [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.

Who can review?

@sanchit-gandhi @thevasudevgupta @patrickvonplaten

Bearnardd avatar Jan 05 '23 18:01 Bearnardd

The documentation is not available anymore as the PR was closed or merged.

Hi @sanchit-gandhi! Thank you very much for this detailed review. It is really helpful since this is my first time working with JAX :). I will apply the changes during the weekend. Have a great day!

Bearnardd avatar Jan 13 '23 12:01 Bearnardd

Awesome, very glad to hear that the pointers were helpful 🤗 feel free to post here if you have any questions - it's a bit of a fiddly fix and I'm more than happy to help if you get stuck on anything!

There's actually a similar rng trick that we use in Flax BEIT: https://github.com/huggingface/transformers/blob/b210c83a78022226ce48402cd67d8c8da7afbd8d/src/transformers/models/beit/modeling_flax_beit.py#L161

You can follow through the logic we employ with "droppath" and droppath_rng to see a working example of what we want to do here!

sanchit-gandhi avatar Jan 13 '23 14:01 sanchit-gandhi

Hi @sanchit-gandhi! Sorry for the late response but lately I was in the process of changing workplaces as well as on vacation so I have not checked github for a while :). I have implemented your comments but I have two follow up questions:

  1. Should I remove all numpy calls in the modeling file even the ones like np.zeros or np.arange or only the ones related to the randomness?

  2. I have some problems with indices_prng_key for the scenario when FlaxBigBirdBlockSparseAttention is used but deterministic=True for which indices_prng_key=None. Since even though deterministic is set to False the random jax functions are still being called and in this case the provided rng_key=None which results in the error.

Bearnardd avatar Feb 15 '23 19:02 Bearnardd

Hey @Bearnardd! Awesome to see that you've picked-up this PR again!

  1. Yes please! If you could replace all NumPy calls with their JAX equivalents that would be grand! This will keep all tensors on the accelerator device (GPU/TPU) rather than pulling them back to the host
  2. In this case, could we add if/else logic that returns the correct attention mask when deterministic? E.g.
if self.deterministic:
    # do the deterministic inference attention with no randomness
else:
    # do the stochastic training attention with jnp randomness

A similar logic is used in the Flax dropout module: https://flax.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout

sanchit-gandhi avatar Feb 17 '23 15:02 sanchit-gandhi

Hi @sanchit-gandhi! I have replaced all NumPy calls but frankly I am not sure if I understand the second part correctly. Could you explain what do you mean by deterministic inference attention and where that if/else logic should be places?

Bearnardd avatar Feb 24 '23 21:02 Bearnardd

Hey @Bearnardd! Very cool! Do you mind pushing your changes to origin so that I can see the code? This might make it easier to help-out with the deterministic issue!

Essentially, we only want to do the random operations when we're training, not at inference time. During inference, we want everything to be deterministic. This is like dropout - we only do this during training and not inference, when we want to disable dropout and have all the nodes be active.

We can check if the model is deterministic through the attribute self.determisitic (like self.training in PyTorch). What we need to do is add some logic so that the random calls are only made if self.deterministic=False (training): we know we're in training mode and we want all of the randomness, so we activate all the random calls. Else self.deterministic=True (inference) and we're indeterministic, then we don't want to do any of the randomness, e.g. skip all of it.

sanchit-gandhi avatar Mar 03 '23 14:03 sanchit-gandhi

Hi @sanchit-gandhi! Sure I will push the changes around Friday since I am currently at a business trip and I do not have my personal laptop :/

Bearnardd avatar Mar 08 '23 12:03 Bearnardd

Hi @sanchit-gandhi! I have pushed the changes.

Bearnardd avatar Mar 12 '23 23:03 Bearnardd

Hi @sanchit-gandhi all copied from statements are back, without one for PredictionHead since different dtype still counts are not copied and it results in the error

Bearnardd avatar Apr 18 '23 18:04 Bearnardd

Hi @amyeroberts @sanchit-gandhi! I changed if checking to deterministic and added unittestskip for equivalence tests. Probably around weekend I will create a issue regarding bug in Pytorch's implementation as well as PR fix. Nevertheless I guess this PR is ready to be merged.

Bearnardd avatar Apr 26 '23 22:04 Bearnardd

Yep - it all looks good to me. Thanks again for this contribution, @Bearnardd!

amyeroberts avatar Apr 27 '23 17:04 amyeroberts