-
Notifications
You must be signed in to change notification settings - Fork 958
The CG Solver in MJX dosen't support reverse-mode differentiation #1182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I'm also trying to replace the
It works when not using
It is because the the
Is there any chance to make it supported for JIT compilation? |
Hi @LyuJ1998 , this is a known issue with while_loop "while_loop is not reverse-mode differentiable because XLA computations require static bounds on memory requirements." https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html You can change the while_loop to a scan . jax-ml/jax#3850 The |
To reduce memory usage, there are some inplace operation in mjx.step. |
Hi there, Yes indeed this is by design, but poorly documented. I'll take this as motivation to add to the documentation. tl;dr: if you would like to experiment with
Then we omit the The reason we don't support this for Also please note that we have not investigated whether @sfd158 not quite sure what you mean by inplace operations - jax operations are pure, they do not modify the original. See for example this documentation on jax.numpy.ndarray.at |
Not sure if this is the right place to post this - Re: whether I have been playing with gradients over mjx.step (Newton solver; 1 iteration) for my Masters Thesis. Please see an implementation of the Short Horizon Actor Critic (SHAC) algorithm here. SHAC involves learning control policies using analyical policy gradients; it augments the basic Analytical Policy Gradient (APG) algorithm with several features, such as a value function. I use jax.grad to take the gradient of the loss of an environment rollout with respect to the policy parameters, and these gradients are informative enough to make the algorithm work. This works without contact (inverted pendulum) and with contact (basic hopper). While the gradients appear to be informative in these simple cases, I can't get quadruped control working; the jacobian of the MJX step, which is a component of the gradient of the loss wrt the policy parameters, is unstable; more on this in the README of the repo. I wonder if different simulation parameters could help here, since this issue appears to have gotten worse from MJX 3.1.1 to MJX 3.1.3. |
Hi erikfrey |
+1 on facing the issue and finding the solution here) Is there any way I can contribute by enhancing the documentation here? I would love to improve the tools I use on the daily basis. |
I'm trying to differentiate the MJX step function via the autograd function
jax.grad()
in JAX, like:When there is only one rigid body in the scene, everthing works, but when there is a need to solve the collision, for example, a ball and a plane in the scene
Error occurs:
It seems the
jax.lax.while()
function used when solving CG do not support dynamic condition function. How can I solve this?The text was updated successfully, but these errors were encountered: