Skip to content

The CG Solver in MJX dosen't support reverse-mode differentiation #1182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
LyuJ1998 opened this issue Nov 10, 2023 · 7 comments
Open

The CG Solver in MJX dosen't support reverse-mode differentiation #1182

LyuJ1998 opened this issue Nov 10, 2023 · 7 comments
Assignees
Labels
bug Something isn't working MJX Using JAX to run on GPU

Comments

@LyuJ1998
Copy link

I'm trying to differentiate the MJX step function via the autograd function jax.grad() in JAX, like:

def step(vel, pos):
  mjx_data = mjx.make_data(mjx_model)
  mjx_data = mjx_data.replace(qvel = vel, qpos = pos)
  pos = mjx.step(mjx_model, mjx_data).qpos
  return pos

def loss(vel, pos):
  pos = step(vel, pos)
  return jnp.sum((pos - goal_pos)**2)

grad_loss = jax.jit(jax.grad(loss))
grad = grad_loss(vel, pos)

When there is only one rigid body in the scene, everthing works, but when there is a need to solve the collision, for example, a ball and a plane in the scene

XML = """
<mujoco>
  <asset>
    <texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3"
    rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
    <material name="grid" texture="grid" texrepeat="2 2" texuniform="true"
    reflectance=".2"/>
  </asset>
  <worldbody>
    <geom name="ground" type="plane" pos="0 0 -.5" size="2 2 .1" material="grid" solimp=".99 .99 .01" solref=".001 1"/>
    <body>
      <freejoint/>
      <geom size=".15" mass="1" type="sphere"/>
    </body>
  </worldbody>
</mujoco>
"""

Error occurs:

File "/path-to-mujoco/mjx/_src/solver.py", line 347, in cg_solve
    ctx = jax.lax.while_loop(cond, body, _CGContext.create(m, d))
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

It seems the jax.lax.while() function used when solving CG do not support dynamic condition function. How can I solve this?

@LyuJ1998 LyuJ1998 added the bug Something isn't working label Nov 10, 2023
@LyuJ1998
Copy link
Author

I'm also trying to replace the ctx = jax.lax.while_loop(cond, body, _CGContext.create(m, d)) in mjx/_src/solver.py Line 347 with a simpler while function:

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
ctx = while_loop(cond, body, _CGContext.create(m, d))

It works when not using jax.jit() to complie the gradient function, but when using jax.jit(), another error:

File "/path-to-mujoco/mjx/_src/solver.py", line 349, in while_loop
    while cond_fun(val):
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function loss at /home/lvjun/Mujoco3/demo_mjx.py:55 for jit. This concrete value was not available in Python because it depends on the values of the arguments vel and pos.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

It is because the the improvement and gradient inf cond() is not static?

  def cond(ctx: _CGContext) -> jax.Array:
    improvement = _rescale(m, ctx.prev_cost - ctx.cost)
    gradient = _rescale(m, math.norm(ctx.grad))

    done = ctx.solver_niter >= m.opt.iterations
    done |= improvement < m.opt.tolerance
    done |= gradient < m.opt.tolerance

    return ~done

Is there any chance to make it supported for JIT compilation?

@btaba
Copy link
Collaborator

btaba commented Nov 10, 2023

Hi @LyuJ1998 , this is a known issue with while_loop "while_loop is not reverse-mode differentiable because XLA computations require static bounds on memory requirements." https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html

You can change the while_loop to a scan . jax-ml/jax#3850

The TracerBoolConversionError occurs because cond_fun(val) is a traced jax array, but you're using it in a python while loop which expects a concrete value. Use a scan or a for loop

@sfd158
Copy link

sfd158 commented Nov 21, 2023

To reduce memory usage, there are some inplace operation in mjx.step.
Inplace operation on intermediate matrics, such as X[0] += Y[0], will break back-propagation path.
So jax.grad(mjx.step) doesn't work in mujoco 3.0.

@erikfrey
Copy link
Collaborator

Hi there,

Yes indeed this is by design, but poorly documented. I'll take this as motivation to add to the documentation.

tl;dr: if you would like to experiment with jax.grad(), please update to MuJoCo 3.0.1 which now includes support for Newton solver in MJX. Newton converges quickly and for many models, a single solver iteration is sufficient. If your XML looks like this:

<option ... solver="Newton" iterations="1" ls_iterations="4">

Then we omit the jax.while() and mjx.step is differentiable.

The reason we don't support this for CG is that replacing while with scan harms forward performance in some settings, so we currently accept this tradeoff.

Also please note that we have not investigated whether jax.grad delivers useful gradients in this setting - I would love to hear insights from anyone that tries this.

@sfd158 not quite sure what you mean by inplace operations - jax operations are pure, they do not modify the original. See for example this documentation on jax.numpy.ndarray.at

@Andrew-Luo1
Copy link
Contributor

Andrew-Luo1 commented Mar 10, 2024

Not sure if this is the right place to post this - Re: whether jax.grad delivers useful gradients:

I have been playing with gradients over mjx.step (Newton solver; 1 iteration) for my Masters Thesis. Please see an implementation of the Short Horizon Actor Critic (SHAC) algorithm here.

SHAC involves learning control policies using analyical policy gradients; it augments the basic Analytical Policy Gradient (APG) algorithm with several features, such as a value function. I use jax.grad to take the gradient of the loss of an environment rollout with respect to the policy parameters, and these gradients are informative enough to make the algorithm work. This works without contact (inverted pendulum) and with contact (basic hopper).

While the gradients appear to be informative in these simple cases, I can't get quadruped control working; the jacobian of the MJX step, which is a component of the gradient of the loss wrt the policy parameters, is unstable; more on this in the README of the repo. I wonder if different simulation parameters could help here, since this issue appears to have gotten worse from MJX 3.1.1 to MJX 3.1.3.

inverted_pend-ezgif com-video-to-gif-converter
framed_hopper-ezgif com-video-to-gif-converter

@junqingqiao
Copy link

Hi there,

Yes indeed this is by design, but poorly documented. I'll take this as motivation to add to the documentation.

tl;dr: if you would like to experiment with jax.grad(), please update to MuJoCo 3.0.1 which now includes support for Newton solver in MJX. Newton converges quickly and for many models, a single solver iteration is sufficient. If your XML looks like this:

<option ... solver="Newton" iterations="1" ls_iterations="4">

Then we omit the jax.while() and mjx.step is differentiable.

The reason we don't support this for CG is that replacing while with scan harms forward performance in some settings, so we currently accept this tradeoff.

Also please note that we have not investigated whether jax.grad delivers useful gradients in this setting - I would love to hear insights from anyone that tries this.

@sfd158 not quite sure what you mean by inplace operations - jax operations are pure, they do not modify the original. See for example this documentation on jax.numpy.ndarray.at

Hi erikfrey
Does setting the iterations=1 impact the simulation accuracy?
Thanks,
Bugman

@lvjonok
Copy link

lvjonok commented Sep 24, 2024

+1 on facing the issue and finding the solution here)

Is there any way I can contribute by enhancing the documentation here? I would love to improve the tools I use on the daily basis.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working MJX Using JAX to run on GPU
Projects
None yet
Development

No branches or pull requests

8 participants