OptimalGradCheckpointing icon indicating copy to clipboard operation
OptimalGradCheckpointing copied to clipboard

Making it work with deepspeed

Open eliird opened this issue 1 year ago • 0 comments

I know it's been a while since this repo was uploaded but was thinking of getting this to work with deepspeed. Did you by any chance try that? Internally deepspeed calls self.module.parameters() for a lot of things and the segment wrapper does not have the original parameters. I tried writing a wrapper around the original model that contains both the original parameter and the dict in segment like:


class Model(Module):
       ... # model definition and forward function

class OGCWrapper(Model):
      def __init__(self, info_dict):
               x = nn.Linear(10, 10)
               self.info_dict = info_dict # generate using optimal_grad_checkpointing function and dumped in a pickle format
      def forward(self, x):
               return graph_forward(x, **self.info_dict)

but this takes more memory than the model without checkpointing, when using with deepspeed

eliird avatar Oct 25 '24 03:10 eliird