PytorchInsight icon indicating copy to clipboard operation
PytorchInsight copied to clipboard

Replace with Group Normalization

Open Haus226 opened this issue 1 year ago • 0 comments

#2 #33 Regarding the issues above, I think the following lines are the proper way to replace the original implementation with the group normalization. The second function is the slightly modified official implementation of SGE block to align with the GN in pytorch

self.gn = nn.GroupNorm(1, 1)
def forward(self, x):
      b, c, h, w = x.size()
      x = x.view(b * self.groups, -1, h, w) 
      xn = x * self.avg_pool(x)
      xn = xn.sum(dim=1, keepdim=True)
      xn = xn.view(b * self.groups, -1, h, w)
      t = self.gn.forward(xn)
      x = x * self.sig(t.view(b * self.groups, 1, h, w))
      x = x.view(b, c, h, w)
      return x

def oforward(self, x):
        b, c, h, w = x.size()
        x = x.view(b * self.groups, -1, h, w) 
        xn = x * self.avg_pool(x)
        # Reduce the weighted channels in each groups to obtain the attention maps for each groups
        # (This operation is not performed in GN)
        xn = xn.sum(dim=1, keepdim=True)
        # Flatten the spatial in each groups
        t = xn.view(b * self.groups, -1)
        # I think we should use the std of the original t instead of the one updated by subtracting a mean from it.
        var = t.var(dim=1, keepdim=True, unbiased=False)
        t = (t - t.mean(dim=1, keepdim=True)) / torch.sqrt(var + self.eps)
        t = t.view(b, self.groups, h, w)
        t = t * self.weight + self.bias
        t = t.view(b * self.groups, 1, h, w)
        x = x * self.sig(t)
        x = x.view(b, c, h, w)
        return x

Following is the testing code with the result:4.3839216232299807e-07

running_sum = 0
for _ in range(100):
    t = torch.rand(32, 512, 21, 21)
    m = SGE(64, 512) # number of groups and input channels
    running_sum += (m.forward(t) - m.oforward(t)).max().item()
print("The average maximum difference between the tensor is : ", running_sum / 100)

Haus226 avatar Aug 31 '24 14:08 Haus226