mujoco icon indicating copy to clipboard operation
mujoco copied to clipboard

[MJX] jax.lax.while_loop in solver.py prevents computation of backward gradients

Open EGalahad opened this issue 1 year ago • 6 comments

The feature, motivation and pitch

Problem

The solver's jax.lax.while_loop implementation prevents gradient computation through the environment step during gradient based trajectory optimization. This occurs in the solver implementation when iterations > 1.

Error encountered with jax.jit compiled grad function:

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values.

Current workaround of using opt.iteration=1 leads to potentially inaccurate simulation and gradients.

Proposed Solution

Add an option to set a fixed iteration count (e.g., 4) that would be compatible with reverse-mode differentiation using either lax.scan or lax.fori_loop with static bounds.

Alternatives

No response

Additional context

No response

EGalahad avatar Nov 29 '24 15:11 EGalahad

I like this suggestion and have labeled it as a good one for someone to take on externally. If no one does, we'll eventually implement it ourselves.

If someone would like to try it, I'd recommend briefly proposing (in this issue) how to modify the API to expose this functionality, and then if we all agree, then open a PR.

erikfrey avatar Dec 03 '24 22:12 erikfrey

@erikfrey are you still looking for a volunteer to tackle this? I'd like to give it a shot.

jaraujo98 avatar Dec 30 '24 23:12 jaraujo98

Hi @erikfrey

I previously did try to contribute to this project. But my PR is still pending(now can complete as I now have some better experience with a large codebase[LibreOffice])

Problem understanding:

jax.lax.while_loop in the solver prevents gradient computation during backpropagation, we need to replace the dynamic loop with a static one when a fixed iteration count is specified.

Solution proposed:

All Im thinking of introducing a new boolean option in the Model.opt. based on that, the solver will use a static loop with a fixed number of iterations, enabling gradient computation.

Named something like - static_iterations

Code modifications:

Check static_iterations Flag: In the solver's solve function, we can use jax.lax.fori_loop instead of jax.lax.while_loop when static_iterations is enabled.

Run Fixed Iterations: When static_iterations is True, execute the solver loop exactly m.opt.iterations times, bypassing the convergence checks.

Basically replacing jax.lax.while_loop with jax.lax.scan or with jax.lax.fori_loop when the flag is enabled?

I'm thinking of using jax.lax.fori_loop seems more appropriate choice

varshneydevansh avatar Jan 29 '25 19:01 varshneydevansh

In C MuJoCo there is a trivial way to fix the number of iterations: set mjModel.opt.tolerance = 0.

But I'll let @erikfrey comment on the correct way to do this in JAX

yuvaltassa avatar Jan 29 '25 19:01 yuvaltassa

I tried to understand regarding setting mjModel.opt.tolerance = 0 i.e. this effectively disables the convergence check which means the solver will continue iterating until it reaches the maximum number of iterations specified, regardless of whether the solution has converged.

It might lead to slightly less accurate solutions because the solver may continue even after reaching a satisfactory solution.

But, In Jax we are using a while(dynamic) loop which stops when the tolerance criteria met and having a fixed number of iterations could provide a deterministic behavior just like C MuJoCo.

varshneydevansh avatar Jan 30 '25 08:01 varshneydevansh

Hello! Can you please let me work on this?

tj279 avatar May 03 '25 04:05 tj279

Can I solve this issue?

nithya-23-cyberlion avatar Jun 24 '25 10:06 nithya-23-cyberlion

I have used this method and this seems like a better candidate for this situation :)

https://github.com/google-deepmind/mujoco/blob/2c4fed07d2cd5470a1332cbaefcf3d5a176c8354/mjx/mujoco/mjx/_src/solver.py#L236-L237

varshneydevansh avatar Jul 06 '25 12:07 varshneydevansh

The goal is to make the solver reverse mode autodiff compatible when m.opt.iterations > 1. In this specific context, setting m.opt.tolerance = 0 to achieve this behavior feels a bit strange. I think the simplest solution would be best here and is what I am currently using as a work around:

if m.opt.iterations == 1:
	ctx = body(ctx)
else:
	ctx = _while_loop_scan(cond, body, ctx, max_iter=m.opt.iterations)

With this change, if you want the solver to run and return the gradients for exactly m.opt.iterations, you can set m.opt.tolerance = 0, matching C MuJoCo's behavior.

jeh15 avatar Jul 09 '25 15:07 jeh15

Hi @jeh15

I did think of doing what you did to a certain degree, because always using _while_loop_scan for iterations > 1 is definitely the most straightforward code.

But later, as I dug into the project and after seeing the maintainers' comments, I found that there's a strong reason for the current approach, based on the existing C MuJoCo API. It came down to two key factors:

  1. API Consistency: In the original C MuJoCo, setting mjModel.opt.tolerance = 0 is the established idiom to disable the convergence check and force the solver to run for the full number of opt.iterations. My goal was to map this existing, well-understood behavior to MJX. This way, users familiar with MuJoCo get a consistent experience and a natural way to enable fixed-iteration runs, which in the JAX world, enables differentiability.

  2. Performance: This is the big one. Your proposed solution would use the _while_loop_scan for all simulations with iterations > 1. While that's perfect for gradients, and I think it would be a performance regression for the many users who are only running forward simulations and don't need gradients. jax.lax.while_loop can exit early as soon as the tolerance is met, which is often much faster. The jax.lax.cond in my PR preserves this faster while_loop for the default case and only uses the differentiable scan when the user explicitly opts in by setting tolerance = 0.

More-over it forces the user to be explicit: "I want gradients, so I will set tolerance = 0." Maybe this is better than implicitly changing the performance characteristics of the simulation under the hood?

varshneydevansh avatar Jul 09 '25 15:07 varshneydevansh

You bring up a good point; I don't have good intuition on the performance differences between jax.lax.while_loop and jax.lax.scan. I would have to benchmark it. I would still find it weird having to set m.opt.tolerance = 0 to get the gradients. This would mean I am forced to always use the gradients at m.opt.iterations even if I may have met the tolerance threshold. I am unsure what this means in a practical context; maybe it's not that big of a deal. I would have to think about it more.

jeh15 avatar Jul 09 '25 17:07 jeh15

Let @erikfrey have a decision on this :)

varshneydevansh avatar Jul 09 '25 17:07 varshneydevansh

@jeh15 Yes this is a JAX constraint. If you want support for reverse mode autodiff, you need to run a fixed number of solver iterations. In practice if you want to differentiate and you don't want your JIT time to explode, you will want to keep the number of iterations low.

We have plenty of evidence that for forward-mode, using jax.lax.while_loop helps performance significantly, so it's important to keep that as the default behavior.

I think using _while_loop_scan when m.opt.tolerance == 0 is a decent compromise - would someone like to open a PR?

erikfrey avatar Jul 10 '25 19:07 erikfrey

@erikfrey That makes sense, thank you for the clarification!

I think @varshneydevansh has a PR with that implementation.

Looking forward to the fix, thanks, guys!

jeh15 avatar Jul 10 '25 20:07 jeh15