RepViT icon indicating copy to clipboard operation
RepViT copied to clipboard

Incorrect implementation for residual connection

Open kensun619 opened this issue 7 months ago • 0 comments

The implementation for Residual() is incorrect in repvit.py.

if isinstance(self.m, Conv2d_BN): m = self.m.fuse() assert(m.groups == m.in_channels) identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1) identity = torch.nn.functional.pad(identity, [1,1,1,1]) m.weight += identity.to(m.weight.device) return m

this is for converting 1x1 conv to 3x3 conv. For identity connection, the implementation is more like

identity = torch.zeros_like(m.weight) for i in range(m.weight.shape[0]): identity[i, i, 1, 1] = 1.0 # center of 3x3 kernel m.weight += identity

kensun619 avatar Jul 09 '25 06:07 kensun619