Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaNs in act when differentiating through muscle-driven arm26.xml with MJX #2489

Open
2 tasks done
jamesheald opened this issue Mar 10, 2025 · 9 comments
Open
2 tasks done
Assignees
Labels
bug Something isn't working

Comments

@jamesheald
Copy link

Intro

Hi!

I am a postdoc at University College London, I use MuJoCo for my research on musculoskeletal control.

I'm very excited by the prospect of using MJX to differentiate through the forward dynamics of musculoskeletal models. However, when I try to differentiate through mjx.step to calculate derivatives of muscle activations with respect to muscle inputs, I get NaNs. I have tested this on the arm26 model in the MJX Github repo. In contrast, when I calculate derivatives of qpos with respect to forces applied directly to the joints, I get well-behaved gradients. This suggests that there is a specific problem with calculating gradients through muscles.

My setup

OS: MacOS Sequoia
python: 3.11.5
mujoco: 3.3.0
mujoco-mjx: 3.3.0
jax: 0.5.1
jaxlib: 0.5.1

What's happening? What did you expect?

The script produces NaNs in the gradients of muscle activations with respect to muscle inputs, but not in the gradients of joint positions with respect to forces applied to the joints.

I expected to be able to apply jax.grad and jax.jacrev to mjx.step to calculate gradients of state variables with respect to muscle inputs.

Steps for reproduction

Run the code below.

Minimal model for reproduction

minimal XML
<mujoco model="2-link 6-muscle arm">
  <option timestep="0.005" iterations="50" solver="Newton" tolerance="1e-10"/>

  <visual>
    <rgba haze=".3 .3 .3 1"/>
  </visual>

  <default>
    <joint type="hinge" pos="0 0 0" axis="0 0 1" limited="true" range="0 120" damping="0.1"/>
    <muscle ctrllimited="true" ctrlrange="0 1"/>
  </default>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1="0.6 0.6 0.6" rgb2="0 0 0" width="512" height="512"/>

    <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>

    <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
  </asset>

  <worldbody>
    <geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane"/>

    <light directional="true" diffuse=".8 .8 .8" specular=".2 .2 .2" pos="0 0 5" dir="0 0 -1"/>

    <site name="s0" pos="-0.15 0 0" size="0.02"/>
    <site name="x0" pos="0 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

    <body pos="0 0 0">
      <geom name="upper arm" type="capsule" size="0.045" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
      <joint name="shoulder"/>
      <geom name="shoulder" type="cylinder" pos="0 0 0" size=".1 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

      <site name="s1" pos="0.15 0.06 0" size="0.02"/>
      <site name="s2" pos="0.15 -0.06 0" size="0.02"/>
      <site name="s3" pos="0.4 0.06 0" size="0.02"/>
      <site name="s4" pos="0.4 -0.06 0" size="0.02"/>
      <site name="s5" pos="0.25 0.1 0" size="0.02"/>
      <site name="s6" pos="0.25 -0.1 0" size="0.02"/>
      <site name="x1" pos="0.5 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

      <body pos="0.5 0 0">
        <geom name="forearm" type="capsule" size="0.035" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
        <joint name="elbow"/>
        <geom name="elbow" type="cylinder" pos="0 0 0" size=".08 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

        <site name="s7" pos="0.11 0.05 0" size="0.02"/>
        <site name="s8" pos="0.11 -0.05 0" size="0.02"/>
      </body>
    </body>
  </worldbody>

  <tendon>
    <spatial name="SF" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s1"/>
    </spatial>

    <spatial name="SE" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s2"/>
    </spatial>

    <spatial name="EF" width="0.01">
      <site site="s3"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="EE" width="0.01">
      <site site="s4"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>

    <spatial name="BF" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s5"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="BE" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s6"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>
  </tendon>

  <actuator>
    <muscle name="SF" tendon="SF"/>
    <muscle name="SE" tendon="SE"/>
    <muscle name="EF" tendon="EF"/>
    <muscle name="EE" tendon="EE"/>
    <muscle name="BF" tendon="BF"/>
    <muscle name="BE" tendon="BE"/>
  </actuator>
</mujoco>

Code required for reproduction

import mujoco
from mujoco import mjx
import jax
from jax import numpy as jp

model_str = """
<!-- Copyright 2021 DeepMind Technologies Limited

     Licensed under the Apache License, Version 2.0 (the "License");
     you may not use this file except in compliance with the License.
     You may obtain a copy of the License at

         http://www.apache.org/licenses/LICENSE-2.0

     Unless required by applicable law or agreed to in writing, software
     distributed under the License is distributed on an "AS IS" BASIS,
     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     See the License for the specific language governing permissions and
     limitations under the License.
-->

<mujoco model="2-link 6-muscle arm">
  <option timestep="0.005" iterations="50" solver="Newton" tolerance="1e-10"/>

  <visual>
    <rgba haze=".3 .3 .3 1"/>
  </visual>

  <default>
    <joint type="hinge" pos="0 0 0" axis="0 0 1" limited="true" range="0 120" damping="0.1"/>
    <muscle ctrllimited="true" ctrlrange="0 1"/>
  </default>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1="0.6 0.6 0.6" rgb2="0 0 0" width="512" height="512"/>

    <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>

    <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
  </asset>

  <worldbody>
    <geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane"/>

    <light directional="true" diffuse=".8 .8 .8" specular=".2 .2 .2" pos="0 0 5" dir="0 0 -1"/>

    <site name="s0" pos="-0.15 0 0" size="0.02"/>
    <site name="x0" pos="0 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

    <body pos="0 0 0">
      <geom name="upper arm" type="capsule" size="0.045" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
      <joint name="shoulder"/>
      <geom name="shoulder" type="cylinder" pos="0 0 0" size=".1 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

      <site name="s1" pos="0.15 0.06 0" size="0.02"/>
      <site name="s2" pos="0.15 -0.06 0" size="0.02"/>
      <site name="s3" pos="0.4 0.06 0" size="0.02"/>
      <site name="s4" pos="0.4 -0.06 0" size="0.02"/>
      <site name="s5" pos="0.25 0.1 0" size="0.02"/>
      <site name="s6" pos="0.25 -0.1 0" size="0.02"/>
      <site name="x1" pos="0.5 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

      <body pos="0.5 0 0">
        <geom name="forearm" type="capsule" size="0.035" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
        <joint name="elbow"/>
        <geom name="elbow" type="cylinder" pos="0 0 0" size=".08 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

        <site name="s7" pos="0.11 0.05 0" size="0.02"/>
        <site name="s8" pos="0.11 -0.05 0" size="0.02"/>
      </body>
    </body>
  </worldbody>

  <tendon>
    <spatial name="SF" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s1"/>
    </spatial>

    <spatial name="SE" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s2"/>
    </spatial>

    <spatial name="EF" width="0.01">
      <site site="s3"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="EE" width="0.01">
      <site site="s4"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>

    <spatial name="BF" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s5"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="BE" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s6"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>
  </tendon>

  <actuator>
    <muscle name="SF" tendon="SF"/>
    <muscle name="SE" tendon="SE"/>
    <muscle name="EF" tendon="EF"/>
    <muscle name="EE" tendon="EE"/>
    <muscle name="BF" tendon="BF"/>
    <muscle name="BE" tendon="BE"/>
  </actuator>
</mujoco>
"""

if __name__ == "__main__":
    
    model = mujoco.MjModel.from_xml_string(model_str)
    mjx_model = mjx.put_model(model)
    mjx_data =  mjx.make_data(mjx_model)

    def joint_step(ctrl, mjx_data, mjx_model):

      mjx_data = mjx_data.replace(qfrc_applied=ctrl)
      mjx_data = mjx.step(mjx_model, mjx_data)

      return mjx_data.qpos
    
    get_jacobian_joints = jax.jit(jax.jacrev(joint_step))

    def muscle_step(ctrl, mjx_data, mjx_model):

      mjx_data = mjx_data.replace(ctrl=ctrl)
      mjx_data = mjx.step(mjx_model, mjx_data)

      return mjx_data.act
    
    get_jacobian_muscles = jax.jit(jax.jacrev(muscle_step))

    # differentiation compatabile solver settings - https://github.com/google-deepmind/mujoco/issues/1182
    mjx_model = mjx_model.replace(opt=mjx_model.opt.replace(solver=mujoco.mjtSolver.mjSOL_NEWTON))
    mjx_model = mjx_model.replace(opt=mjx_model.opt.replace(iterations=1))
    mjx_model = mjx_model.replace(opt=mjx_model.opt.replace(ls_iterations=1))

    jacobian_joints = get_jacobian_joints(jp.ones(mjx_model.nq), mjx_data, mjx_model)

    if not jp.any(jp.isnan(jacobian_joints)): print("No NaNs detected when differentiating through joints")

    jacobian_muscles = get_jacobian_muscles(jp.ones(mjx_model.nu), mjx_data, mjx_model)

    assert not jp.any(jp.isnan(jacobian_muscles)), f"NaNs detected when differentiating through muscles"

