Fix torch.clamp issue #237
This PR is related to #237 ! There are two options to fix this problem,
- just convert
maxtensor to scalar (this PR) - make
maxtensor be loaded the same device withself.logit_scalelike below:
device = self.logit_scale.device
max_tensor = torch.log(torch.tensor(1. / 0.01)).to(device)
logit_scale = torch.clamp(self.logit_scale, max=max_tensor).exp()
I think the first option is better due to its simplicity. I tested both options on my datasets, it seems no difference in cuda memory allocation & inference speed.
@ancientmooner Could you please check issue #237 ?
@CryptoSalamander, I would prefer to use the second option. I faced the same issue when using torch 2.0, and the item() method in the first option will lead torch.dynamo to break WindowAttention into two graphs when tracing the module.
@juncgu Thanks for your suggestion. I have modified the code as the second option. Could you please take a look at this PR? @ancientmooner