Unnecessary conv3 Layer Weights in DAE with MONAI v0.8.0 Leading to Compatibility Issues with v0.9.0+
Description
In MONAI version 0.8.0, UnetResBlock creates a conv3 layer regardless of its use, leading to unnecessary weight inclusion in the SwinUNETR architecture within the DAE ssl weights. Specifically, conv3 layers in encoder2, encoder3, encoder4, and encoder10 seem redundant as these blocks have identical input and output feature dimensions. This redundancy becomes problematic when upgrading to MONAI v0.9.0+, where conv3 is only instantiated if needed, causing errors due to missing layers.
Monai v0.8.0 snippet shows unconditional creation:
def __init__(self, ...):
# Rest of code
self.conv3 = get_conv_layer(
spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True
)
self.downsample = in_channels != out_channels
stride_np = np.atleast_1d(stride)
if not np.all(stride_np == 1):
self.downsample = True
def forward(self, inp):
residual = inp
# Rest of code
if self.downsample:
residual = self.conv3(residual)
residual = self.norm3(residual)
In Monai v0.9.0+ the conv3 layer is created conditionally based on input and output channels. The DAE ssl weights includes conv3 in all UnetResBlocks, also in layers where they are not used. When loading ssl weights when finetuning using the DAE/BTCV_Finetune repositiory using Monai v0.9.0+ the code throws error since the expected conv3 layers are never created.
load_from in DAE/BTCV_Finetune/swin_unetr_og.py includes the following lines:
self.encoder1.layer.conv1.conv.weight.copy_(weights["model"]["encoder1.layer.conv1.conv.weight"])
self.encoder1.layer.conv2.conv.weight.copy_(weights["model"]["encoder1.layer.conv2.conv.weight"])
self.encoder1.layer.conv3.conv.weight.copy_(weights["model"]["encoder1.layer.conv3.conv.weight"])
self.encoder2.layer.conv1.conv.weight.copy_(weights["model"]["encoder2.layer.conv1.conv.weight"])
self.encoder2.layer.conv2.conv.weight.copy_(weights["model"]["encoder2.layer.conv2.conv.weight"])
self.encoder2.layer.conv3.conv.weight.copy_(weights["model"]["encoder2.layer.conv3.conv.weight"])
self.encoder3.layer.conv1.conv.weight.copy_(weights["model"]["encoder3.layer.conv1.conv.weight"])
self.encoder3.layer.conv2.conv.weight.copy_(weights["model"]["encoder3.layer.conv2.conv.weight"])
self.encoder3.layer.conv3.conv.weight.copy_(weights["model"]["encoder3.layer.conv3.conv.weight"])
self.encoder4.layer.conv1.conv.weight.copy_(weights["model"]["encoder4.layer.conv1.conv.weight"])
self.encoder4.layer.conv2.conv.weight.copy_(weights["model"]["encoder4.layer.conv2.conv.weight"])
self.encoder4.layer.conv3.conv.weight.copy_(weights["model"]["encoder4.layer.conv3.conv.weight"])
self.encoder10.layer.conv1.conv.weight.copy_(weights["model"]["encoder10.layer.conv1.conv.weight"])
self.encoder10.layer.conv2.conv.weight.copy_(weights["model"]["encoder10.layer.conv2.conv.weight"])
self.encoder10.layer.conv3.conv.weight.copy_(weights["model"]["encoder10.layer.conv3.conv.weight"])
Proposed solution
Avoid loading weights for conv3 in blocks encoder2, encoder3, encoder4 and encoder10, where the input and output feature sizes are the same, to ensure compatibility with newer monai versions.
Could also be beneficial to remove the mentioned weights from DAE_SSL_WEIGHTS to avoid further confusion.