Add modules for easily constructing residual networks
Here are two modules that implement residual blocks in a nice object-oriented way. They subclass the container modules, and they accept modules as positional arguments similarly to nn.Sequential. They also pretty-print nicely, again due to subclassing the container modules. No more self-implemented convoluted logic in separately-defined __init__() and forward()
The ResidualBlock module just takes a sequence of modules and adds a 'shortcut connection' between its input and its output. In other words its final output is the sum of the output of its last module and its original input. I also provide a ResidualBlockWithShortcut module, which lets you customize the shortcut connection, for instance to make sure it is the same shape as the output of the main branch.
No more self-implemented convoluted easily-gotten-wrong network topology in separately-defined __init__() and forward() methods, when all you want is a standard ResNet! The code looks like this:
model = nn.Sequential(
nn.Conv2d(1, 10, 1),
ResidualBlock(
nn.ReLU(),
nn.Conv2d(10, 10, 3, padding=1),
nn.ReLU(),
nn.Conv2d(10, 10, 3, padding=1),
),
nn.MaxPool2d(2),
ResidualBlock(
nn.ReLU(),
nn.Conv2d(10, 10, 3, padding=1),
nn.ReLU(),
nn.Conv2d(10, 10, 3, padding=1),
),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(7*7*10, 10),
nn.LogSoftmax(dim=-1),
)
and the model looks like this when printed:
Sequential(
(0): Conv2d(1, 10, kernel_size=(1, 1), stride=(1, 1))
(1): ResidualBlock(
(0): ReLU()
(1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): ReLU()
(3): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): ResidualBlock(
(0): ReLU()
(1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): ReLU()
(3): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Flatten()
(6): Linear(in_features=490, out_features=10, bias=True)
(7): LogSoftmax()
)
The references for residual blocks are: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, "Deep Residual Learning for Image Recognition" (https://arxiv.org/abs/1512.03385), and "Identity Mappings in Deep Residual Networks" (https://arxiv.org/abs/1603.05027).