Thank you & two minor issues
Hi,
thank you so much for providing this code! Using the automatic computation graph parser, I was able to use the optimal gradient checkpoints during model training without writing much additional code. Now I can train with almost 2x larger batches, which is very helpful for my application!
I just had a few minor issues when running the code and I want to quickly mention them here in case anyone else experiences the same:
- In graph.py line 429: The assertion
len(input_node_names) == len(inputs_nodes_ids)fails because the listinputs_nodes_idscontainsNone. It works after removing allNonefrom the list (inputs_nodes_ids = [i for i in inputs_nodes_ids if i is not None]). However, I'm not sure if doing this could have any adverse effects?? - In graph.py line 159: Parsing the shape string fails because the
node_typelooks sth like"Float(2, 1024, strides=[1024, 1], requires_grad=0, device=cuda:0)"(I used pytorch v1.9). Quick fix:
if 'strides' in node_type:
shape_str = node_type.split('(')[-1].split(', strides')[0]
else:
shape_str = node_type.split('(')[-1].split(')')[0]
- I also had to add a few lines of code to the
get_python_module_from_node_opfunction (in graph.py) in order to handle'prim:ListUnpack', 'aten::constant_pad_nd', 'aten::squeeze', but this was straightforward based on the examples in your code :)
Thank you again for making my life easier!
Hi IsabelFunke,
Thank you for your comment! I will double check the logic in graph.py when I get some time. I will also restructure and modularize the code better. As long as your forward check and backward check pass for the network, the modifications should not have any side effects.
Parsing computation graph from the strings returned from torch.jit.trace is a workaround. This can be error prone sometimes. I am still looking into parsing the computation graph from torch.jit C++ api but no luck yet. But I will definitely upgrade this part in the future to make it more robust and easier to use.
Thanks!