Poor reconstruction usinf DAE pre-trained weights
I am trying to load the pretrained weights from the tutorial https://github.com/Project-MONAI/tutorials/blob/main/self_supervised_pretraining/vit_unetr_ssl/ssl_finetune.ipynb that were trained using the DAE method https://github.com/Project-MONAI/research-contributions/tree/main/DAE/Pretrain_full_contrast. The weights used are the "ssl_pretrained_weights.pth" available at https://github.com/Project-MONAI/MONAI-extra-test-data/releases.
When I run the following code to see how well I can reconstruct one sample for the HNSCC dataset:
`
swinUnetr = SwinUNETR(
img_size=(96, 96, 96),
in_channels=1,
out_channels=1,
feature_size=48,
drop_rate=0.0,
attn_drop_rate=0.0,
dropout_path_rate=0.0,
use_checkpoint=True,
)
pretrained_path = (
"/home/ol18/Downloads/ssl_pretrained_weights.pth"
)
model_dict = torch.load(pretrained_path)["model"]
pretrained_weights_keys = model_dict.keys()
store_dict = swinUnetr.state_dict()
net_keys = store_dict.keys()
print(len(net_keys), len(pretrained_weights_keys))
count = 0
del model_dict["encoder.mask_token"]
del model_dict["encoder.norm.weight"]
del model_dict["encoder.norm.bias"]
for key, value in model_dict.items():
if key[:8] == "encoder.":
if key[8:19] == "patch_embed":
new_key = "swinViT." + key[8:]
else:
new_key = "swinViT." + key[8:18] + key[20:]
store_dict[new_key] = value
count += 1
elif key in net_keys:
store_dict[key] = value
count += 1
else:
print(key)
print(count)
swinUnetr.load_state_dict(store_dict)
for key in batch.keys():
# get inputs
inputs = batch[key]["image"]
B, C, H, W, Z = inputs.shape
noise = (0.1**0.5) * torch.randn(B, C, H, W, Z).to(
inputs.device
)
img_noisy = inputs + noise
img_lowres = F.interpolate(
img_noisy, size=(int(96 / 4), int(96 / 4), int(96 / 4))
)
img_resam = F.interpolate(img_lowres, size=(96, 96, 96))
swinUnetr.to(inputs.device)
x_rec = swinUnetr(img_resam)
self._log_image_tensorboard(
inputs[0],
img_resam[0],
x_rec[0],
f"test_images_0_{key}",
)
`
I've got the following outputs for the original image, noisy and noisy + low resolution:
original image as input
noisy image as input
noisy + low resolution image as input
From my understanding these reconstructions are fairly poor for the amount of pretraining that is done. Am I missing something in here? Can someone give me a direction?
Many thanks
@tangy5
I also had similar results like yours. Were you able to figure it out?