OptimalGradCheckpointing
OptimalGradCheckpointing copied to clipboard
How do I get the checkpoints when applying it at linear chain feedforward models?
As title. The program shows the peak memory usage and cut-offs, while I need help/hint as the title.
Hi sorry for the late reply. You can get the checkpointed tensors at https://github.com/jianweif/OptimalGradCheckpointing/blob/main/graph.py#L876
Instead of returning tensor_dict[target] which is the output tensor, you can return the entire tensor_dict where each value will be a checkpointed tensor in the computation graph.