Open
Description
I'm trying to differentiate the MJX step function via the autograd function jax.grad()
in JAX, like:
def step(vel, pos):
mjx_data = mjx.make_data(mjx_model)
mjx_data = mjx_data.replace(qvel = vel, qpos = pos)
pos = mjx.step(mjx_model, mjx_data).qpos
return pos
def loss(vel, pos):
pos = step(vel, pos)
return jnp.sum((pos - goal_pos)**2)
grad_loss = jax.jit(jax.grad(loss))
grad = grad_loss(vel, pos)
When there is only one rigid body in the scene, everthing works, but when there is a need to solve the collision, for example, a ball and a plane in the scene
XML = """
<mujoco>
<asset>
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3"
rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
<material name="grid" texture="grid" texrepeat="2 2" texuniform="true"
reflectance=".2"/>
</asset>
<worldbody>
<geom name="ground" type="plane" pos="0 0 -.5" size="2 2 .1" material="grid" solimp=".99 .99 .01" solref=".001 1"/>
<body>
<freejoint/>
<geom size=".15" mass="1" type="sphere"/>
</body>
</worldbody>
</mujoco>
"""
Error occurs:
File "/path-to-mujoco/mjx/_src/solver.py", line 347, in cg_solve
ctx = jax.lax.while_loop(cond, body, _CGContext.create(m, d))
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.
It seems the jax.lax.while()
function used when solving CG do not support dynamic condition function. How can I solve this?