Skip to content

NaNs in act when differentiating through muscle-driven arm26.xml with MJX #2489

Closed
@jamesheald

Description

@jamesheald

Intro

Hi!

I am a postdoc at University College London, I use MuJoCo for my research on musculoskeletal control.

I'm very excited by the prospect of using MJX to differentiate through the forward dynamics of musculoskeletal models. However, when I try to differentiate through mjx.step to calculate derivatives of muscle activations with respect to muscle inputs, I get NaNs. I have tested this on the arm26 model in the MJX Github repo. In contrast, when I calculate derivatives of qpos with respect to forces applied directly to the joints, I get well-behaved gradients. This suggests that there is a specific problem with calculating gradients through muscles.

My setup

OS: MacOS Sequoia
python: 3.11.5
mujoco: 3.3.0
mujoco-mjx: 3.3.0
jax: 0.5.1
jaxlib: 0.5.1

What's happening? What did you expect?

The script produces NaNs in the gradients of muscle activations with respect to muscle inputs, but not in the gradients of joint positions with respect to forces applied to the joints.

I expected to be able to apply jax.grad and jax.jacrev to mjx.step to calculate gradients of state variables with respect to muscle inputs.

Steps for reproduction

Run the code below.

Minimal model for reproduction

minimal XML
<mujoco model="2-link 6-muscle arm">
  <option timestep="0.005" iterations="50" solver="Newton" tolerance="1e-10"/>

  <visual>
    <rgba haze=".3 .3 .3 1"/>
  </visual>

  <default>
    <joint type="hinge" pos="0 0 0" axis="0 0 1" limited="true" range="0 120" damping="0.1"/>
    <muscle ctrllimited="true" ctrlrange="0 1"/>
  </default>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1="0.6 0.6 0.6" rgb2="0 0 0" width="512" height="512"/>

    <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>

    <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
  </asset>

  <worldbody>
    <geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane"/>

    <light directional="true" diffuse=".8 .8 .8" specular=".2 .2 .2" pos="0 0 5" dir="0 0 -1"/>

    <site name="s0" pos="-0.15 0 0" size="0.02"/>
    <site name="x0" pos="0 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

    <body pos="0 0 0">
      <geom name="upper arm" type="capsule" size="0.045" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
      <joint name="shoulder"/>
      <geom name="shoulder" type="cylinder" pos="0 0 0" size=".1 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

      <site name="s1" pos="0.15 0.06 0" size="0.02"/>
      <site name="s2" pos="0.15 -0.06 0" size="0.02"/>
      <site name="s3" pos="0.4 0.06 0" size="0.02"/>
      <site name="s4" pos="0.4 -0.06 0" size="0.02"/>
      <site name="s5" pos="0.25 0.1 0" size="0.02"/>
      <site name="s6" pos="0.25 -0.1 0" size="0.02"/>
      <site name="x1" pos="0.5 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

      <body pos="0.5 0 0">
        <geom name="forearm" type="capsule" size="0.035" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
        <joint name="elbow"/>
        <geom name="elbow" type="cylinder" pos="0 0 0" size=".08 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

        <site name="s7" pos="0.11 0.05 0" size="0.02"/>
        <site name="s8" pos="0.11 -0.05 0" size="0.02"/>
      </body>
    </body>
  </worldbody>

  <tendon>
    <spatial name="SF" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s1"/>
    </spatial>

    <spatial name="SE" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s2"/>
    </spatial>

    <spatial name="EF" width="0.01">
      <site site="s3"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="EE" width="0.01">
      <site site="s4"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>

    <spatial name="BF" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s5"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="BE" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s6"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>
  </tendon>

  <actuator>
    <muscle name="SF" tendon="SF"/>
    <muscle name="SE" tendon="SE"/>
    <muscle name="EF" tendon="EF"/>
    <muscle name="EE" tendon="EE"/>
    <muscle name="BF" tendon="BF"/>
    <muscle name="BE" tendon="BE"/>
  </actuator>
</mujoco>

Code required for reproduction

import mujoco
from mujoco import mjx
import jax
from jax import numpy as jp

