Description
MJX (comparable to Nvidia Warp) allows Mujoco to run on the GPU using the JAX framework.
Using MJX allows for parallel training in multiple simulation environments (instead of just the one Mujoco instance running on CPU).
Hello Robot Stretch is on the Deepmind Mujoco Menagerie, however, it is not listed as a mobile manipulator with MJX support.
Implementation Challenges
Missing Feature 1: Cylinder collisions
At the time of writing, non of the other mobile manipulators listed on the menagerie have mjx support.
This is likely due to MJX not supporting Cylinder-Mesh collisions, which prevent wheeled robots from being simulated.
This is discussed here: google-deepmind/mujoco#2420.
We attempted replacing the cylindrical-collision shape on the wheels with a cylindrical mesh, however, the robot's movement is very bumpy because of the tessellations at the surface of the cylinder.
Note: The collision driver implements Cylinder-Cylinder and Sphere-Mesh collisions.
Limited Support: GPU Rendering
According the the mujoco docs, we should use madrona_mjx
for vision RL training, because MJX does not have a .render()
method to render cameras on the GPU. Rendering cameras on CPU would be compute intensive and defeat the point of using MJX in the first place.
Build Madrona-MJX from source by running:
```
git clone https://github.com/shacklettbp/madrona_mjx.git
cd madrona_mjx
git submodule update --init --recursive
mkdir build
cd build
cmake -DLOAD_VULKAN=OFF ..
make -j
cd ..
uv pip install -e .
```
Madrona tutorial: https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_2.ipynb
However, this only works on machines with CUDA. We attempted to use this on a Mac (Metal series) and it does not work.
Limited support MacOS Metal:
Jax Metal (https://developer.apple.com/metal/jax/) seems to be missing a lot of features from the regular Jax implementations. Therefore, when trying to run any of the MJX tutorials after running uv pip install jax-metal
yields this error: error: failed to legalize operation 'mhlo.reduce’, which might be related to this:
Which may be related to: jax-ml/jax#17490
Missing Feature 2: multiccd flag
We are currently using the multiccd
flag in the stretch.xml
MJCF file. However, we get a NotImplementedError when we try to use this flag in MJX.
Meshes are too big
For some meshes, we get the following error:
Mesh "cylinder_object" has a coplanar face with more than 20 vertices.
Mesh "link_head_10" has a coplanar face with more than 20 vertices.
Mesh "laser" has a coplanar face with more than 20 vertices.
Mesh "cylinder_wheel" has a coplanar face with more than 20 vertices.
Mesh "link_lift_0" has a coplanar face with more than 20 vertices.
Mesh "link_lift_8" has a coplanar face with more than 20 vertices.
Mesh "link_lift_9" has a coplanar face with more than 20 vertices.
Mesh "link_arm_l0_0" has a coplanar face with more than 20 vertices.
Mesh "link_arm_l0_2" has a coplanar face with more than 20 vertices.
Mesh "link_wrist_yaw" has a coplanar face with more than 20 vertices.
Mesh "link_DW3_wrist_yaw_bottom" has a coplanar face with more than 20 vertices.
Mesh "link_DW3_wrist_pitch" has a coplanar face with more than 20 vertices.
Mesh "link_DW3_wrist_quick_connect" has a coplanar face with more than 20 vertices.
Mesh "link_SG3_gripper_body" has a coplanar face with more than 20 vertices.
Mesh "link_d405" has a coplanar face with more than 20 vertices.
Mesh "link_head_pan_1" has a coplanar face with more than 20 vertices.
Mesh "link_head_tilt_0" has a coplanar face with more than 20 vertices.
Mesh "link_head_tilt_1" has a coplanar face with more than 20 vertices.
All of them had these, so I stripped it out to make the output smaller:
stretch_mujoco/.venv/lib/python3.13/site-packages/mujoco/mjx/_src/mesh.py:141: UserWarning:
This may lead to performance issues and inaccuracies in collision detection. Consider decimating the mesh.
We noticed that when we remove the collision
class (e.g. <geom mesh="link_lift_8" class="collision" mass="1.5"/>
and move the mass to the visual
class, the errors go away, but we can't do that for all of the meshes since some of them actually need to be involved in collision.
Note that some of these meshes are really thin, and should only be used for aesthetics:

Resources and misc notes
-
Google Deepmind Control (dm_control) library tutorial: https://github.com/google-deepmind/dm_control/blob/main/tutorial.ipynb
-
dm_control tutorial: https://arxiv.org/pdf/2006.12983
- dm_control does rendering differently from normal Mujoco: “Quantities that depend only on the state are computed in the first stage, mj_step1(), and those that also depend on the control (including forces) are computed in the subsequent mj_step2(). Physics.step() calls these sub-functions in reverse order, as follows. … In particular, this means that after a
Physics.step(), rendered pixels will correspond to the current state, rather than
the previous one.”
- dm_control does rendering differently from normal Mujoco: “Quantities that depend only on the state are computed in the first stage, mj_step1(), and those that also depend on the control (including forces) are computed in the subsequent mj_step2(). Physics.step() calls these sub-functions in reverse order, as follows. … In particular, this means that after a
-
Mujoco Playground provides a MjxEnv abstraction that allows training using MJX, Jax (Jit-ing MJX) and Flax. https://github.com/google-deepmind/mujoco_playground/blob/main/mujoco_playground/_src/mjx_env.py
-
Mjx is not good for single scene simulations, up to 10x slowdown if not doing parallel compute: https://mujoco.readthedocs.io/en/stable/mjx.html#mjx-the-sharp-bits
-
Jax to Pytorch guide: https://cloud.google.com/blog/products/ai-machine-learning/guide-to-jax-for-pytorch-developers
-
Nvidia datasets for ML training: https://huggingface.co/collections/nvidia/physical-ai-67c643edbb024053dcbcd6d8
-
Robot Dog MJX training implementation example: https://research.mels.ai/ide?mels=UnitreeGo1.qkazy
Code implementation
MjxEnv
To create a training task, it is recommended to implement the MjxEnv abstract class, so the MJX and dm_control can use their internal mechanisms for training.
There is an unfinished branch that started this work of implementing MjxEnv:
- https://github.com/hello-robot/stretch_mujoco/blob/feature/mjx/stretch_mujoco/mjx/stretch.py
- https://github.com/hello-robot/stretch_mujoco/blob/feature/mjx/stretch_mujoco/mjx/pick.py
However, I do not think this is completely necessary if a PPO implementation, independent of dm_control's implementations, is used.
Forward/Inverse Kinematics
The FK/IK implementation here: https://github.com/hello-robot/stretch_mujoco/blob/feature/mjx/stretch_mujoco/mjx/pinocchio_ik_solver.py has never been tested, and I am unsure if it will run. It was adapted from https://github.com/hello-robot/stretch_ai/blob/main/src/stretch/motion/pinocchio_ik_solver.py.
pyproject.toml
Add these dependencies to the [project.optional-dependencies]
section and then run uv pip install -e ".[mjx]"
to install them.
mjx = [
"ipykernel",
"ipython",
"playground",
"brax",
"flax",
"mediapy"
]
Notebooks
Two notebooks were created in this branch:
- RL manipulation training without vision: https://github.com/hello-robot/stretch_mujoco/blob/feature/mjx/stretch_mujoco/mjx/stretch_mjx_manipulation.ipynb
- This one runs without errors, but not successfully. It takes too long to run even one epoch with minimal hyperparameters.
- On MacOS, I tried to speed it up using Metal
uv pip install jax-metal
but got this error:error: failed to legalize operation 'mhlo.reduce’, which might be related to this:
Apple Silicon: error: failed to legalize operation 'mhlo.triangular_solve' jax-ml/jax#17490
- RL manipulation using Madrona (never tested): https://github.com/hello-robot/stretch_mujoco/blob/feature/mjx/stretch_mujoco/mjx/stretch_mjx_manipulator_with_vision.ipynb