Questions about making ResNets equivariant with e2cnn
I'm trying to create an equivariant ResNet model based on e2_wide_resnet.py and I would really appreciate if you could clarify some of my doubts.
Main questions
-
I see that in your repo
e2_wide_resnet.pyis the equivariant version ofwide_resnet.pyandWideBasicis equivalent toBasicBlock. In standard, non-equivariant CNNs, we pass the argumentsin_planesandout_planestoBasicBlockwhich correspond to the number of input and output channels of the block. Afaik the number of channels in standard CNNs are equivalent to number of representations ine2cnn.nn.FieldType. However, in WideBasic instead of passing input and output FieldTypes,in_fiberandout_fiber, you also pass the inner FieldTypeinner_fiber. So my question is why is passing those two FieldTypes (like in non-equivariant network) not enough and we need the third one? -
Apart from using
BasicBlock, I also want to use a different block needed for most models in the resnet family - Bottleneck.Bottleneckis a bit more complicated thanBasicBlock, this is how it initialization looks like:
class Bottleneck(nn.Module):
def __init__(self, inplanes: int, planes: int, ...) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
which looks like this on a diagram from the ResNext paper (for 32 groups):
At the beginning, we calculate the width which is the number of output channels in the self.conv1 and the number of input channels of self.conv2 (128 in the diagram above). Given that we need a different number of channels (or number of representations in FieldType), for the equivariant bottleneck, I create a new FieldType width_fiber:
# I think len(out_fiber) is the number of channels in E2
planes = len(out_fiber)
width = int(planes * (base_width / 64.)) * groups
# now we need to get the same field type but with `width` representations (number of channels)
# dirty check if all representations in `in_fiber` are the same (otherwise next line is incorrect)
# the assumption here is that all representations in this FieldType are the same (no mixed types)
first_rep_type = type(in_fiber.representations[0])
for rep in in_fiber.representations:
assert first_rep_type == type(rep)
# FIXME hardcoded representation, should be the same representation as in_fiber
in_rep = 'regular'
# create new fiber with `width` channels
width_fiber = nn.FieldType(in_fiber.gspace, width * [in_fiber.gspace.representations[in_rep]])
self.conv1 = conv1x1(in_fiber, width_fiber, sigma=sigma, F=F, initialize=False)
...
Does that seem correct to you? Is there a cleaner way to retrieve the representation type from FieldType instead of hardcoding it?
After doing that, I again need to use a different number of channels (from planes to planes * expansion) and I do a similar thing as before for exp_out_fiber. Here is the whole code for E2Bottleneck initialization:
class E2Bottleneck(nn.EquivariantModule):
def __init__(
self,
in_fiber: nn.FieldType,
inner_fiber: nn.FieldType,
out_fiber: nn.FieldType=None,
...):
# I think len(out_fiber) is the number of channels in E2
planes = len(out_fiber)
width = int(planes * (base_width / 64.)) * groups
# now we need to get the same field type but with `width` representations (number of channels)
# dirty check if all representations in `in_fiber` are the same
first_rep_type = type(in_fiber.representations[0])
for rep in in_fiber.representations:
assert first_rep_type == type(rep)
# FIXME hardcoded representation, should be the same representation as in_fiber
in_rep = 'regular'
# create new fiber with `width` channels
width_fiber = nn.FieldType(in_fiber.gspace, width * [in_fiber.gspace.representations[in_rep]])
self.conv1 = conv1x1(in_fiber, width_fiber, sigma=sigma, F=F, initialize=False)
self.bn1 = nn.InnerBatchNorm(width_fiber)
self.conv2 = conv(width_fiber, width_fiber, stride, groups, dilation, sigma=sigma, F=F, initialize=False)
self.bn2 = nn.InnerBatchNorm(width_fiber)
# create new fiber with `planes * self.expansion` channels
exp_out_fiber = nn.FieldType(in_fiber.gspace,
planes * self.expansion * [in_fiber.gspace.representations[in_rep]])
self.conv3 = conv1x1(width_fiber, exp_out_fiber, sigma=sigma, F=F, initialize=False)
self.bn3 = nn.InnerBatchNorm(exp_out_fiber)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
Smaller questions about e2_wide_resnet.py
- Why do we use conv layers with kernel size 3 for rotations 0, 2 and 4 while conv layers with kernel size 5 for others?
if rotations in [0, 2, 4]:
conv = conv3x3
else:
conv = conv5x5
- Is this initialization correct?
elif isinstance(module, torch.nn.BatchNorm2d):
module.weight.data.fill_(1)
module.bias.data.zero_()
elif isinstance(module, torch.nn.Linear):
module.bias.data.zero_()
BatchNorm2d isn't even used, should we replace it with InnerBatchNorm or remove that part entirely (InnerBatchNorm doesn't have instance variables weight or bias from what I've checked). Also why is the standard linear initialized to 0 instead of using standard initializations?
Looking forward to your reply and please let me know if there is anything unclear in my question :)
Hi @blazejdolicki
Main questions:
-
If I understood correctly your question, in my code,
inner_typerepresents what you calledout_planes. You can then verify that, when the blocks are instantiated here, I always just set the output type to the same type asinner_type. The only exception is in the last layer of the model, where I need to output invariant features. This is the only reason why WideBasic has the option to specify a different output type -
The relation between the number of channels in the conventional model and the equivariant one depends on what you want to do. If you actually want to count the number of effective channels in the equivariant model (relevant for preserving the computational/memory cost), you should look at
out_fiber.size. Instead,len(out_fiber)tells you the number of independent features. If you useregular_representation, this is equivalent to the number of channels in a GCNN; in this sense,out_fiber.sizeis|G|times larger, as each GCNN channel effectively stores an activation for each element of the groupG. In my code, this behaviour is regulated by thefixparamsargument. If set to False, I make sure thatout_fiber.size = out_planes; instead, iffixparams=True, I scale up the number of channels by roughlysqrt{|G|}to ensure the total number of parameters is the same. If you uselen(out_fiber) = out_channelsyou are generating a much larger model; I'd not recommend this as it's likely going to be a very expensive architecture.
2.1 I am not sure what you mean with type(in_fiber.representations[0]). Regarding your "dirty check", our new library includes a shortcut for that, check here.
2.2 One you verified all representations in in_fiber are the same, you can just access one as in_fiber.representations[0]and use it to build width_fiber as width_fiber = FieldType(in_fiber.gspace, width * [in_fiber.representations[0]]). This avoids hard-coding in_rep
2.3 The same trick can be used for expansions. If you want to preserve the total number of channels or the number of parameters (as recommended earlier), I'd recommend to estimate the number of channels using something like this and pass it planes * self.expansion, such that the number of channels is rounded after taking the product.
Smaller questions:
-
the grid where images are sampled are perfectly symmetric only wrt 90 degrees rotations (contained in
C_Nfor N=1, 2 or 4). Equivariance to other rotations is necessarily approximate; using slightly wider filters allows for better stability to smaller rotations -
Yeah the initialization of BatchNorm2d is not necessary but you don't need to add anything:
InnerBatchNormis already initialized in that way during instantiation. Also, note that I initialized to 0 only the bias, but not the weights oftorch.nn.Linear.
Hope this helped! let me know if you have other questions Gabriele
Hey Gabriele, thanks for your elaborate answer, this makes it much more clear now! It seems that e2wrn.py is an improved version of e2_wide_resnet.py. Apart from some variable renaming (for example conv2triv was changed to totrivial), I see that before it was a parameter in the Wide_ResNet model class while now it's only included in _wide_layer. What was the reasoning behind that change? Moreover, now there is no GroupPooling layer, isn't it necessary for invariance?
I have another question related to the comment above. Based on my understanding it seems to me that the group space used in a FieldType should correspond to it's representation. For example, a FieldType with a trivial space should have trivial representations. Based on the examples, I see that this is not the case, for example in the first layer the gspace is Rot2dOnR2() (instead of TrivialOnR2()) while the representations are trivial (for example, in the 3rd code cell here). Can you explain why is my reasoning incorrect?
Hi @blazejdolicki
Sorry for the late reply
Regarding the difference between e2wrn.py and e2_wide_resnet.py: the first is a cleaner example I prepared for the tutorial of the library while the second was a more flexible model I built to be able to run different experiments. I do not use GroupPooling in e2wrn.py since the last convolutional layer already maps to invariant features (trivial representations) so there is no need to perform further pooling.
Regarding the other question The gspace describes the symmetry group (e.g. Rot2dOnR2(N=8) = 8 discrete rotations, TrivialOnR2() = no rotations). The representations passed as second argument, instead, define how this symmetry group transforms the channels.
Take a look at this tutorial; does it make things more clear?
Best, Gabriele
Thanks, that makes it more clear. I implemented and trained an equivariant model for N=4 rotations, but when I check the probabilities between rotations by 90 degrees (so perfect rotation without interpolation artifacts) of the same images, sometimes there is a significant difference such as the second image here (each subplot title contains class probabilities for the corresponding image)
.
Do you think this comes from some justifiable numerical inaccuracy or is there something inherently wrong with the model? Below I'm attaching my model architecture, the invariance is obtained by converting to trivial representations.
E2ResNet(
(conv1): R2Conv([4-Rotations: 3 representations], [4-Rotations: 16 representations], kernel_size=5, stride=1, padding=2, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 16 representations])
(maxpool): PointwiseMaxPool()
(layer1): SequentialModule(
(0): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 16 representations])
(conv2): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 16 representations])
)
(1): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 16 representations])
(conv2): R2Conv([4-Rotations: 16 representations], [4-Rotations: 16 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 16 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 16 representations])
)
)
(layer2): SequentialModule(
(0): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 16 representations], [4-Rotations: 32 representations], kernel_size=3, stride=2, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 32 representations])
(conv2): R2Conv([4-Rotations: 32 representations], [4-Rotations: 32 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 32 representations])
(downsample): SequentialModule(
(0): R2Conv([4-Rotations: 16 representations], [4-Rotations: 32 representations], kernel_size=1, stride=2, bias=False)
(1): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 32 representations], [4-Rotations: 32 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 32 representations])
(conv2): R2Conv([4-Rotations: 32 representations], [4-Rotations: 32 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 32 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 32 representations])
)
)
(layer3): SequentialModule(
(0): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 32 representations], [4-Rotations: 64 representations], kernel_size=3, stride=2, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 64 representations])
(conv2): R2Conv([4-Rotations: 64 representations], [4-Rotations: 64 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 64 representations])
(downsample): SequentialModule(
(0): R2Conv([4-Rotations: 32 representations], [4-Rotations: 64 representations], kernel_size=1, stride=2, bias=False)
(1): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 64 representations], [4-Rotations: 64 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 64 representations])
(conv2): R2Conv([4-Rotations: 64 representations], [4-Rotations: 64 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 64 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 64 representations])
)
)
(layer4): SequentialModule(
(0): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 64 representations], [4-Rotations: 128 representations], kernel_size=3, stride=2, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 128 representations])
(conv2): R2Conv([4-Rotations: 128 representations], [4-Rotations: 128 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 128 representations])
(downsample): SequentialModule(
(0): R2Conv([4-Rotations: 64 representations], [4-Rotations: 128 representations], kernel_size=1, stride=2, bias=False)
(1): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): E2BasicBlock(
(conv1): R2Conv([4-Rotations: 128 representations], [4-Rotations: 128 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn1): InnerBatchNorm([4-Rotations: 128 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True, type=[4-Rotations: 128 representations])
(conv2): R2Conv([4-Rotations: 128 representations], [4-Rotations: 512 representations], kernel_size=3, stride=1, padding=1, bias=False)
(bn2): InnerBatchNorm([4-Rotations: 512 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True, type=[4-Rotations: 512 representations])
(downsample): SequentialModule(
(0): R2Conv([4-Rotations: 128 representations], [4-Rotations: 512 representations], kernel_size=1, stride=1, bias=False)
(1): InnerBatchNorm([4-Rotations: 512 representations], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=2, bias=True)
)
this model is created using this script:
class E2BasicBlock(nn.EquivariantModule):
expansion: int = 1
def __init__(
self,
in_fiber: nn.FieldType,
inner_fiber: nn.FieldType,
out_fiber: nn.FieldType = None,
stride: int = 1,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
F: float = 1.,
sigma: float = 0.45,
) -> None:
super(E2BasicBlock, self).__init__()
if out_fiber is None:
out_fiber = in_fiber
self.in_type = in_fiber
inner_class = inner_fiber
self.out_type = out_fiber
if isinstance(in_fiber.gspace, gspaces.FlipRot2dOnR2):
rotations = in_fiber.gspace.fibergroup.rotation_order
elif isinstance(in_fiber.gspace, gspaces.Rot2dOnR2):
rotations = in_fiber.gspace.fibergroup.order()
else:
rotations = 0
if rotations in [0, 2, 4]:
conv = conv3x3
else:
conv = conv5x5
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv(self.in_type, inner_class, stride=stride, sigma=sigma, F=F, initialize=False)
self.bn1 = nn.InnerBatchNorm(inner_class)
self.relu = nn.ReLU(inner_class, inplace=True)
self.conv2 = conv(inner_class, self.out_type, sigma=sigma, F=F, initialize=False)
self.bn2 = nn.InnerBatchNorm(self.out_type)
# add another relu because the shape changes
self.relu2 = nn.ReLU(self.out_type, inplace=True)
self.stride = stride
# `downsample` in resnet.py is the equivalent of `shortcut` in e2_wide_resnet.py
self.downsample = None
if stride != 1 or self.in_type != self.out_type:
self.downsample = nn.SequentialModule(
conv1x1(self.in_type, self.out_type, stride=stride, bias=False, sigma=sigma, F=F, initialize=False),
nn.InnerBatchNorm(self.out_type),
)
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu2(out)
return out
# abstract method
def evaluate_output_shape(self, input_shape):
raise NotImplementedError
class E2ResNet(torch.nn.Module):
def __init__(
self,
block: Type[Union[E2BasicBlock, E2Bottleneck]],
layers: List[int],
num_classes: int = 1000,
N: int = 8,
restrict: int = 1,
flip: bool = True,
main_fiber: str = "regular",
inner_fiber: str = "regular",
F: float = 1.,
sigma: float = 0.45,
deltaorth: bool = False,
fixparams: bool = True,
initial_stride: int = 1,
conv2triv: bool = True,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None
) -> None:
"""
:param block: Type of block used in the model (E2BasicBlock or E2Bottleneck)
:param layers: Number of blocks in each layer
:param num_classes:
:param N:
:param restrict:
:param f: If the model is flip equivariant.
:param main_fiber:
:param inner_fiber:
:param F:
:param sigma:
:param deltaorth:
:param fixparams:
:param conv2triv:
:param zero_init_residual:
:param groups:
:param width_per_group:
:param replace_stride_with_dilation:
"""
super(E2ResNet, self).__init__()
# Standard initialization of ResNet
# Number of output channels of the first convolution
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
# Equivariant part of initialization of ResNet
self._fixparams = fixparams
self.conv2triv = conv2triv
self._layer = 0
self._N = N
# if the model is [F]lip equivariant
self._f = flip
# level of [R]estriction:
# r < 0 : never do restriction, i.e. initial group (either D8 or C8) preserved for the whole network
# r = 0 : do restriction before first layer, i.e. initial group doesn't have rotation equivariance (C1 or D1)
# r > 0 : restrict after every block, i.e. start with 8 rotations, then restrict to 4 and finally 1
self._r = restrict
self._F = F
self._sigma = sigma
if self._f:
self.gspace = gspaces.FlipRot2dOnR2(N)
else:
self.gspace = gspaces.Rot2dOnR2(N)
if self._r == 0:
id = (0, 1) if self._f else 1
self.gspace, _, _ = self.gspace.restrict(id)
# Start building layers
# field type of layer lifting the Z^2 input to N rotations
self.in_lifting_type = nn.FieldType(self.gspace, [self.gspace.trivial_repr] * 3)
# field type for the first lifted layer
self.next_in_type = FIBERS[main_fiber](self.gspace, self.inplanes, fixparams=self._fixparams)
# number of output channels in each outer layer
num_channels = [64, 128, 256, 512]
# For this initial cnn, torchvision ResNet uses kernel_size=7, stride=2, padding=3
# wide_resnet.py uses kernel_size=3, stride=1, padding=1
# and e2_wideresnet.py uses kernel_size=5. We follow the latter.
self.conv1 = conv5x5(self.in_lifting_type, self.next_in_type, sigma=sigma, F=F, initialize=False)
self.bn1 = nn.InnerBatchNorm(self.next_in_type)
self.relu = nn.ReLU(self.next_in_type, inplace=True)
self.maxpool = nn.PointwiseMaxPool(self.next_in_type, kernel_size=3, stride=2, padding=1)
# self.layer_i is equivalent to self.block_i in wide_resnet.py (for ith layer)
# self._make_layer is equivalent to NetworkBlock (which contains the same method) in wide_resnet.py
# and to _wide_layer in e2_wide_resnet.py
self.layer1 = self._make_layer(block, num_channels[0], layers[0], stride=initial_stride,
dilate=replace_stride_with_dilation[0],
main_fiber=main_fiber, inner_fiber=inner_fiber)
# first restriction layer
if self._r > 0:
id = (0, 4) if self._f else 4
self.restrict1 = self._restrict_layer(id)
else:
self.restrict1 = lambda x: x
self.layer2 = self._make_layer(block, num_channels[1], layers[1], stride=2,
dilate=replace_stride_with_dilation[0],
main_fiber=main_fiber, inner_fiber=inner_fiber)
# second restriction layer
if self._r > 1:
id = (0, 1) if self._f else 1
self.restrict2 = self._restrict_layer(id)
else:
self.restrict2 = lambda x: x
self.layer3 = self._make_layer(block, num_channels[2], layers[2], stride=2,
dilate=replace_stride_with_dilation[1],
main_fiber=main_fiber, inner_fiber=inner_fiber)
if self.conv2triv:
out_fiber = "trivial"
else:
out_fiber = None
self.layer4 = self._make_layer(block, num_channels[3], layers[3], stride=2,
dilate=replace_stride_with_dilation[2],
main_fiber=main_fiber, inner_fiber=inner_fiber, out_fiber=out_fiber)
if not self.conv2triv:
self.mp = nn.GroupPooling(self.layer4.out_type)
self.avgpool = torch.nn.AdaptiveAvgPool2d((1,1))
linear_input_features = self.mp.out_type.size if not self.conv2triv else self.layer4.out_type.size
self.fc = torch.nn.Linear(linear_input_features, num_classes)
for module in self.modules():
if isinstance(module, nn.R2Conv):
if deltaorth:
init.deltaorthonormal_init(module.weights.data, module.basisexpansion)
else:
init.generalized_he_init(module.weights.data, module.basisexpansion)
elif isinstance(module, torch.nn.Linear):
module.bias.data.zero_()
num_params = sum([p.numel() for p in self.parameters() if p.requires_grad])
print("Total number of learnable parameters:", num_params)
def _make_layer(self, block: Type[Union[E2BasicBlock, E2Bottleneck]], planes: int, num_blocks: int,
stride: int = 1, dilate: bool = False,
main_fiber: str = "regular",
inner_fiber: str = "regular",
out_fiber: str = None,
) -> nn.SequentialModule:
self._layer += 1
logging.info(f"Start building layer {self._layer}")
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
layers = []
main_type = FIBERS[main_fiber](self.gspace, planes, fixparams=self._fixparams)
inner_class = FIBERS[inner_fiber](self.gspace, planes, fixparams=self._fixparams)
out_f = main_type
# add first block that starts with `self.inplanes` channels and ends with `planes` channels
# use stride=`stride` for the first block and stride=1 for all the rest (default value)
first_block = block(in_fiber=self.next_in_type,
inner_fiber=inner_class,
out_fiber=out_f,
stride=stride,
# downsample=downsample,
groups=self.groups,
base_width=self.base_width,
dilation=previous_dilation,
sigma=self._sigma,
F=self._F)
layers.append(first_block)
# create new field type with `planes * block.expansion` channels
self.next_in_type = first_block.out_type
out_f = self.next_in_type
for _ in range(1, num_blocks-1):
next_block = block(in_fiber=self.next_in_type,
inner_fiber=inner_class,
out_fiber=out_f,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
sigma=self._sigma,
F=self._F)
layers.append(next_block)
self.next_in_type = out_f
# add last block
if out_fiber is None:
out_fiber = main_fiber
out_type = FIBERS[out_fiber](self.gspace, planes, fixparams=self._fixparams)
last_block = block(in_fiber=self.next_in_type,
inner_fiber=inner_class,
out_fiber=out_type,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
sigma=self._sigma,
F=self._F)
layers.append(last_block)
self.next_in_type = out_f
logging.info(f"Built layer {self._layer}")
return nn.SequentialModule(*layers)
def _restrict_layer(self, subgroup_id):
layers = list()
layers.append(nn.RestrictionModule(self.next_in_type, subgroup_id))
layers.append(nn.DisentangleModule(layers[-1].out_type))
self.next_in_type = layers[-1].out_type
self.gspace = self.next_in_type.gspace
restrict_layer = nn.SequentialModule(*layers)
return restrict_layer
def features(self, x):
x = nn.GeometricTensor(x, self.in_lifting_type)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
out = self.maxpool(x)
x1 = self.layer1(out)
x2 = self.layer2(self.restrict1(x1))
x3 = self.layer3(self.restrict2(x2))
x4 = self.layer4(x3)
# out = self.relu(self.mp(self.bn1(out)))
return x1, x2, x3, x4
def _forward_impl(self, x: Tensor) -> Tensor:
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = nn.GeometricTensor(x, self.in_lifting_type)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(self.restrict1(x))
x = self.layer3(self.restrict2(x))
x = self.layer4(x)
if not self.conv2triv:
x = self.mp(x)
x = self.avgpool(x.tensor)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
Hi @blazejdolicki
sorry for the late reply.
No, the equivariance to 90 deg should be almost perfect, with an absolute error probably lower than 1e-5.
At a first look, your code seems ok. Usually this problem arises when you used strided convolution with odd-size filters but your input images have even size.
Check page 3 of our paper for more details on this issue and how to solve it.
Let me know if this was your problem
Best, Gabriele