RepViT
RepViT copied to clipboard
Incorrect implementation for residual connection
The implementation for Residual() is incorrect in repvit.py.
if isinstance(self.m, Conv2d_BN): m = self.m.fuse() assert(m.groups == m.in_channels) identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1) identity = torch.nn.functional.pad(identity, [1,1,1,1]) m.weight += identity.to(m.weight.device) return m
this is for converting 1x1 conv to 3x3 conv. For identity connection, the implementation is more like
identity = torch.zeros_like(m.weight) for i in range(m.weight.shape[0]): identity[i, i, 1, 1] = 1.0 # center of 3x3 kernel m.weight += identity