Use rand + topk in place of multinomail distribution for sampling the…
… indices to be masked
What does this PR do? Please describe: A summary of the change or the issue that is fixed.
Use rand + topk in place of multinomail distribution for sampling the indices to be masked. The overall idea is as follows:
In the _generate_mask function, we were using the multinomial distribution to sample the indices to be masked. In this PR, we replace that by a combination of rand + topk functions.
The probability distribution (float_mask) that we are feeding to the multinomial distribution has a special structure - its either 1 or 0 so when you do something like torch.multinomial(float_mask, num_samples=min_num_masked), we are basically saying that we want to sample some values where float_mask = 1.
This can be re-written as
random_values = torch.rand_like(float_mask) + 0.001 # make a tensor of random values, the 0.001 is to make sure the min is not 0 (but should not be needed because we are sampling real values and the liklihood of sampling 0 is very low)
random_values = random_values * float_mask # all the 0 values in float_mask are still 0 but the non-zero values have a random value assigned to them
_, indices = torch.topk(random_values, k=min_num_masked, dim=1, sorted=False) # select the topk values (which would be basically a subset of non-zero values in the float_mask
We need the random_values = random_values * float_mask bit part because without this, we will always return the first k (or some pseduo deterministic set of k) values from the float_mask - multiplying by random values makes it random. Note that this only works when float_mask is either 0 or 1. It would not work in other cases.
This version is about 2x faster and about 4x more memory efficient when float_mask is of shape (1500, 500))`. The performance improvements are even better for bigger shapes.
Fixes #{issue number}
Does your PR introduce any breaking changes? If yes, please list them:
No
Check list:
- [x] Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements) - Yes discussed on workplace,
- [x] Did you read the contributor guideline?
- [x] Did you make sure that your PR does only one thing instead of bundling different changes together?
- [x] Did you make sure to update the documentation with your changes? (if necessary)
- [x] Did you write any new necessary tests?
- [x] Did you verify new and existing tests pass locally with your changes?
- [] Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes) - Not sure if this applies.