-
Notifications
You must be signed in to change notification settings - Fork 942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NaNs in act when differentiating through muscle-driven arm26.xml with MJX #2489
Comments
Did you try this with regular MuJoCo? MuJoCo has fast finite difference Jacobians which are well tested. Differentiating the MJX is possible but we don't have much experience with it. |
When I use mujoco.mjd_transitionFD to calculate the Jacobian in the muscle model, I get all zeros for the derivative of the next state with respect to the muscle inputs. The Jacobian is (2*nv+na) x nu, where nv = 2 and na and nu = 6: Jacobian muscle model: [[0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 0. 0.] [1. 0. 0. 0. 0. 0.] [0. 1. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0.] [0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 1. 0.] [0. 0. 0. 0. 0. 1.]] When I exchange the 6 muscles for a single hinge joint motor, I get all nonzeros for the derivative of the next state with respect to the motor input. The values of na and nu are now both 1: Jacobian motor model: [[-0.00012784] [ 0.00044407] [-0.02556708] [ 0.08881482]] This is the script I ran: import mujoco
import numpy as np
muscle_model_str = """
<mujoco model="2-link 6-muscle arm">
<option timestep="0.005" iterations="50" solver="Newton" tolerance="1e-10"/>
<visual>
<rgba haze=".3 .3 .3 1"/>
</visual>
<default>
<joint type="hinge" pos="0 0 0" axis="0 0 1" limited="true" range="0 120" damping="0.1"/>
<muscle ctrllimited="true" ctrlrange="0 1"/>
</default>
<asset>
<texture type="skybox" builtin="gradient" rgb1="0.6 0.6 0.6" rgb2="0 0 0" width="512" height="512"/>
<texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>
<material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
</asset>
<worldbody>
<geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane"/>
<light directional="true" diffuse=".8 .8 .8" specular=".2 .2 .2" pos="0 0 5" dir="0 0 -1"/>
<site name="s0" pos="-0.15 0 0" size="0.02"/>
<site name="x0" pos="0 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>
<body pos="0 0 0">
<geom name="upper arm" type="capsule" size="0.045" fromto="0 0 0 0.5 0 0" rgba=".5 .1 .1 1"/>
<joint name="shoulder"/>
<geom name="shoulder" type="cylinder" pos="0 0 0" size=".1 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>
<site name="s1" pos="0.15 0.06 0" size="0.02"/>
<site name="s2" pos="0.15 -0.06 0" size="0.02"/>
<site name="s3" pos="0.4 0.06 0" size="0.02"/>
<site name="s4" pos="0.4 -0.06 0" size="0.02"/>
<site name="s5" pos="0.25 0.1 0" size="0.02"/>
<site name="s6" pos="0.25 -0.1 0" size="0.02"/>
<site name="x1" pos="0.5 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>
<body pos="0.5 0 0">
<geom name="forearm" type="capsule" size="0.035" fromto="0 0 0 0.5 0 0" rgba=".5 .1 .1 1"/>
<joint name="elbow"/>
<geom name="elbow" type="cylinder" pos="0 0 0" size=".08 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>
<site name="s7" pos="0.11 0.05 0" size="0.02"/>
<site name="s8" pos="0.11 -0.05 0" size="0.02"/>
</body>
</body>
</worldbody>
<tendon>
<spatial name="SF" width="0.01">
<site site="s0"/>
<geom geom="shoulder"/>
<site site="s1"/>
</spatial>
<spatial name="SE" width="0.01">
<site site="s0"/>
<geom geom="shoulder" sidesite="x0"/>
<site site="s2"/>
</spatial>
<spatial name="EF" width="0.01">
<site site="s3"/>
<geom geom="elbow"/>
<site site="s7"/>
</spatial>
<spatial name="EE" width="0.01">
<site site="s4"/>
<geom geom="elbow" sidesite="x1"/>
<site site="s8"/>
</spatial>
<spatial name="BF" width="0.009" rgba=".4 .6 .4 1">
<site site="s0"/>
<geom geom="shoulder"/>
<site site="s5"/>
<geom geom="elbow"/>
<site site="s7"/>
</spatial>
<spatial name="BE" width="0.009" rgba=".4 .6 .4 1">
<site site="s0"/>
<geom geom="shoulder" sidesite="x0"/>
<site site="s6"/>
<geom geom="elbow" sidesite="x1"/>
<site site="s8"/>
</spatial>
</tendon>
<actuator>
<muscle name="SF" tendon="SF"/>
<muscle name="SE" tendon="SE"/>
<muscle name="EF" tendon="EF"/>
<muscle name="EE" tendon="EE"/>
<muscle name="BF" tendon="BF"/>
<muscle name="BE" tendon="BE"/>
</actuator>
</mujoco>
"""
motor_model_str = """
<mujoco model="2-link 6-muscle arm">
<option timestep="0.005" iterations="50" solver="Newton" tolerance="1e-10"/>
<visual>
<rgba haze=".3 .3 .3 1"/>
</visual>
<default>
<joint type="hinge" pos="0 0 0" axis="0 0 1" limited="true" range="0 120" damping="0.1"/>
<muscle ctrllimited="true" ctrlrange="0 1"/>
</default>
<asset>
<texture type="skybox" builtin="gradient" rgb1="0.6 0.6 0.6" rgb2="0 0 0" width="512" height="512"/>
<texture name="texplane" type="2d" builtin="checker" rgb1=".25 .25 .25" rgb2=".3 .3 .3" width="512" height="512" mark="cross" markrgb=".8 .8 .8"/>
<material name="matplane" reflectance="0.3" texture="texplane" texrepeat="1 1" texuniform="true"/>
</asset>
<worldbody>
<geom name="floor" pos="0 0 -0.5" size="0 0 1" type="plane" material="matplane"/>
<light directional="true" diffuse=".8 .8 .8" specular=".2 .2 .2" pos="0 0 5" dir="0 0 -1"/>
<site name="s0" pos="-0.15 0 0" size="0.02"/>
<site name="x0" pos="0 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>
<body pos="0 0 0">
<geom name="upper arm" type="capsule" size="0.045" fromto="0 0 0 0.5 0 0" rgba=".5 .1 .1 1"/>
<joint name="shoulder"/>
<geom name="shoulder" type="cylinder" pos="0 0 0" size=".1 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>
<site name="s1" pos="0.15 0.06 0" size="0.02"/>
<site name="s2" pos="0.15 -0.06 0" size="0.02"/>
<site name="s3" pos="0.4 0.06 0" size="0.02"/>
<site name="s4" pos="0.4 -0.06 0" size="0.02"/>
<site name="s5" pos="0.25 0.1 0" size="0.02"/>
<site name="s6" pos="0.25 -0.1 0" size="0.02"/>
<site name="x1" pos="0.5 -0.15 0" size="0.02" rgba="0 .7 0 1" group="1"/>
<body pos="0.5 0 0">
<geom name="forearm" type="capsule" size="0.035" fromto="0 0 0 0.5 0 0" rgba=".5 .1 .1 1"/>
<joint name="elbow"/>
<geom name="elbow" type="cylinder" pos="0 0 0" size=".08 .05" rgba=".5 .1 .8 .5" mass="0" group="1"/>
<site name="s7" pos="0.11 0.05 0" size="0.02"/>
<site name="s8" pos="0.11 -0.05 0" size="0.02"/>
</body>
</body>
</worldbody>
<actuator>
<motor name="hinge" joint="elbow"/>
</actuator>
</mujoco>
"""
if __name__ == "__main__":
muscle_model = mujoco.MjModel.from_xml_string(muscle_model_str)
data = mujoco.MjData(muscle_model)
data.ctrl = np.ones(muscle_model.nu)
B = np.zeros((2 * muscle_model.nv + muscle_model.na, muscle_model.nu))
# Compute the transition Jacobian using finite differences
mujoco.mjd_transitionFD(muscle_model, data, 1e-6, 1, None, B, None, None)
print("Jacobian muscle model: \n", B)
motor_model_str = mujoco.MjModel.from_xml_string(motor_model_str)
data = mujoco.MjData(motor_model_str)
data.ctrl = np.ones(motor_model_str.nu)
B = np.zeros((2 * motor_model_str.nv + motor_model_str.na, motor_model_str.nu))
# Compute the transition Jacobian using finite differences
mujoco.mjd_transitionFD(motor_model_str, data, 1e-6, 1, None, B, None, None)
print("Jacobian motor model: \n", B) |
This is working as intended and exactly what I would expect. When you have muscles, the excitation (ctrl) only affects the activation (act). The activation affects the dynamics through the transition matrix A, not through the control matrix B. ctrl -> act -> qacc -> qvel -> qpos. Does this make sense? |
Maybe I have misunderstood the role of transitionfd. In terms of your flow (ctrl -> act -> qacc -> qvel -> qpos), I thought that B would contain dqpos/dctrl and dqvel/dctrl. How do I compute these quantities in regular MuJoCo? More, generally I want to compute the derivatives of functions of qpos or qvel (e.g. the position and orientation of bodies) with respect to ctrl (ideally in MJX, but regular MuJoCo would also be useful). |
You did not misunderstand, that is exactly what it's giving you. In the model with the muscles dqpos/dctrl and dqvel/dctrl are identically zero, because ctrl only affects act (and then act affects qvel). |
Why are we not just applying the chain rule across all the step in your flow to get a nonzero dqpos/dctrl? |
This is a discrete time, stateful system. The |
Ok, that makes sense. That doesn't explain why I get NaN when calculating dact/dctrl (my original post) though does it? |
No it does not. I'll reopen. |
Intro
Hi!
I am a postdoc at University College London, I use MuJoCo for my research on musculoskeletal control.
I'm very excited by the prospect of using MJX to differentiate through the forward dynamics of musculoskeletal models. However, when I try to differentiate through mjx.step to calculate derivatives of muscle activations with respect to muscle inputs, I get NaNs. I have tested this on the arm26 model in the MJX Github repo. In contrast, when I calculate derivatives of qpos with respect to forces applied directly to the joints, I get well-behaved gradients. This suggests that there is a specific problem with calculating gradients through muscles.
My setup
OS: MacOS Sequoia
python: 3.11.5
mujoco: 3.3.0
mujoco-mjx: 3.3.0
jax: 0.5.1
jaxlib: 0.5.1
What's happening? What did you expect?
The script produces NaNs in the gradients of muscle activations with respect to muscle inputs, but not in the gradients of joint positions with respect to forces applied to the joints.
I expected to be able to apply jax.grad and jax.jacrev to mjx.step to calculate gradients of state variables with respect to muscle inputs.
Steps for reproduction
Run the code below.
Minimal model for reproduction
minimal XML
Code required for reproduction
Confirmations
The text was updated successfully, but these errors were encountered: