INNLab icon indicating copy to clipboard operation
INNLab copied to clipboard

Do we need to add `.detach()` after `var` in `INN.BatchNorm1d`?

Open Zhangyanbo opened this issue 4 years ago • 2 comments

In INN.BatchNorm1d, the forward function is:

def forward(self, x, log_p=0, log_det_J=0):
        
        if self.compute_p:
            if not self.training:
                # if in self.eval()
                var = self.running_var # [dim]
            else:
                # if in training
                # TODO: Do we need to add .detach() after var?
                var = torch.var(x, dim=0, unbiased=False) # [dim]

            x = super(BatchNorm1d, self).forward(x)

            log_det = -0.5 * torch.log(var + self.eps)
            log_det = torch.sum(log_det, dim=-1)

            return x, log_p, log_det_J + log_det
        else:
            return super(BatchNorm1d, self).forward(x)

Do we need to requires var has gradient information? It seems not training BatchNorm1d, but training modules before it. Is there any references on this?

Zhangyanbo avatar Apr 25 '21 20:04 Zhangyanbo

Compare to nn.BatchNorm1d:

x = torch.randn((5, 3))
bn = nn.BatchNorm1d(3, affine=False)

bn(x)

The output is:

tensor([[-1.6941,  0.2933, -0.2451],
        [-0.1313, -0.2711,  1.4740],
        [ 0.2754, -0.2282,  0.4445],
        [ 0.1287, -1.4409, -0.0721],
        [ 1.4213,  1.6469, -1.6014]])

So, if we do not require affine in bn, we don't need gradient for BatchNorm.

Zhangyanbo avatar Apr 25 '21 20:04 Zhangyanbo

Experiments show that if we add .detach(), the training loss will not decrease. While if I added .detach(), it works. So, in the latest version, I added a parameter requires_grad:bool to INN.BatchNorm1d.

Zhangyanbo avatar Apr 26 '21 18:04 Zhangyanbo