sam
sam copied to clipboard
SAM: Sharpness-Aware Minimization (PyTorch)
``` base_optimizer = torch.optim.Adam optimizer = SAM(model.parameters(), base_optimizer, lr=0.1) torch.save({"optz_state_dict":optimizer.state_dict()}, "state.pth") checkpoint = torch.load("state.pth") optimizer.load_state_dict(checkpoint["optz_state_dict"]) ``` By using the above code, the saved state size is more than halved compared...
Link to the paper: https://arxiv.org/abs/2206.04920 Any chance for this implementation in this module?
def _grad_norm(self): shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism norm = torch.norm( torch.stack([ ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) for group...
As per title. If the closure is not None, we should assume that the parameter's gradients have not been computed, and immediately run the closure. See, e.g., the step function...
Hello, I am trying to use the step function(with the transformers and accelerate library) while passing the closure. The step function has a decorator @torch.no_grad() and thus we specify enable_grad...
As mentioned in Readme, the suggested usage can potentially cause problems if you use batch normalization. Will Layernorm or Groupnorm cause problems in principle? I use SAM in Swintransformer and...
How can you combine SAM with GradScaler and gradient clipping, because you can't unscale twice.