BN Fixes
There are some subtle issues with how BatchNorm is handled in the PyTorch version of the code. Currently, workload.model_fn has an update_batch_norm parameter, which in theory should allow the submission to control whether the batch-norm statistics are updated during a forward pass. The issues are the following:
- The
update_batch_norm_fnfunction stores the old momentum parameter for each batchnorm layer in amomentum_backupvariable, so it can be restored later, before zeroing the parameter. However, if it is called withupdate_batch_norm=Falsetwice in a row, it overwrites themomentum_backupwith 0 on the second call, so momentum then remains zero for the remainder of training. - In PyTorch's bultin BatchNorm,
0indicates that the momentum buffer shouldn't be updated. This is the opposite of how EMA momentum is usually done (i.e. in Adam), where1would indicate that it shouldn't be updated, and 0 means it's set to the latest value at every step. The custom BatchNorm modules used in the two librispeech workloads follows this second, more standard convention instead. However, theupdate_batch_norm_fnsets the momentum to zero for all three layer types, resulting in incorrect behavior for the librispeech workloads. - The
update_batch_norm_fnsets the BN layers to eval mode. This doesn't make sense as it prevents the use-case where you use batch-computed statistics (train mode) without also updating the running statistics. The BN layers can bet set to eval mode separately by passing inForwardPassMode.EVALto the forward pass, so removing this.eval()call doesn't prevent the submission from using eval mode during a forward pass.
This PR changes switch the custom BN code to follow the BN convention so that momentum=0 doesn't update the running buffers. It also fixes the issues in the update_batch_norm_fn function mentioned above.
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅
Thanks for spotting all these issues. I agree, we should incorporate these fixes. I spotted one more subtle issue in our JAX code similar to 3 here.
def __call__(self,
x: spec.Tensor,
update_batch_norm: bool = True) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
norm = functools.partial(
nn.BatchNorm,
use_running_average=not update_batch_norm,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)
This prevents the use case where you (don't) want to use the running average in train mode and (don't) want update the batch_norm statistics. Maybe we need an extra arg in the call functions to distinguish between train and eval mode (or just whether or not to use_running_average) instead of inferring from the update_batch_norm arg?
recheck