Skip to content

Commit 343ea73

Browse files
authored
Support for geomgroup filters in raycasting (#370)
* improvements to sim.ray() * typo * group filter is bytes
1 parent d980119 commit 343ea73

File tree

1 file changed

+80
-13
lines changed

1 file changed

+80
-13
lines changed

mujoco_py/mjsim.pyx

+80-13
Original file line numberDiff line numberDiff line change
@@ -335,31 +335,98 @@ cdef class MjSim(object):
335335
else:
336336
raise ValueError("Unsupported format. Valid ones are 'xml' and 'mjb'")
337337

338-
def ray(self,
339-
np.ndarray[np.float64_t, mode="c", ndim=1] pnt,
340-
np.ndarray[np.float64_t, mode="c", ndim=1] vec,
341-
include_static_geoms=True, exclude_body=-1):
338+
def ray(self, pnt, vec, include_static_geoms=True, exclude_body=-1, group_filter=None):
342339
"""
343340
Cast a ray into the scene, and return the first valid geom it intersects.
344341
pnt - origin point of the ray in world coordinates (X Y Z)
345342
vec - direction of the ray in world coordinates (X Y Z)
346343
include_static_geoms - if False, we exclude geoms that are children of worldbody.
347344
exclude_body - if this is a body ID, we exclude all children geoms of this body.
348-
Returns (distance, geom_id) where
345+
group_filter - a vector of booleans of length const.NGROUP
346+
which specifies what geom groups (stored in model.geom_group)
347+
to enable or disable. If none, all groups are used
348+
Returns (distance, geomid) where
349349
distance - distance along ray until first collision with geom
350-
geom_id - id of the geom the ray collided with
350+
geomid - id of the geom the ray collided with
351351
If no collision was found in the scene, return (-1, None)
352352
353353
NOTE: sometimes self.forward() needs to be called before self.ray().
354+
355+
See self.ray_fast_group() and self.ray_fast_nogroup() for versions of this call
356+
with more stringent type requirements.
357+
"""
358+
cdef mjtNum distance
359+
cdef mjtNum[::view.contiguous] pnt_view = pnt
360+
cdef mjtNum[::view.contiguous] vec_view = vec
361+
362+
if group_filter is None:
363+
return self.ray_fast_nogroup(
364+
np.asarray(pnt, dtype=np.float64),
365+
np.asarray(vec, dtype=np.float64),
366+
1 if include_static_geoms else 0,
367+
exclude_body)
368+
else:
369+
return self.ray_fast_group(
370+
np.asarray(pnt, dtype=np.float64),
371+
np.asarray(vec, dtype=np.float64),
372+
np.asarray(group_filter, dtype=np.uint8),
373+
1 if include_static_geoms else 0,
374+
exclude_body)
375+
376+
def ray_fast_group(self,
377+
np.ndarray[np.float64_t, mode="c", ndim=1] pnt,
378+
np.ndarray[np.float64_t, mode="c", ndim=1] vec,
379+
np.ndarray[np.uint8_t, mode="c", ndim=1] geomgroup,
380+
mjtByte flg_static=1,
381+
int bodyexclude=-1):
382+
"""
383+
Faster version of sim.ray(), which avoids extra copies,
384+
but needs to be given all the correct type arrays.
385+
386+
See self.ray() for explanation of arguments
387+
"""
388+
cdef int geomid
389+
cdef mjtNum distance
390+
cdef mjtNum[::view.contiguous] pnt_view = pnt
391+
cdef mjtNum[::view.contiguous] vec_view = vec
392+
cdef mjtByte[::view.contiguous] geomgroup_view = geomgroup
393+
394+
distance = mj_ray(self.model.ptr,
395+
self.data.ptr,
396+
&pnt_view[0],
397+
&vec_view[0],
398+
&geomgroup_view[0],
399+
flg_static,
400+
bodyexclude,
401+
&geomid)
402+
return (distance, geomid)
403+
404+
405+
def ray_fast_nogroup(self,
406+
np.ndarray[np.float64_t, mode="c", ndim=1] pnt,
407+
np.ndarray[np.float64_t, mode="c", ndim=1] vec,
408+
mjtByte flg_static=1,
409+
int bodyexclude=-1):
410+
"""
411+
Faster version of sim.ray(), which avoids extra copies,
412+
but needs to be given all the correct type arrays.
413+
414+
This version hardcodes the geomgroup to NULL.
415+
(Can't easily express a signature that is "numpy array of specific type or None")
416+
417+
See self.ray() for explanation of arguments
354418
"""
355-
cdef int geom_id
419+
cdef int geomid
356420
cdef mjtNum distance
357421
cdef mjtNum[::view.contiguous] pnt_view = pnt
358422
cdef mjtNum[::view.contiguous] vec_view = vec
359423

360-
distance = mj_ray(self.model.ptr, self.data.ptr,
361-
&pnt_view[0], &vec_view[0], NULL,
362-
1 if include_static_geoms else 0,
363-
exclude_body,
364-
&geom_id)
365-
return (distance, geom_id)
424+
distance = mj_ray(self.model.ptr,
425+
self.data.ptr,
426+
&pnt_view[0],
427+
&vec_view[0],
428+
NULL,
429+
flg_static,
430+
bodyexclude,
431+
&geomid)
432+
return (distance, geomid)

0 commit comments

Comments
 (0)