Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Causes of different behaviour in MJX and CPU MuJoCo #1341

Open
Balint-H opened this issue Jan 16, 2024 · 10 comments
Open

Causes of different behaviour in MJX and CPU MuJoCo #1341

Balint-H opened this issue Jan 16, 2024 · 10 comments
Assignees
Labels
MJX Using JAX to run on GPU question Request for help or information

Comments

@Balint-H
Copy link
Collaborator

Balint-H commented Jan 16, 2024

I've been taking a look at MJX, and I'm impressed with how smooth the usage is. As I've been playing around more with the tutorial code on my system, I noticed how e.g. a humanoid running policy that performs great in MJX usually ends up failing eventually when transferred to CPU MuJoCo.

download.mp4

(video from the "MJX Policy in MuJoCo" cell of the tutorial, with rng = jax.random.PRNGKey(2))

This is of course not too unexpected, considering that I didn't use any domain randomization or other methods that would help with sim2sim transfer. I'm aware of this discussion about FP precision differences between MJX and regular MuJoCo: #1203. Beyond FP precision, are there other key differences in the two versions of the engine that can cause a failure to transfer policies (provided only features that are officially supported in both versions are used)? Are there certain settings that can be used with CPU MuJoCo to make it behave closer to MJX (or vice versa)?

@Balint-H Balint-H added the question Request for help or information label Jan 16, 2024
@erikfrey
Copy link
Collaborator

Hi @Balint-H - glad you're finding MJX easy to use. While it's possible you've found some discrepancy between MuJoCo and MJX, it's also possible you picked a bad seed for the rollout. The RL policies in that colab aren't optimized for stability across seeds - you'll see that in the standard deviation bars in the reward evaluation graphs.

I think for us to determine there's an issue here, you might want to try doing, say, 512 policy rollouts on MJX and the same number on MuJoCo, and then showing the reward distributions are actually different. If you suspect there's an issue, please try that and let us know.

To answer your question, here are the calculation differences between MJX and MuJoCo that come to mind:

  • The float precision, as you mention
  • MJX's convex<>convex collision algorithms are different from MuJoCo (shouldn't matter in your demo)
  • MJX uses a slightly tweaked linesearch op inside its solver compared to MuJoCo, for performance reasons

That's it! Barring any bugs, of course, which we are happy to investigate should you find a repro. Cheers.

@Balint-H
Copy link
Collaborator Author

Balint-H commented Jan 23, 2024

Hello, running with 20 different seeds, the episode lengths of the trained humanoid running policy are below (terminating when falling over, or reaching 500 decision steps):

MJX:
[500, 500, 500, 500, 500, 77, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500, 500]
Average reward: 4567.945

MJX transferred to MuJoCo CPU:
[363, 170, 173, 212, 404, 500, 204, 286, 321, 500, 391, 404, 252, 466, 427, 500, 249, 149, 279, 500]
Average reward: 1475.5411

@erikfrey Is this much discrepancy expected? If not, is there perhaps a bug in tutorial code (e.g. some postprocessing not being applied to the actions for the CPU version)?

Here's a modified version of the colab tutorial that repeats the evals for RNG seeds 0-19 (used for the results above):
https://github.com/Balint-H/mujoco/blob/colab/mjx/mjx/tutorial.ipynb

Please let me know if I messed up the process of editing the tutorial somewhere.

@erikfrey
Copy link
Collaborator

This is very helpful - thank you! I have a hunch as to what's going on here. I'll take a look.

@yuvaltassa
Copy link
Collaborator

@erikfrey any updates?

This sounds like a thread worth pulling on...

@AlexS28
Copy link

AlexS28 commented Apr 17, 2024

@Balint-H has there been any solution to this? This would seem critical?? I am also in the process of soon transferring the policy to Mujoco cpu, so wondering if this was resolved or not.

@Balint-H
Copy link
Collaborator Author

At the moment I think the guideline is to train more robust agents, then finetune on cpu if needed. Although it has been a while since I tried he transfer, some of the tweaks to MJX might have changed this.

@AlexS28
Copy link

AlexS28 commented Apr 17, 2024

I see, because I also noticed this decline in my application between MJX and Mujoco, of course I currently suspect it's due to something in my code not MJX, but then I came across this post and was wondering about it.

@kevinzakka kevinzakka added the MJX Using JAX to run on GPU label Aug 30, 2024
@erikfrey
Copy link
Collaborator

@Balint-H found one culprit that can impede transfer from MJX back to MuJoCo:

While MuJoCo and MJX both converge on numerically close constraint solutions (e.g. contact, joint limit), they can differ if you stop the solver early, e.g. by setting a limit on iterations or raising the tolerance.

Stopping the solver early is fine for many learning workloads, including for sim2real, where you actually want some noise in the physics. But if you want MuJoCo and MJX to have numerically very close dynamics so that you can transfer between them, you should use the default solver iterations/threshold when training your policy, (or try a bit of fine tuning at the end with higher iterations / lower tolerance).

When I have some time I'll update the colab to demonstrate this phenomenon.

@P-Schumacher
Copy link

So under the assumption of using the same solver settings for mjx and cpu mujoco, we can expect both simulators to behave identically?

I did some experiments on this a while back and found that mjx and cpu mujoco were identical up to machine precision. I could also transfer policies.

It would be good to know under which conditions this isn‘t true.

@erikfrey
Copy link
Collaborator

erikfrey commented Dec 3, 2024

@P-Schumacher the only condition I'm aware of is the one above - aggressively lowering solver_iterations and ls_iterations can impede transfer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
MJX Using JAX to run on GPU question Request for help or information
Projects
None yet
Development

No branches or pull requests

6 participants