stablediffusion icon indicating copy to clipboard operation
stablediffusion copied to clipboard

What's the idea behind asymmetric padding during downsampling?

Open Arksyd96 opened this issue 2 years ago • 7 comments

What's the idea behind this down-sampling with asymmetric padding : Why don't we just use a symmetric padding of 1, everything would fit perfectly.

class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)

    def forward(self, x):
        if self.with_conv:
            pad = (0,1,0,1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x

Arksyd96 avatar Jan 19 '24 10:01 Arksyd96

Did you get an answer to this question? I have been wondering myself.

ChaosAdmStudent avatar Mar 11 '24 14:03 ChaosAdmStudent

No, unfortunately not yet !

Arksyd96 avatar Mar 11 '24 14:03 Arksyd96

What's the idea behind this down-sampling with asymmetric padding : Why don't we just use a symmetric padding of 1, everything would fit perfectly.

class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)

    def forward(self, x):
        if self.with_conv:
            pad = (0,1,0,1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x

I also wonder why it is.

x2x5 avatar Apr 21 '24 08:04 x2x5

Has anyone found the answer? I'm curious why.

GGYtilE avatar Aug 13 '24 11:08 GGYtilE

I'm also curious about this question. Does anyone have an idea?

weiaicunzai avatar Aug 31 '24 15:08 weiaicunzai

Clearly, no one wants to answer to this question lol

Arksyd96 avatar Sep 01 '24 13:09 Arksyd96