Swin-Transformer icon indicating copy to clipboard operation
Swin-Transformer copied to clipboard

Fix torch.clamp issue #237

Open CryptoSalamander opened this issue 3 years ago • 3 comments

This PR is related to #237 ! There are two options to fix this problem,

  1. just convert max tensor to scalar (this PR)
  2. make max tensor be loaded the same device with self.logit_scale like below:
device = self.logit_scale.device
max_tensor = torch.log(torch.tensor(1. / 0.01)).to(device)
logit_scale = torch.clamp(self.logit_scale, max=max_tensor).exp()

I think the first option is better due to its simplicity. I tested both options on my datasets, it seems no difference in cuda memory allocation & inference speed.

CryptoSalamander avatar Jul 14 '22 09:07 CryptoSalamander

@ancientmooner Could you please check issue #237 ?

CryptoSalamander avatar Jul 14 '22 09:07 CryptoSalamander

@CryptoSalamander, I would prefer to use the second option. I faced the same issue when using torch 2.0, and the item() method in the first option will lead torch.dynamo to break WindowAttention into two graphs when tracing the module.

juncgu avatar Apr 11 '23 00:04 juncgu

@juncgu Thanks for your suggestion. I have modified the code as the second option. Could you please take a look at this PR? @ancientmooner

CryptoSalamander avatar Apr 11 '23 14:04 CryptoSalamander