research-contributions icon indicating copy to clipboard operation
research-contributions copied to clipboard

Poor reconstruction usinf DAE pre-trained weights

Open OeslleLucena opened this issue 1 year ago • 1 comments

I am trying to load the pretrained weights from the tutorial https://github.com/Project-MONAI/tutorials/blob/main/self_supervised_pretraining/vit_unetr_ssl/ssl_finetune.ipynb that were trained using the DAE method https://github.com/Project-MONAI/research-contributions/tree/main/DAE/Pretrain_full_contrast. The weights used are the "ssl_pretrained_weights.pth" available at https://github.com/Project-MONAI/MONAI-extra-test-data/releases.

When I run the following code to see how well I can reconstruct one sample for the HNSCC dataset:

`

   swinUnetr = SwinUNETR(
        img_size=(96, 96, 96),
        in_channels=1,
        out_channels=1,
        feature_size=48,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        dropout_path_rate=0.0,
        use_checkpoint=True,
    )

    pretrained_path = (
        "/home/ol18/Downloads/ssl_pretrained_weights.pth"
    )
    model_dict = torch.load(pretrained_path)["model"]
    pretrained_weights_keys = model_dict.keys()
    store_dict = swinUnetr.state_dict()
    net_keys = store_dict.keys()

    print(len(net_keys), len(pretrained_weights_keys))

    count = 0
    del model_dict["encoder.mask_token"]
    del model_dict["encoder.norm.weight"]
    del model_dict["encoder.norm.bias"]

    for key, value in model_dict.items():

        if key[:8] == "encoder.":
            if key[8:19] == "patch_embed":
                new_key = "swinViT." + key[8:]
            else:
                new_key = "swinViT." + key[8:18] + key[20:]
            store_dict[new_key] = value
            count += 1
        elif key in net_keys:
            store_dict[key] = value
            count += 1
        else:
            print(key)

    print(count)
    swinUnetr.load_state_dict(store_dict)

    for key in batch.keys():
        # get inputs
        inputs = batch[key]["image"]
        B, C, H, W, Z = inputs.shape
        noise = (0.1**0.5) * torch.randn(B, C, H, W, Z).to(
            inputs.device
        )
        img_noisy = inputs + noise
        img_lowres = F.interpolate(
            img_noisy, size=(int(96 / 4), int(96 / 4), int(96 / 4))
        )
        img_resam = F.interpolate(img_lowres, size=(96, 96, 96))

        swinUnetr.to(inputs.device)
        x_rec = swinUnetr(img_resam)

        self._log_image_tensorboard(
            inputs[0],
            img_resam[0],
            x_rec[0],
            f"test_images_0_{key}",
        )

`

I've got the following outputs for the original image, noisy and noisy + low resolution:

original image as input

no_augmentations

noisy image as input

noisy

noisy + low resolution image as input

all_augmentations

From my understanding these reconstructions are fairly poor for the amount of pretraining that is done. Am I missing something in here? Can someone give me a direction?

Many thanks

@tangy5

OeslleLucena avatar Jun 14 '24 15:06 OeslleLucena

I also had similar results like yours. Were you able to figure it out?

yalcintur avatar Nov 30 '24 10:11 yalcintur