OptimalGradCheckpointing icon indicating copy to clipboard operation
OptimalGradCheckpointing copied to clipboard

Can't use run_segment with apex.amp

Open karinaodm opened this issue 4 years ago • 4 comments

I use code like this

run_segment = optimal_grad_checkpointing(model, inp)
run_segment, optimizer = apex.amp.initialize(run_segment, optimizer, opt_level="02", verbosity=0)
...
output = run_segment(images)

and get the error

output = run_segment(images)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/apex/amp/_initialize.py", line 197, in new_fwd
    **applier(kwargs, input_caster))
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 911, in forward
    return graph_forward(x, **self.info_dict)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 838, in graph_forward
    output = checkpoint(segment_checkpoint_forward(op), input)
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 155, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 74, in forward
    outputs = run_function(*args)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 807, in custom_forward
    outputs = segment(*inputs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 911, in forward
    return graph_forward(x, **self.info_dict)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 840, in graph_forward
    output = op(input)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 349, in forward
    return self._conv_forward(input, self.weight)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 346, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

It would be effective to combine Optimal Gradient Checkpointing with apex.amp or torch.cuda.amp

karinaodm avatar Nov 11 '21 07:11 karinaodm

Hi,

It could be that the pytorch checkpointing function is not supporting apex. Did you try torch.cuda.amp?

jianweif avatar Nov 11 '21 08:11 jianweif

I would like to try torch.cuda.amp, but torch.cuda.amp.autocast appears only in PyTorch 1.6 and OptimalGradCheckpointing works only with PyTorch 1.5

karinaodm avatar Nov 11 '21 08:11 karinaodm

Our implementation of auto parsing graph is depending on torch.jit and quite volatile with pytorch version. If you have manual parse_graph function it can definitely work with 1.6.

For auto parse, I haven't tested on 1.6 but I think it is likely working because I don't expect too many changes from pytorch 1.5 to 1.6.

Let me know if you are able to use it under pytorch 1.6. I will also test the compatibility of different versions when I get time.

Thanks

jianweif avatar Nov 11 '21 08:11 jianweif

Yes, it works with torch.cuda.amp with PyTorch 1.10 after I fixed the line https://github.com/lordfjw/OptimalGradCheckpointing/issues/3#issuecomment-966102707

Thanks!

karinaodm avatar Nov 11 '21 08:11 karinaodm