OptimalGradCheckpointing
OptimalGradCheckpointing copied to clipboard
Making it work with deepspeed
I know it's been a while since this repo was uploaded but was thinking of getting this to work with deepspeed. Did you by any chance try that? Internally deepspeed calls self.module.parameters() for a lot of things and the segment wrapper does not have the original parameters. I tried writing a wrapper around the original model that contains both the original parameter and the dict in segment like:
class Model(Module):
... # model definition and forward function
class OGCWrapper(Model):
def __init__(self, info_dict):
x = nn.Linear(10, 10)
self.info_dict = info_dict # generate using optimal_grad_checkpointing function and dumped in a pickle format
def forward(self, x):
return graph_forward(x, **self.info_dict)
but this takes more memory than the model without checkpointing, when using with deepspeed