[`MJX`] `mj_getState` and `mj_setState` equivalents for `mjx`
Hi,
I'm a maintainer of Gymnasium & the project manager of Gymnasium-Robotics, and I'm trying to use MuJoCo-MJX for "prototyping MJX-based RL environments in Gymnasium, Gymnasium-Robotics, Metaworld, MO-Gymnasium".
the python mujoco API has mj_getState and mj_setState
https://mujoco.readthedocs.io/en/3.1.0/APIreference/APIfunctions.html#mj-getstate
example usage:
state = np.empty(mujoco.mj_stateSize(env.unwrapped.model, mujoco.mjtState.mjSTATE_PHYSICS))
mujoco.mj_getState(env.unwrapped.model, env.unwrapped.data, state, spec=mujoco.mjtState.mjSTATE_PHYSICS)
mujoco.mj_setState(env.unwrapped.model, env.unwrapped.data, state, spec=mujoco.mjtState.mjSTATE_PHYSICS)
but MJX does not have an alternative, but one can write his own easily example:
# TODO unit test these
def mjx_get_physics_state(mjx_data: mjx._src.types.Data) -> jnp.ndarray:
"""Get physics state of `mjx_data` similar to mujoco.get_state."""
return jnp.concatenate([mjx_data.qpos, mjx_data.qvel, mjx_data.act])
def mjx_set_physics_state(mjx_data: mjx._src.types.Data, mjx_physics_state) -> mjx._src.types.Data:
"""Sets the physics state in `mjx_data`."""
qpos_end_index = mjx_data.qpos.size
qvel_end_index = qpos_end_index + mjx_data.qvel.size
qpos = mjx_physics_state[:qpos_end_index]
qvel = mjx_physics_state[qpos_end_index: qvel_end_index]
act = mjx_physics_state[qvel_end_index:]
assert qpos.size == mjx_data.qpos.size
assert qvel.size == mjx_data.qvel.size
assert act.size == mjx_data.act.size
return mjx_data.replace(qpos=qpos, qvel=qvel, act=act)
is there a plan to add mj_getState & mj_setState functions in mjx, or is the user expected to write their own?
Thanks!
The recommended way to do this now is: do a mjx.get_data or mjx.put_data, and use the get/set State API from MuJoCo. Does that not fit the use-case you have?
Adding Erik if he has more thoughts
The problem with your suggestion is that it is not JITable.
def mjx_get_physics_state_put_version(self, mjx_data: mjx._src.types.Data) -> np.ndarray:
"""Version based on @btaba suggestion."""
data = mujoco.MjData(self.model)
mjx.device_get_into(data, mjx_data)
#data = mjx.get_data(self.mjx_model, mjx_data) # TODO figure out how to use get_data instead
state = np.empty(mujoco.mj_stateSize(self.model, mujoco.mjtState.mjSTATE_PHYSICS))
mujoco.mj_getState(self.model, data, state, spec=mujoco.mjtState.mjSTATE_PHYSICS)
return state
Hi @Kallinteris-Andreas
You should call mjx_get_physics_state_put_version outside of the jax.jit. So once all the computations are done on device (in MJX-land), only then should you transfer the data back onto the host using device_get_inot or get_data, does that make sense?