2
2
import subprocess
3
3
import sys
4
4
5
+ import numpy as np
5
6
import pytest
6
7
7
8
from mujoco_py import MjSim , MjRenderPool , load_model_from_xml
8
- from mujoco_py .tests . utils import requires_rendering
9
-
9
+ from mujoco_py .modder import TextureModder
10
+ from mujoco_py . tests . utils import compare_imgs , requires_rendering
10
11
11
12
BASIC_MODEL_XML = """
12
13
<mujoco>
13
14
<worldbody>
14
15
<light name="light1" diffuse=".5 .5 .5" pos="0 0 3" dir="0 0 -1"/>
15
16
<camera name="camera1" pos="3 0 0" zaxis="1 0 0" />
16
- <geom name="geom1 " pos="0.5 0.4 0.3" type="plane" size="1 1 0.1" rgba=".9 0 0 1"/>
17
+ <geom name="g1 " pos="0.5 0.4 0.3" type="plane" size="1 1 0.1" rgba="1 1 1 1" material="m1" />
17
18
<body pos="0 0 1" name="body1">
18
19
<joint name="joint1" type="free"/>
19
- <geom name="geom2 " pos="0 1 0" type="box" size=".1 .2 .3" rgba="0 .9 0 1"/>
20
+ <geom name="g2 " pos="0 1 0" type="box" size=".1 .2 .3" rgba="1 1 1 1" material="m2" />
20
21
<site name="site1" pos="1 0 0" size="0.1" type="sphere"/>
21
22
<site name="sensorsurf" pos="0 0.045 0" size=".03 .03 .03" type="ellipsoid" rgba="0.3 0.2 0.1 0.3"/>
22
23
</body>
23
24
<body pos="1 0 0" name="mocap1" mocap="true">
24
- <geom conaffinity="0" contype="0" pos="0 0 0" size="0.01 0.01 0.01" type="box"/>
25
+ <geom name="g3" conaffinity="0" contype="0" pos="0 0 0" size="0.01 0.01 0.01" type="box" material="m3" rgba="1 1 1 1 "/>
25
26
</body>
26
27
</worldbody>
27
28
<sensor>
28
29
<touch name="touchsensor" site="sensorsurf" />
29
30
</sensor>
31
+ <asset>
32
+ <texture name="t1" width="33" height="36" type="2d" builtin="flat" />
33
+ <texture name="t2" width="34" height="39" type="2d" builtin="flat" />
34
+ <texture name="t3" width="31" height="37" type="2d" builtin="flat" />
35
+ <material name="m1" texture="t1" />
36
+ <material name="m2" texture="t2" />
37
+ <material name="m3" texture="t3" />
38
+ </asset>
30
39
</mujoco>
31
40
"""
32
41
@@ -64,11 +73,59 @@ def mp_test_rendering():
64
73
65
74
images = pool .render (100 , 100 )
66
75
assert images .shape == (3 , 100 , 100 , 3 )
76
+ compare_imgs (images [0 ], 'test_render_pool.mp_test_rendering.0.png' )
77
+ assert np .all (images [0 ] == images [1 ])
67
78
68
79
images , depth = pool .render (101 , 103 , depth = True )
69
80
assert images .shape == (3 , 103 , 101 , 3 )
70
81
assert depth .shape == (3 , 103 , 101 )
82
+ assert np .all (images [0 ] == images [1 ])
83
+ assert np .all (images [1 ] == images [2 ])
84
+
85
+
86
+ def mp_test_cameras ():
87
+ model = load_model_from_xml (BASIC_MODEL_XML )
88
+ pool = MjRenderPool (model , n_workers = 1 )
89
+
90
+ image = pool .render (100 , 100 )
91
+ assert image .shape == (1 , 100 , 100 , 3 )
92
+ compare_imgs (image [0 ], 'test_render_pool.mp_test_cameras.0.png' )
93
+
94
+ image = pool .render (100 , 100 , camera_name = 'camera1' )
95
+ assert image .shape == (1 , 100 , 100 , 3 )
96
+ compare_imgs (image [0 ], 'test_render_pool.mp_test_cameras.1.png' )
97
+
98
+
99
+ def mp_test_modder ():
100
+ model = load_model_from_xml (BASIC_MODEL_XML )
101
+ pool = MjRenderPool (model , n_workers = 2 , modder = TextureModder )
102
+
103
+ images = pool .render (100 , 100 , randomize = True )
104
+ assert images .shape == (2 , 100 , 100 , 3 )
71
105
106
+ # the order of the images aren't guaranteed to be consistent
107
+ # between the render runs
108
+ images1 = pool .render (100 , 100 , randomize = False )
109
+ assert images1 .shape == (2 , 100 , 100 , 3 )
110
+
111
+ if np .all (images [0 ] == images1 [0 ]) and np .all (images [1 ] == images1 [1 ]):
112
+ images_same = True
113
+ elif np .all (images [0 ] == images1 [1 ]) and np .all (images [1 ] == images1 [0 ]):
114
+ images_same = True
115
+ else :
116
+ images_same = False
117
+ assert images_same
118
+
119
+ images2 = pool .render (100 , 100 , randomize = True )
120
+ assert images2 .shape == (2 , 100 , 100 , 3 )
121
+
122
+ if np .all (images [0 ] == images2 [0 ]) and np .all (images [1 ] == images2 [1 ]):
123
+ images_same = True
124
+ elif np .all (images [0 ] == images2 [1 ]) and np .all (images [1 ] == images2 [0 ]):
125
+ images_same = True
126
+ else :
127
+ images_same = False
128
+ assert not images_same
72
129
73
130
def mp_test_states ():
74
131
sim = MjSim (load_model_from_xml (BASIC_MODEL_XML ))
@@ -82,10 +139,15 @@ def mp_test_states():
82
139
83
140
images = pool .render (100 , 100 , states = states [:2 ])
84
141
assert images .shape == (2 , 100 , 100 , 3 )
142
+ compare_imgs (images [0 ], 'test_render_pool.mp_test_states.1.png' )
143
+ compare_imgs (images [1 ], 'test_render_pool.mp_test_states.2.png' )
85
144
86
145
states = list (reversed (states ))
87
146
images = pool .render (100 , 100 , states = states )
88
147
assert images .shape == (3 , 100 , 100 , 3 )
148
+ compare_imgs (images [0 ], 'test_render_pool.mp_test_states.3.png' )
149
+ compare_imgs (images [1 ], 'test_render_pool.mp_test_states.4.png' )
150
+ compare_imgs (images [2 ], 'test_render_pool.mp_test_states.5.png' )
89
151
90
152
91
153
if __name__ == '__main__' :
@@ -95,3 +157,5 @@ def mp_test_states():
95
157
mp_test_create_destroy ()
96
158
mp_test_rendering ()
97
159
mp_test_states ()
160
+ mp_test_cameras ()
161
+ mp_test_modder ()
0 commit comments