mujoco
mujoco copied to clipboard
MuJoCo as Jax callback.
Hi,
I'm working on including MuJoCo in Jax using external callbacks, the endgoal being a fully jittable / vmappable episode generation (MJX doesn't support everything we use (yet :-)). I've a basic version working nicely, and all my tests pass, but I'd like some clarification on a few points:
- 32 vs 64 bit precision: Is the 64 bit precision due to numpy default 64 bit? Or is it required for the dynamics? I couldn't find anything related to this in MJX or MuJoco code, so my guess is the former, but confirmation would be nice!
- Do I understand correctly that
mujoco.mjtState.mjSTATE_INTEGRATIONcontains everything required to do fully deterministic calculations? I.e. instead of working withmjData, I can achieve the same with (a pytree of)mjSTATE_INTEGRATION?
Thanks!