RED-CNN
RED-CNN copied to clipboard
The network bug, conv layers were not supposed to share weights
In networks.py, when init class RED_CNN, only 4 nn.Conv2d were instantiated.
def __init__(self, out_ch=96):
super(RED_CNN, self).__init__()
self.conv_first = nn.Conv2d(1, out_ch, kernel_size=5, stride=1, padding=0)
self.conv = nn.Conv2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
self.conv_t = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=5, stride=1, padding=0)
self.conv_t_last = nn.ConvTranspose2d(out_ch, 1, kernel_size=5, stride=1, padding=0)
self.relu = nn.ReLU()
then in the forward function, self.conv and self.conv_t were referenced multi times. So the same conv layer(same weights) were used to conv the features.
def forward(self, x):
# encoder
residual_1 = x.clone()
out = self.relu(self.conv_first(x))
out = self.relu(self.conv(out))
residual_2 = out.clone()
out = self.relu(self.conv(out))
out = self.relu(self.conv(out))
residual_3 = out.clone()
out = self.relu(self.conv(out))
# decoder
out = self.conv_t(out)
out += residual_3
out = self.conv_t(self.relu(out))
out = self.conv_t(self.relu(out))
out += residual_2
out = self.conv_t(self.relu(out))
out = self.conv_t_last(self.relu(out))
out += residual_1
out = self.relu(out)
return out
I guess the right way to construct network is like this:
class RED_CNN(nn.Module):
def __init__(self, out_ch=96):
super(RED_CNN, self).__init__()
self.conv_first = nn.Conv2d(
1, out_ch, kernel_size=5, stride=1, padding=0)
self.conv1 = nn.Conv2d(
out_ch, out_ch, kernel_size=5, stride=1, padding=0)
self.conv2 = nn.Conv2d(
out_ch, out_ch, kernel_size=5, stride=1, padding=0)
self.conv3 = nn.Conv2d(
out_ch, out_ch, kernel_size=5, stride=1, padding=0)
self.conv4 = nn.Conv2d(
out_ch, out_ch, kernel_size=5, stride=1, padding=0)
self.conv_t1 = nn.ConvTranspose2d(
out_ch, out_ch, kernel_size=5, stride=1, padding=0)
self.conv_t2 = nn.ConvTranspose2d(
out_ch, out_ch, kernel_size=5, stride=1, padding=0)
self.conv_t3 = nn.ConvTranspose2d(
out_ch, out_ch, kernel_size=5, stride=1, padding=0)
self.conv_t4 = nn.ConvTranspose2d(
out_ch, out_ch, kernel_size=5, stride=1, padding=0)
self.conv_t_last = nn.ConvTranspose2d(
out_ch, 1, kernel_size=5, stride=1, padding=0)
self.relu = nn.ReLU()
def forward(self, x):
# encoder
residual_1 = x
out = self.relu(self.conv_first(x))
out = self.relu(self.conv1(out))
residual_2 = out
out = self.relu(self.conv2(out))
out = self.relu(self.conv3(out))
residual_3 = out
out = self.relu(self.conv4(out))
# decoder
out = self.conv_t1(out)
out += residual_3
out = self.conv_t2(self.relu(out))
out = self.conv_t3(self.relu(out))
out += residual_2
out = self.conv_t4(self.relu(out))
out = self.conv_t_last(self.relu(out))
out += residual_1
out = self.relu(out)
return out
I found a bug late. Thank you very much.