Switchable-Normalization icon indicating copy to clipboard operation
Switchable-Normalization copied to clipboard

Switch Norm 1d for 3D tensors

Open adrienchaton opened this issue 6 years ago • 1 comments

Hello, Thank you for providing code implementation for your paper.

I am interested in trying your normalization in my current experiment which works on raw waveform and audio "style". It is thus of prime interest to adaptively modulate different feature normalizations and I hope your proposal would work good to my extent.

However, when I read your Switchable-Normalization/devkit/ops/switchable_norm.py the 1d normalization only applies to 2D tensors and the 2d normalization only applies to 4D tensors. Whereas pytorch implementations of BatchNorm1d and InstanceNorm1d applies to both 2D and 3D tensors.

If possible, how should I please apply your SwitchNorm1d to 3D tensors, as for instance the output of conv1d ?

thank you !

adrienchaton avatar Jun 12 '19 12:06 adrienchaton

Hi, I tried to modify the original codes to fit the case you mentioned:

class SwitchNorm1d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True):
        super(SwitchNorm1d, self).__init__()
        self.eps = eps
        self.momentum = momentum
        self.using_moving_average = using_moving_average
        self.weight = nn.Parameter(torch.ones(1, num_features, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1))
        self.mean_weight = nn.Parameter(torch.ones(3))
        self.var_weight = nn.Parameter(torch.ones(3))
        self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
        self.register_buffer('running_var', torch.zeros(1, num_features, 1))
        self.reset_parameters()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.zero_()
        self.weight.data.fill_(1)
        self.bias.data.zero_()

    def _check_input_dim(self, input):
        if input.dim() != 3:
            raise ValueError('expected 3D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, x):  # (B, C, L)
        self._check_input_dim(x)
        mean_ln = x.mean(1, keepdim=True)  # (B, 1, L)
        var_ln = x.var(1, keepdim=True)

        mean_in = x.mean(-1, keepdim=True)  # (B, C, 1)
        var_in = x.var(-1, keepdim=True)
        temp = var_in + mean_in ** 2

        if self.training:
            mean_bn = mean_in.mean(0, keepdim=True)  # (1, C, 1)
            var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2

            if self.using_moving_average:
                self.running_mean.mul_(self.momentum)
                self.running_mean.add_((1 - self.momentum) * mean_bn.data)
                self.running_var.mul_(self.momentum)
                self.running_var.add_((1 - self.momentum) * var_bn.data)
            else:
                self.running_mean.add_(mean_bn.data)
                self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
        else:
            mean_bn = torch.autograd.Variable(self.running_mean)
            var_bn = torch.autograd.Variable(self.running_var)

        softmax = nn.Softmax(0)
        mean_weight = softmax(self.mean_weight)
        var_weight = softmax(self.var_weight)

        mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
        var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn

        x = (x - mean) / (var + self.eps).sqrt()
        return x * self.weight + self.bias

For temporal data, LN is slightly different from how it calculates for images (SwitchNorm2d). It should work as desired :)

zqiao11 avatar Mar 11 '23 12:03 zqiao11