warp icon indicating copy to clipboard operation
warp copied to clipboard

[REQ] Context manager similar to torch.no_grad

Open JonathanKuelz opened this issue 11 months ago • 0 comments

Description

It would be great if warp would offer a feature similar to PyTorch's with torch.no_grad() context.

Context

There are various use cases for such a feature. An obvious one would be efficiency in case only certain operations within a tape are relevant. Another one could be implementing a custom integrator that remains differentiable, but excludes often noisy contact computations from the scope of a tape. With this feature, it would be as easy as inheriting from an existing simulator and putting contact force computations within such a context.

If I am not mistaken, there are currently three different ways of achieving something similar:

  1. Implementing a custom adjoint kernel that prevents gradient computation.
  2. Manually interrupting the computation graph by passing detached copies of arrays to the kernels that should not contribute to gradient compuation.
  3. A post-computation cleanup of recorded launches, i.e., something like tape.launches = [l for l in tape.launches if not (isinstance(l, list) and l[0].key == 'no_grad_kernel')].

All are suboptimal both in terms of code readability and/or efficiency. Also, I am not sure if option 3 works properly or breaks something within the tape?

JonathanKuelz avatar Mar 20 '25 19:03 JonathanKuelz