CLAM icon indicating copy to clipboard operation
CLAM copied to clipboard

size mismatch while using CLAM_SB

Open Himanshunitrr opened this issue 1 year ago • 1 comments

Trying to use CLAM_SB directly. But I am getting missing keys, unexpected keys and size mismatch. How to resolve this?




import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.model_mil import MIL_fc, MIL_fc_mc
from models.model_clam import CLAM_SB, CLAM_MB
import pdb
import os
import pandas as pd
from utils.utils import *
from utils.core_utils import Accuracy_Logger
from sklearn.metrics import roc_auc_score, roc_curve, auc
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt

def initiate_model(args, ckpt_path, device='cuda'):
    print('Init Model')    
    model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes, "embed_dim": args.embed_dim}
    
    if args.model_size is not None and args.model_type in ['clam_sb', 'clam_mb']:
        model_dict.update({"size_arg": args.model_size})
    
    if args.model_type =='clam_sb':
        model = CLAM_SB(**model_dict)
    elif args.model_type =='clam_mb':
        model = CLAM_MB(**model_dict)
    else: # args.model_type == 'mil'
        if args.n_classes > 2:
            model = MIL_fc_mc(**model_dict)
        else:
            model = MIL_fc(**model_dict)

    print_network(model)

    ckpt = torch.load(ckpt_path)
    ckpt_clean = {}
    for key in ckpt.keys():
        if 'instance_loss_fn' in key:
            continue
        ckpt_clean.update({key.replace('.module', ''):ckpt[key]})
    model.load_state_dict(ckpt_clean, strict=True)

    _ = model.to(device)
    _ = model.eval()
    return model

from argparse import Namespace

args = Namespace(
    k=10,
    models_exp_code="task_1_tumor_vs_normal_CLAM_50_s1",
    save_exp_code="task_1_tumor_vs_normal_CLAM_50_s1_cv",
    task="task_1_tumor_vs_normal",
    model_type="clam_sb",
    results_dir="results",
    data_root_dir="DATA_ROOT_DIR",
    drop_out=0.25,
    embed_dim=1024,
    n_classes=2,
    model_size="small"
)

model = initiate_model(args, "/data/hmaurya/CLAM/clam_weights/camelyon_40x_cv/camelyon_40x_cv_CLAM_10_s1/s_0_checkpoint.pt")

error:


Init Model
CLAM_SB(
  (attention_net): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.25, inplace=False)
    (3): Attn_Net_Gated(
      (attention_a): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): Tanh()
        (2): Dropout(p=0.25, inplace=False)
      )
      (attention_b): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): Sigmoid()
        (2): Dropout(p=0.25, inplace=False)
      )
      (attention_c): Linear(in_features=256, out_features=1, bias=True)
    )
  )
  (classifiers): Linear(in_features=512, out_features=2, bias=True)
  (instance_classifiers): ModuleList(
    (0-1): 2 x Linear(in_features=512, out_features=2, bias=True)
  )
  (instance_loss_fn): CrossEntropyLoss()
)
Total number of parameters: 790791
Total number of trainable parameters: 790791
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], [line 1](vscode-notebook-cell:?execution_count=13&line=1)
----> [1] model = initiate_model(args, "/data/hmaurya/CLAM/clam_weights/camelyon_40x_cv/camelyon_40x_cv_CLAM_10_s1/s_0_checkpoint.pt")

Cell In[1], [line 42](vscode-notebook-cell:?execution_count=1&line=42)
     [40]         continue
     [41]    ckpt_clean.update({key.replace('.module', ''):ckpt[key]})
---> [42 model.load_state_dict(ckpt_clean, strict=True)
     [44] _ = model.to(device)
     [45] _ = model.eval()

File /data/hmaurya/temp_conda_envs/lib/python3.10/site-packages/torch/nn/modules/module.py:2581, in Module.load_state_dict(self, state_dict, strict, assign)
   [2573]        error_msgs.insert(
   [2574]             0,
   [2575]             "Missing key(s) in state_dict: {}. ".format(
   [2576]                 ", ".join(f'"{k}"' for k in missing_keys)
   [2577]             ),
   [2578]         )
   [2580] if len(error_msgs) > 0:
-> [2581]     raise RuntimeError(
   [2582]         "Error(s) in loading state_dict for {}:\n\t{}".format(
   [2583]             self.__class__.__name__, "\n\t".join(error_msgs)
   [2584]         )
   [2585]     )
   [2586] return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for CLAM_SB:
	Missing key(s) in state_dict: "classifiers.weight", "classifiers.bias". 
	Unexpected key(s) in state_dict: "classifiers.0.weight", "classifiers.0.bias", "classifiers.1.weight", "classifiers.1.bias". 
	size mismatch for attention_net.3.attention_c.weight: copying a param with shape torch.Size([2, 256]) from checkpoint, the shape in current model is torch.Size([1, 256]).
	size mismatch for attention_net.3.attention_c.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).

Himanshunitrr avatar Apr 03 '25 17:04 Himanshunitrr

Is it because of hardcoded n_classes in https://github.com/mahmoodlab/CLAM/blob/f1e93945d5f5ac6ed077cb020ed01cf984780a77/models/model_clam.py#L85 ?

if I change that to 2 (n_classes) it works fine.

Himanshunitrr avatar Apr 03 '25 17:04 Himanshunitrr