SupContrast icon indicating copy to clipboard operation
SupContrast copied to clipboard

About the loss. Sincerely, I would like to ask:

Open Struggle-Forever opened this issue 3 years ago • 5 comments

The purpose of contrast loss is to minimize the positive sample distance while maximizing the negative sample distance. However, I only find minimizing the distance of positive samples in this loss, and I don't see maximizing the distance of negative samples? Can you tell me which codes achieve the maximum negative sample distance?

Struggle-Forever avatar Sep 03 '22 02:09 Struggle-Forever

I think this line does it.

HobbitLong avatar Sep 03 '22 03:09 HobbitLong

I think this line does it.

log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
The logits denote all samples' distances and torch.log(exp_logits.sum(1, keepdim=True)) denote the negative samples' distances .

The log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) denote the positive samples' distances and the mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) denotes the average positive loss .

What confuses me is that this feels like it only minimizes the positive sample distance. The loss of maximizing negative samples is not in the final loss.

I feel like I'm not understanding something, can you help me?

Struggle-Forever avatar Sep 03 '22 03:09 Struggle-Forever

I think this line does it. I still don't understand it. Please help me, thanks.

Struggle-Forever avatar Sep 03 '22 03:09 Struggle-Forever

I think this line does it. I still don't understand it. Please help me, thanks.

he log_prob denotes all samples' distances and the mask * log_prob can obtain the positive sample. This means let all the sample distances do the numerator/denominator , then get the loss of positive samples by mask. This time my understanding should be correct.

Struggle-Forever avatar Sep 03 '22 03:09 Struggle-Forever

I also wonder that, at line image

Isn't that the distance between positive pairs are being the denominator because of exp_logits.sum(1, keepdim=True)?

RizhaoCai avatar Dec 18 '22 12:12 RizhaoCai