Update _SimpleConsensus to use static autograd methods (for PyTorch >1.3)
Thank you so much for sharing your implementation of TPN!
Problem
I've been trying to get it to work in one of my own projects -- however, I ran into the same issue as mentioned #28, in which the user pastes a stack trace with error message "Legacy autograd function with non-static forward method is deprecated." This occurs when you try to call forward() with the old code when the averaging consensus (_SimpleConsensus) is used.
Environment
- PyTorch 1.7.0 + CUDA 11.0
- Ubuntu 16.04.2 LTS
- Python 3.7.8
Summary of changes
In order to make the _SimpleConsensus class (subclassing torch.Autograd.Function) compatible with PyTorch >1.3:
- Removed
__init__method from_SimpleConsensus - Use
applystatic method instead offorwardfor passing input tensor through the_SimpleConsensusobject -
forward()method of_SimpleConsensususesctx.save_for_backward(args)to cache input tensorx,dim, andconsensus_type. -
self.shapeis no longer a member of_SimpleConsensus; it is reconstructed by retrievingxfromctx.saved_tensorsand callingx.size()in each call tobackward().
This is consistent with the template given in the PyTorch docs, which I referenced.
Discussion
The changes in this PR work for me -- I am able to run forward() without issue now. However, as a disclaimer, due to the nature of my project, I'm using my own testing script instead of the provided testing framework in this repo. For completeness, my model loading code looks like this:
from TPN.mmaction.models.recognizers import TSN3D
import torch
PRETRAINED_MODEL_PATH = "/path/to/my/model/kinetics400_tpn_r50f32s2.pth"
model = TSN3D(model_cfg["backbone"], necks=model_cfg["necks"], cls_head=model_cfg["cls_head"],
spatial_temporal_module=model_cfg["spatial_temporal_module"],
segmental_consensus=model_cfg["segmental_consensus"])
pretrained = torch.load(PRETRAINED_MODEL_PATH)
model.load_state_dict(pretrained)
Please let me know if there's any additional testing (suites or otherwise) I should run, or if there's a contributing guide that I've overlooked. Furthermore, I'm happy to provide more details as needed. Thanks!
It's saying ctx is not defined for me, am I missing something?
Missed a typo; thanks for pointing that out. I overwrote my previous changes on my branch. Does it work for you now?
The ctx not defined error is solved, but there's a new error. It seems like save_for_backward can only save tensors, not dim(int) and consensus_type(str): TypeError: save_for_backward can only save variables, but argument 1 is of type int
I've created a PR on your fork with the code that works for me: PR