brax
brax copied to clipboard
Persistent Caching of Jitted Functions on GPU for Brax Envs and Autodiff
Hi,
I am trying to use persistent caching in XLA on GPU to speed up the execution of my Brax code. Tracking issues in JAX this seems to be possible and I have confirmed it works for most functions on my side with no issue.
Unfortunately for my use case (getting the jacobian/hessian of the env.step wrt obs) my code exits prematurely without error when I call my persistently cached jacobian/hessian on subsequent code executions. This happens regardless of env and backend. I have included a minimal example to reproduce what I am seeing below. To reproduce my issue:
- Run minimal.py - entire program will execute and ./cache will be made and populated
- Run minimal.py again - second print (line 63) will not execute and program will quit prematurely
Main takeaways:
- Persistent caching of jacrev (the script default) applied to my step function wrapper fails
- Persistently cached hessian of my step function wrapper fails as well
- Persistent caching of jacfwd applied to my step function wrapper works without issue
- No issues when I don't jit (or jit without persistent caching) for any of the above use cases
Thanks for the help!