Skip to content

Commit 1f1d528

Browse files
quaglacopybara-github
authored andcommitted
Fix bind.set when using a single joint instead of a list.
PiperOrigin-RevId: 738546056 Change-Id: I26ae045be9e983cb0760132b9d47eac018ca33cd
1 parent be64747 commit 1f1d528

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

mjx/mujoco/mjx/_src/support.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -507,14 +507,19 @@ def set(self, name: str, value: jax.Array) -> Data:
507507
iter(value)
508508
except TypeError:
509509
value = [value]
510-
if name == 'qpos':
511-
adr = self.model.jnt_qposadr[self.id]
512-
typ = self.model.jnt_type[self.id]
513-
num = sum((typ == jt) * jt.qpos_width() for jt in JointType)
514-
elif name == 'qvel' or name == 'qacc':
515-
adr = self.model.jnt_dofadr[self.id]
516-
typ = self.model.jnt_type[self.id]
517-
num = sum((typ == jt) * jt.dof_width() for jt in JointType)
510+
if name in ('qpos', 'qvel', 'qacc'):
511+
adr = num = 0
512+
if name == 'qpos':
513+
adr = self.model.jnt_qposadr[self.id]
514+
typ = self.model.jnt_type[self.id]
515+
num = sum((typ == jt) * jt.qpos_width() for jt in JointType)
516+
elif name == 'qvel' or name == 'qacc':
517+
adr = self.model.jnt_dofadr[self.id]
518+
typ = self.model.jnt_type[self.id]
519+
num = sum((typ == jt) * jt.dof_width() for jt in JointType)
520+
if not isinstance(self.id, list):
521+
adr = [adr]
522+
num = [num]
518523
elif isinstance(self.id, list):
519524
adr = self.id * dim
520525
num = [dim for _ in range(len(self.id))]

mjx/mujoco/mjx/_src/support_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ def test_bind(self):
287287
dx6 = dx.bind(mx, s.joints[::2]).set('qpos', [1, 0, 0, 0, 8])
288288
np.testing.assert_array_equal(dx6.bind(mx, s.joints).qpos, qpos_desired)
289289
np.testing.assert_array_almost_equal(dx.bind(mx, s.joints).qpos, d.qpos)
290+
dx6a = dx.bind(mx, s.joints[0]).set('qpos', qpos_desired[:4])
291+
np.testing.assert_array_equal(
292+
dx6a.bind(mx, s.joints[0]).qpos, qpos_desired[:4]
293+
)
290294
dx7 = dx.bind(mx, s.joints[::2]).set('qvel', [2.0, -1.2, 0.5, 0.3])
291295
np.testing.assert_array_almost_equal(
292296
dx7.bind(mx, s.joints).qvel, [2.0, -1.2, 0.5, 0.0, 0.3], decimal=6

0 commit comments

Comments
 (0)