Confirmations

@jamesheald jamesheald added the bug Something isn't working label Mar 10, 2025
@yuvaltassa
Copy link
Collaborator

Did you try this with regular MuJoCo?

MuJoCo has fast finite difference Jacobians which are well tested. Differentiating the MJX is possible but we don't have much experience with it.

@jamesheald
Copy link
Author

jamesheald commented Mar 10, 2025

When I use mujoco.mjd_transitionFD to calculate the Jacobian in the muscle model, I get all zeros for the derivative of the next state with respect to the muscle inputs.

The Jacobian is (2*nv+na) x nu, where nv = 2 and na and nu = 6:

Jacobian muscle model: 
 [[0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1.]]

When I exchange the 6 muscles for a single hinge joint motor, I get all nonzeros for the derivative of the next state with respect to the motor input.

The values of na and nu are now both 1:

Jacobian motor model: 
 [[-0.00012784]
 [ 0.00044407]
 [-0.02556708]
 [ 0.08881482]]

This is the script I ran:

import mujoco
import numpy as np

muscle_model_str = """
<mujoco model="2-link 6-muscle arm">
  <option timestep="0.005" iterations="50" solver="Newton" tolerance="1e-10"/>

  <visual>
    <rgba haze=".3 .3 .3 1"/>
  </visual>

  <default>
    <joint type="hinge" pos="0 0 0" axis="0 0 1" limited="true" range="0 120" damping="0.1"/>
    <muscle ctrllimited="true" ctrlrange="0 1"/>
  </default>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1="0.6 0.6 0.6" rgb2="0 0 0" width="512" height="512"/>

    <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>

    <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
  </asset>

  <worldbody>
    <geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane"/>

    <light directional="true" diffuse=".8 .8 .8" specular=".2 .2 .2" pos="0 0 5" dir="0 0 -1"/>

    <site name="s0" pos="-0.15 0 0" size="0.02"/>
    <site name="x0" pos="0 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

    <body pos="0 0 0">
      <geom name="upper arm" type="capsule" size="0.045" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
      <joint name="shoulder"/>
      <geom name="shoulder" type="cylinder" pos="0 0 0" size=".1 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

      <site name="s1" pos="0.15 0.06 0" size="0.02"/>
      <site name="s2" pos="0.15 -0.06 0" size="0.02"/>
      <site name="s3" pos="0.4 0.06 0" size="0.02"/>
      <site name="s4" pos="0.4 -0.06 0" size="0.02"/>
      <site name="s5" pos="0.25 0.1 0" size="0.02"/>
      <site name="s6" pos="0.25 -0.1 0" size="0.02"/>
      <site name="x1" pos="0.5 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

      <body pos="0.5 0 0">
        <geom name="forearm" type="capsule" size="0.035" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
        <joint name="elbow"/>
        <geom name="elbow" type="cylinder" pos="0 0 0" size=".08 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

        <site name="s7" pos="0.11 0.05 0" size="0.02"/>
        <site name="s8" pos="0.11 -0.05 0" size="0.02"/>
      </body>
    </body>
  </worldbody>

  <tendon>
    <spatial name="SF" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s1"/>
    </spatial>

    <spatial name="SE" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s2"/>
    </spatial>

    <spatial name="EF" width="0.01">
      <site site="s3"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="EE" width="0.01">
      <site site="s4"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>

    <spatial name="BF" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s5"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="BE" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s6"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>
  </tendon>

  <actuator>
    <muscle name="SF" tendon="SF"/>
    <muscle name="SE" tendon="SE"/>
    <muscle name="EF" tendon="EF"/>
    <muscle name="EE" tendon="EE"/>
    <muscle name="BF" tendon="BF"/>
    <muscle name="BE" tendon="BE"/>
  </actuator>
</mujoco>
"""

