RED-CNN icon indicating copy to clipboard operation
RED-CNN copied to clipboard

The network bug, conv layers were not supposed to share weights

Open Miaite opened this issue 5 years ago • 1 comments

In networks.py, when init class RED_CNN, only 4 nn.Conv2d were instantiated.

    def __init__(self, out_ch=96):
        super(RED_CNN, self).__init__()
        self.conv_first = nn.Conv2d(1, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv = nn.Conv2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv_t = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv_t_last = nn.ConvTranspose2d(out_ch, 1, kernel_size=5, stride=1, padding=0)
        self.relu = nn.ReLU()

then in the forward function, self.conv and self.conv_t were referenced multi times. So the same conv layer(same weights) were used to conv the features.

def forward(self, x):
        # encoder
        residual_1 = x.clone()
        out = self.relu(self.conv_first(x))
        out = self.relu(self.conv(out))
        residual_2 = out.clone()
        out = self.relu(self.conv(out))
        out = self.relu(self.conv(out))
        residual_3 = out.clone()
        out = self.relu(self.conv(out))

        # decoder
        out = self.conv_t(out)
        out += residual_3
        out = self.conv_t(self.relu(out))
        out = self.conv_t(self.relu(out))
        out += residual_2
        out = self.conv_t(self.relu(out))
        out = self.conv_t_last(self.relu(out))
        out += residual_1
        out = self.relu(out)
        return out

I guess the right way to construct network is like this:

class RED_CNN(nn.Module):
    def __init__(self, out_ch=96):
        super(RED_CNN, self).__init__()
        self.conv_first = nn.Conv2d(
            1, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv1 = nn.Conv2d(
            out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv2 = nn.Conv2d(
            out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv3 = nn.Conv2d(
            out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv4 = nn.Conv2d(
            out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv_t1 = nn.ConvTranspose2d(
            out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv_t2 = nn.ConvTranspose2d(
            out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv_t3 = nn.ConvTranspose2d(
            out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv_t4 = nn.ConvTranspose2d(
            out_ch, out_ch, kernel_size=5, stride=1, padding=0)
        self.conv_t_last = nn.ConvTranspose2d(
            out_ch, 1, kernel_size=5, stride=1, padding=0)
        self.relu = nn.ReLU()

    def forward(self, x):
        # encoder
        residual_1 = x
        out = self.relu(self.conv_first(x))
        out = self.relu(self.conv1(out))
        residual_2 = out
        out = self.relu(self.conv2(out))
        out = self.relu(self.conv3(out))
        residual_3 = out
        out = self.relu(self.conv4(out))

        # decoder
        out = self.conv_t1(out)
        out += residual_3
        out = self.conv_t2(self.relu(out))
        out = self.conv_t3(self.relu(out))
        out += residual_2
        out = self.conv_t4(self.relu(out))
        out = self.conv_t_last(self.relu(out))
        out += residual_1
        out = self.relu(out)
        return out

Miaite avatar Jul 04 '20 11:07 Miaite

I found a bug late. Thank you very much.

SSinyu avatar Nov 17 '20 02:11 SSinyu