[REQ] Context manager similar to torch.no_grad
Description
It would be great if warp would offer a feature similar to PyTorch's with torch.no_grad() context.
Context
There are various use cases for such a feature. An obvious one would be efficiency in case only certain operations within a tape are relevant. Another one could be implementing a custom integrator that remains differentiable, but excludes often noisy contact computations from the scope of a tape. With this feature, it would be as easy as inheriting from an existing simulator and putting contact force computations within such a context.
If I am not mistaken, there are currently three different ways of achieving something similar:
- Implementing a custom adjoint kernel that prevents gradient computation.
- Manually interrupting the computation graph by passing detached copies of arrays to the kernels that should not contribute to gradient compuation.
- A post-computation
cleanupof recorded launches, i.e., something liketape.launches = [l for l in tape.launches if not (isinstance(l, list) and l[0].key == 'no_grad_kernel')].
All are suboptimal both in terms of code readability and/or efficiency. Also, I am not sure if option 3 works properly or breaks something within the tape?