motor_model_str = """
<mujoco model="2-link 6-muscle arm">
  <option timestep="0.005" iterations="50" solver="Newton" tolerance="1e-10"/>

  <visual>
    <rgba haze=".3 .3 .3 1"/>
  </visual>

  <default>
    <joint type="hinge" pos="0 0 0" axis="0 0 1" limited="true" range="0 120" damping="0.1"/>
    <muscle ctrllimited="true" ctrlrange="0 1"/>
  </default>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1="0.6 0.6 0.6" rgb2="0 0 0" width="512" height="512"/>

    <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>

    <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
  </asset>

  <worldbody>
    <geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane"/>

    <light directional="true" diffuse=".8 .8 .8" specular=".2 .2 .2" pos="0 0 5" dir="0 0 -1"/>

    <site name="s0" pos="-0.15 0 0" size="0.02"/>
    <site name="x0" pos="0 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

    <body pos="0 0 0">
      <geom name="upper arm" type="capsule" size="0.045" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
      <joint name="shoulder"/>
      <geom name="shoulder" type="cylinder" pos="0 0 0" size=".1 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

      <site name="s1" pos="0.15 0.06 0" size="0.02"/>
      <site name="s2" pos="0.15 -0.06 0" size="0.02"/>
      <site name="s3" pos="0.4 0.06 0" size="0.02"/>
      <site name="s4" pos="0.4 -0.06 0" size="0.02"/>
      <site name="s5" pos="0.25 0.1 0" size="0.02"/>
      <site name="s6" pos="0.25 -0.1 0" size="0.02"/>
      <site name="x1" pos="0.5 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

      <body pos="0.5 0 0">
        <geom name="forearm" type="capsule" size="0.035" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
        <joint name="elbow"/>
        <geom name="elbow" type="cylinder" pos="0 0 0" size=".08 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

        <site name="s7" pos="0.11 0.05 0" size="0.02"/>
        <site name="s8" pos="0.11 -0.05 0" size="0.02"/>
      </body>
    </body>
  </worldbody>

  <actuator>
    <motor name="hinge" joint="elbow"/>
  </actuator>
</mujoco>
"""

if __name__ == "__main__":
    
    muscle_model = mujoco.MjModel.from_xml_string(muscle_model_str)
    data = mujoco.MjData(muscle_model)
    data.ctrl = np.ones(muscle_model.nu)
    B = np.zeros((2 * muscle_model.nv + muscle_model.na, muscle_model.nu))

    # Compute the transition Jacobian using finite differences
    mujoco.mjd_transitionFD(muscle_model, data, 1e-6, 1, None, B, None, None)
    print("Jacobian muscle model: \n", B)

    motor_model_str = mujoco.MjModel.from_xml_string(motor_model_str)
    data = mujoco.MjData(motor_model_str)
    data.ctrl = np.ones(motor_model_str.nu)
    B = np.zeros((2 * motor_model_str.nv + motor_model_str.na, motor_model_str.nu))

    # Compute the transition Jacobian using finite differences
    mujoco.mjd_transitionFD(motor_model_str, data, 1e-6, 1, None, B, None, None)
    print("Jacobian motor model: \n", B)

@yuvaltassa
Copy link
Collaborator

This is working as intended and exactly what I would expect. When you have muscles, the excitation (ctrl) only affects the activation (act). The activation affects the dynamics through the transition matrix A, not through the control matrix B. ctrl -> act -> qacc -> qvel -> qpos.

Does this make sense?

@jamesheald
Copy link
Author

jamesheald commented Mar 11, 2025

Maybe I have misunderstood the role of transitionfd. In terms of your flow (ctrl -> act -> qacc -> qvel -> qpos), I thought that B would contain dqpos/dctrl and dqvel/dctrl. How do I compute these quantities in regular MuJoCo?

More, generally I want to compute the derivatives of functions of qpos or qvel (e.g. the position and orientation of bodies) with respect to ctrl (ideally in MJX, but regular MuJoCo would also be useful).

@yuvaltassa
Copy link
Collaborator

You did not misunderstand, that is exactly what it's giving you. In the model with the muscles dqpos/dctrl and dqvel/dctrl are identically zero, because ctrl only affects act (and then act affects qvel).

@jamesheald
Copy link
Author

Why are we not just applying the chain rule across all the step in your flow to get a nonzero dqpos/dctrl?

@yuvaltassa
Copy link
Collaborator

yuvaltassa commented Mar 11, 2025

This is a discrete time, stateful system. The act variables introduce a state between ctrl and the motion state. dqpos/dctrl is zero if you have muscles or other stateful actuators; the chain rule has nothing to do with it. That's what statefulness means, in discrete time. At time 0, ctrl will affect act, at the next timestep act will affect the dynamics.

@yuvaltassa yuvaltassa closed this as not planned Won't fix, can't repro, duplicate, stale Mar 11, 2025
@jamesheald
Copy link
Author

jamesheald commented Mar 11, 2025

Ok, that makes sense. That doesn't explain why I get NaN when calculating dact/dctrl (my original post) though does it?

@yuvaltassa
Copy link
Collaborator

No it does not. I'll reopen.

@yuvaltassa yuvaltassa reopened this Mar 12, 2025
@thowell thowell self-assigned this Mar 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants