Refactor mask_along_axis{,_iid} and add SpecAugment transform
The previous way of doing SpecAugment via Frequency/TimeMasking transforms has the following problems:
- Only zero masking can be done; masking by mean value is not supported.
- mask_along_axis is hard-coded to mask the 1st dimension and mask_along_axis_iid is hard-code to mask the 2nd or 3rd dimension of the input tensor.
- For 3D spectrogram tensors where the first dimension is batch or channel, features from the same batch or different channels have to use the same mask, because mask_along_axis_iid only support 4D tensors, because of the above hard-coding
- For 2D spectrogram tensors w/o a batch or channel dimension, Time/Frequency masking can't be applied at all, since mask_along_axis only support 3D tensors, because of the above hard-coding.
- It's not straightforward to apply multiple time/frequency masks by the current design.
To solve these issues, here we
- Extend mask_along_axis_iid to support 3D tensors and mask_along_axis to support 2D tensors. Now both of them are able to mask one of the last two dimensions (where the time or frequency dimension lives) of the input tensor.
- Add a SpecAugment transform which directly calls mask_along_axis{,_iid} with the flexibility of specifying number of time/frequency masks within a single transform, and zero/mean-value based masking. Note that if the input tensor has a dimension of 3 or 4, we always apply different masks along the first 1 or 2 dimensions.
Additionally, TimeMasking/Frequency We may consider deprecating AxisMasking, TimeMasking and FrequencyMasking in future.
We may consider deprecating AxisMasking, TimeMasking and FrequencyMasking in future.
I am okay with that, but in that case, instead of extending them, I'd start a new implementation and keep them as-is, so that new implementations are tailored for proper implementations.
We may consider deprecating AxisMasking, TimeMasking and FrequencyMasking in future.
I am okay with that, but in that case, instead of extending them, I'd start a new implementation and keep them as-is, so that new implementations are tailored for proper implementations.
yeah agree. I'm exactly doing what you suggested: keeping them as they're. I extended mask_along_axis{,_iid} so that the new class SpecAugment can work properly.
We may consider deprecating AxisMasking, TimeMasking and FrequencyMasking in future.
I am okay with that, but in that case, instead of extending them, I'd start a new implementation and keep them as-is, so that new implementations are tailored for proper implementations.
yeah agree. I'm exactly doing what you suggested: keeping them as they're. I extended mask_along_axis{,_iid} so that the new class SpecAugment can work properly.
What I mean was to even keep mask_along_axis{,_iid} as-is because it is public API. When I look at the implementation of them and related functions, I feel like the implementation can be more straightforward.
I think we need to add tests for the new capability. And perhaps can we split the PR into 1. changes to the existing functions and 2. addition of SpecAugment?
@xiaohui-zhang has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.