Skip to content

The CG Solver in MJX dosen't support reverse-mode differentiation #1182

Open
@LyuJ1998

Description

@LyuJ1998

I'm trying to differentiate the MJX step function via the autograd function jax.grad() in JAX, like:

def step(vel, pos):
  mjx_data = mjx.make_data(mjx_model)
  mjx_data = mjx_data.replace(qvel = vel, qpos = pos)
  pos = mjx.step(mjx_model, mjx_data).qpos
  return pos

def loss(vel, pos):
  pos = step(vel, pos)
  return jnp.sum((pos - goal_pos)**2)

grad_loss = jax.jit(jax.grad(loss))
grad = grad_loss(vel, pos)

When there is only one rigid body in the scene, everthing works, but when there is a need to solve the collision, for example, a ball and a plane in the scene

XML = """
<mujoco>
  <asset>
    <texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3"
    rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
    <material name="grid" texture="grid" texrepeat="2 2" texuniform="true"
    reflectance=".2"/>
  </asset>
  <worldbody>
    <geom name="ground" type="plane" pos="0 0 -.5" size="2 2 .1" material="grid" solimp=".99 .99 .01" solref=".001 1"/>
    <body>
      <freejoint/>
      <geom size=".15" mass="1" type="sphere"/>
    </body>
  </worldbody>
</mujoco>
"""

Error occurs:

File "/path-to-mujoco/mjx/_src/solver.py", line 347, in cg_solve
    ctx = jax.lax.while_loop(cond, body, _CGContext.create(m, d))
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

It seems the jax.lax.while() function used when solving CG do not support dynamic condition function. How can I solve this?

Metadata

Metadata

Assignees

Labels

MJXUsing JAX to run on GPUbugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions