DeepSpeed
DeepSpeed copied to clipboard
Update vae.py
Since the DSVAE class is already inheriting from torch.nn.Module, there is no need to inherit from CUDAGraph as well. You can remove the CUDAGraph inheritance. Instead of using self.vae.requires_grad_(requires_grad=False), you can use torch.no_grad() context manager during initialization to disable gradient computation for the self.vae module. The _graph_replay_decoder, _graph_replay_encoder, and _graph_replay methods can benefit from the @torch.no_grad() decorator.
@microsoft-github-policy-service agree
@microsoft-github-policy-service agree