[MJX] jax.lax.while_loop in solver.py prevents computation of backward gradients
The feature, motivation and pitch
Problem
The solver's jax.lax.while_loop implementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.
Error encountered with jax.jit compiled grad function:
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values.
Current workaround of using opt.iteration=1 leads to potentially inaccurate simulation and gradients.
Proposed Solution
Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either lax.scan or lax.fori_loop with static bounds.
Alternatives
No response
Additional context
No response
I like this suggestion and have labeled it as a good one for someone to take on externally. If no one does, we'll eventually implement it ourselves.
If someone would like to try it, I'd recommend briefly proposing (in this issue) how to modify the API to expose this functionality, and then if we all agree, then open a PR.
@erikfrey are you still looking for a volunteer to tackle this? I'd like to give it a shot.
Hi @erikfrey
I previously did try to contribute to this project. But my PR is still pending(now can complete as I now have some better experience with a large codebase[LibreOffice])
Problem understanding:
jax.lax.while_loop in the solver prevents gradient computation during backpropagation, we need to replace the dynamic loop with a static one when a fixed iteration count is specified.
Solution proposed:
All Im thinking of introducing a new boolean option in the Model.opt. based on that, the solver will use a static loop with a fixed number of iterations, enabling gradient computation.
Named something like - static_iterations
Code modifications:
Check static_iterations Flag: In the solver's solve function, we can use jax.lax.fori_loop instead of jax.lax.while_loop when static_iterations is enabled.
Run Fixed Iterations: When static_iterations is True, execute the solver loop exactly m.opt.iterations times, bypassing the convergence checks.
Basically replacing jax.lax.while_loop with jax.lax.scan or with jax.lax.fori_loop when the flag is enabled?
I'm thinking of using jax.lax.fori_loop seems more appropriate choice
In C MuJoCo there is a trivial way to fix the number of iterations: set mjModel.opt.tolerance = 0.
But I'll let @erikfrey comment on the correct way to do this in JAX
I tried to understand regarding setting mjModel.opt.tolerance = 0 i.e. this effectively disables the convergence check which means the solver will continue iterating until it reaches the maximum number of iterations specified, regardless of whether the solution has converged.
It might lead to slightly less accurate solutions because the solver may continue even after reaching a satisfactory solution.
But, In Jax we are using a while(dynamic) loop which stops when the tolerance criteria met and having a fixed number of iterations could provide a deterministic behavior just like C MuJoCo.
Hello! Can you please let me work on this?
Can I solve this issue?
I have used this method and this seems like a better candidate for this situation :)
https://github.com/google-deepmind/mujoco/blob/2c4fed07d2cd5470a1332cbaefcf3d5a176c8354/mjx/mujoco/mjx/_src/solver.py#L236-L237
The goal is to make the solver reverse mode autodiff compatible when m.opt.iterations > 1. In this specific context, setting m.opt.tolerance = 0 to achieve this behavior feels a bit strange.
I think the simplest solution would be best here and is what I am currently using as a work around:
if m.opt.iterations == 1:
ctx = body(ctx)
else:
ctx = _while_loop_scan(cond, body, ctx, max_iter=m.opt.iterations)
With this change, if you want the solver to run and return the gradients for exactly m.opt.iterations, you can set m.opt.tolerance = 0, matching C MuJoCo's behavior.
Hi @jeh15
I did think of doing what you did to a certain degree, because always using _while_loop_scan for iterations > 1 is definitely the most straightforward code.
But later, as I dug into the project and after seeing the maintainers' comments, I found that there's a strong reason for the current approach, based on the existing C MuJoCo API. It came down to two key factors:
-
API Consistency: In the original C MuJoCo, setting
mjModel.opt.tolerance = 0is the established idiom to disable the convergence check and force the solver to run for the full number ofopt.iterations. My goal was to map this existing, well-understood behavior to MJX. This way, users familiar with MuJoCo get a consistent experience and a natural way to enable fixed-iteration runs, which in the JAX world, enables differentiability. -
Performance: This is the big one. Your proposed solution would use the
_while_loop_scanfor all simulations withiterations > 1. While that's perfect for gradients, and I think it would be a performance regression for the many users who are only running forward simulations and don't need gradients.jax.lax.while_loopcan exit early as soon as the tolerance is met, which is often much faster. Thejax.lax.condin my PR preserves this fasterwhile_loopfor the default case and only uses the differentiablescanwhen the user explicitly opts in by settingtolerance = 0.
More-over it forces the user to be explicit: "I want gradients, so I will set tolerance = 0." Maybe this is better than implicitly changing the performance characteristics of the simulation under the hood?
You bring up a good point; I don't have good intuition on the performance differences between jax.lax.while_loop and jax.lax.scan. I would have to benchmark it.
I would still find it weird having to set m.opt.tolerance = 0 to get the gradients. This would mean I am forced to always use the gradients at m.opt.iterations even if I may have met the tolerance threshold. I am unsure what this means in a practical context; maybe it's not that big of a deal. I would have to think about it more.
Let @erikfrey have a decision on this :)
@jeh15 Yes this is a JAX constraint. If you want support for reverse mode autodiff, you need to run a fixed number of solver iterations. In practice if you want to differentiate and you don't want your JIT time to explode, you will want to keep the number of iterations low.
We have plenty of evidence that for forward-mode, using jax.lax.while_loop helps performance significantly, so it's important to keep that as the default behavior.
I think using _while_loop_scan when m.opt.tolerance == 0 is a decent compromise - would someone like to open a PR?
@erikfrey That makes sense, thank you for the clarification!
I think @varshneydevansh has a PR with that implementation.
Looking forward to the fix, thanks, guys!