model_str = """
<!-- Copyright 2021 DeepMind Technologies Limited

     Licensed under the Apache License, Version 2.0 (the "License");
     you may not use this file except in compliance with the License.
     You may obtain a copy of the License at

         http://www.apache.org/licenses/LICENSE-2.0

     Unless required by applicable law or agreed to in writing, software
     distributed under the License is distributed on an "AS IS" BASIS,
     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     See the License for the specific language governing permissions and
     limitations under the License.
-->

<mujoco model="2-link 6-muscle arm">
  <option timestep="0.005" iterations="50" solver="Newton" tolerance="1e-10"/>

  <visual>
    <rgba haze=".3 .3 .3 1"/>
  </visual>

  <default>
    <joint type="hinge" pos="0 0 0" axis="0 0 1" limited="true" range="0 120" damping="0.1"/>
    <muscle ctrllimited="true" ctrlrange="0 1"/>
  </default>

  <asset>
    <texture type="skybox" builtin="gradient" rgb1="0.6 0.6 0.6" rgb2="0 0 0" width="512" height="512"/>

    <texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>

    <material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
  </asset>

  <worldbody>
    <geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane"/>

    <light directional="true" diffuse=".8 .8 .8" specular=".2 .2 .2" pos="0 0 5" dir="0 0 -1"/>

    <site name="s0" pos="-0.15 0 0" size="0.02"/>
    <site name="x0" pos="0 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

    <body pos="0 0 0">
      <geom name="upper arm" type="capsule" size="0.045" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
      <joint name="shoulder"/>
      <geom name="shoulder" type="cylinder" pos="0 0 0" size=".1 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

      <site name="s1" pos="0.15 0.06 0" size="0.02"/>
      <site name="s2" pos="0.15 -0.06 0" size="0.02"/>
      <site name="s3" pos="0.4 0.06 0" size="0.02"/>
      <site name="s4" pos="0.4 -0.06 0" size="0.02"/>
      <site name="s5" pos="0.25 0.1 0" size="0.02"/>
      <site name="s6" pos="0.25 -0.1 0" size="0.02"/>
      <site name="x1" pos="0.5 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>

      <body pos="0.5 0 0">
        <geom name="forearm" type="capsule" size="0.035" fromto="0 0 0  0.5 0 0" rgba=".5 .1 .1 1"/>
        <joint name="elbow"/>
        <geom name="elbow" type="cylinder" pos="0 0 0" size=".08 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>

        <site name="s7" pos="0.11 0.05 0" size="0.02"/>
        <site name="s8" pos="0.11 -0.05 0" size="0.02"/>
      </body>
    </body>
  </worldbody>

  <tendon>
    <spatial name="SF" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s1"/>
    </spatial>

    <spatial name="SE" width="0.01">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s2"/>
    </spatial>

    <spatial name="EF" width="0.01">
      <site site="s3"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="EE" width="0.01">
      <site site="s4"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>

    <spatial name="BF" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder"/>
      <site site="s5"/>
      <geom geom="elbow"/>
      <site site="s7"/>
    </spatial>

    <spatial name="BE" width="0.009" rgba=".4 .6 .4 1">
      <site site="s0"/>
      <geom geom="shoulder" sidesite="x0"/>
      <site site="s6"/>
      <geom geom="elbow" sidesite="x1"/>
      <site site="s8"/>
    </spatial>
  </tendon>

  <actuator>
    <muscle name="SF" tendon="SF"/>
    <muscle name="SE" tendon="SE"/>
    <muscle name="EF" tendon="EF"/>
    <muscle name="EE" tendon="EE"/>
    <muscle name="BF" tendon="BF"/>
    <muscle name="BE" tendon="BE"/>
  </actuator>
</mujoco>
"""

if __name__ == "__main__":
    
    model = mujoco.MjModel.from_xml_string(model_str)
    mjx_model = mjx.put_model(model)
    mjx_data =  mjx.make_data(mjx_model)

    def joint_step(ctrl, mjx_data, mjx_model):

      mjx_data = mjx_data.replace(qfrc_applied=ctrl)
      mjx_data = mjx.step(mjx_model, mjx_data)

      return mjx_data.qpos
    
    get_jacobian_joints = jax.jit(jax.jacrev(joint_step))

    def muscle_step(ctrl, mjx_data, mjx_model):

      mjx_data = mjx_data.replace(ctrl=ctrl)
      mjx_data = mjx.step(mjx_model, mjx_data)

      return mjx_data.act
    
    get_jacobian_muscles = jax.jit(jax.jacrev(muscle_step))

    # differentiation compatabile solver settings - https://github.com/google-deepmind/mujoco/issues/1182
    mjx_model = mjx_model.replace(opt=mjx_model.opt.replace(solver=mujoco.mjtSolver.mjSOL_NEWTON))
    mjx_model = mjx_model.replace(opt=mjx_model.opt.replace(iterations=1))
    mjx_model = mjx_model.replace(opt=mjx_model.opt.replace(ls_iterations=1))

    jacobian_joints = get_jacobian_joints(jp.ones(mjx_model.nq), mjx_data, mjx_model)

    if not jp.any(jp.isnan(jacobian_joints)): print("No NaNs detected when differentiating through joints")

    jacobian_muscles = get_jacobian_muscles(jp.ones(mjx_model.nu), mjx_data, mjx_model)

    assert not jp.any(jp.isnan(jacobian_muscles)), f"NaNs detected when differentiating through muscles"

Confirmations

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions