ProtoSeg icon indicating copy to clipboard operation
ProtoSeg copied to clipboard

intra-class prototypes is same.

Open ChenT1994 opened this issue 3 years ago • 2 comments

Hi.Thanks for your great work.But I have one problem here. I downloaded the model file hrnet_w48_proto_lr1x_hrnet_proto_80k_latest.pth you supported.and use code below to get the self similarity of prototypes.

protos=model['state_dict']['module.prototypes']
feat=protos.view(-1,protos.shape[-1])
simi=feat @ feat.t()
simi=simi.cpu().numpy()
sns.heatmap(simi)
print(simi[0,:10])
print(simi[100,100:110])

got result:

The prototypes in one class is all the same! I'm confusing this. Could you please me some help.Thank you.

ChenT1994 avatar Feb 20 '23 07:02 ChenT1994

I was able to reproduce the problem. The 10 different prototypes for each class have almost identical feature values. What is the reason for this?

import torch
import seaborn as sns
import matplotlib.pyplot as plt

# from lib.models.nets.hrnet import HRNet_W48_Proto

def normalize(x):
    mean = torch.mean(x, dim=1, keepdim=True)
    std = torch.std(x, dim=1, keepdim=True)
    return (x - mean) / std

# downloaded from https://github.com/tfzhou/pretrained_weights/releases/download/v.cvpr22/hrnet_w48_proto_lr1x_hrnet_proto_80k_latest.pth
path_state_dict = "checkpoints/hrnet_w48_proto_lr1x_hrnet_proto_80k_latest.pth"

model = torch.load(path_state_dict)["state_dict"]
prototypes = model["module.prototypes"]
print("Shape of prototypes: ", prototypes.shape)

for c in range(19):

    p = prototypes[c,:,:]
    print(f"Shape prototypes class {c}: ", p.shape)
    for k in range(10):
        print(f"Features prototype {k} class {c}: ", p[k,:10])

    similarity = torch.mm(normalize(p), normalize(p).t()).cpu().numpy()

    ax = sns.heatmap(similarity, vmin=0, vmax=1, cmap="YlGnBu")
    plt.savefig(f"heatmap_{c}.png")
    plt.close()

JohannesK14 avatar Sep 12 '23 13:09 JohannesK14

Hey all, thanks for opening this issue, and thanks to the authors for providing the open-source code.

I am exploring the proposed prototype approach for my thesis and found some inconsistencies that I would like to discuss. For example, I found the same problem when training the model on a custom dataset. After a few epochs, all prototypes of the same class collapse while definitely learning to be better separable from the other classes.

I collected the following questions:

  • How do you explain the identical prototypes within one class?

  • How is this related to the Figure 3 of your paper, showing that the prototypes capture distinct features within one class while they are all identical after training?

  • According to the pixel prototype contrastive loss PPC introduced in the paper, prototypes should also be dissimilar within a class. So, is there maybe just too little emphasis on separating the intra-class prototypes?

I'll present some of my findings here: The following images show the average cosine distances $d$ between prototypes of each pair of two classes:

  • 0 (blue) means identical prototypes
  • 1 (white) means orthogonal prototypes
  • 2 (red) means opposite prototypes

A good representation finds prototypes that are well separable, so values $d\geq1$ are desired. However, when comparing the prototypes within one class (the diagonal elements), it becomes evident that they collapse after some time during training:

grafik

For comparison, I trained the model with only one prototype per class, which led to the same result, showing that the number $k>1$ does not bring any benefits in the final model:

grafik

These findings made me check the model provided in the official repo here, which verifies the findings from above:

.

To calculate the entries of the distance matrix, we sum the cosine distances between each prototype of one class to each prototype of the other class and divide by the number of counts according to the following formulas:

INTER-class prototype distances: The formula for two different classes $A$ and $B$ with $n$ prototypes $\mathbf{v}_i$ each is

INTRA-class prototype distances: On the diagonal, to ensure not counting distances between the same prototypes twice, the formula becomes

It would be great if you could provide some insights and explanations about this topic. Thank you!

Miri-xyz avatar Sep 20 '23 11:09 Miri-xyz