Skip to content

Commit 87e871a

Browse files
committed
Fix tests and examples
1 parent a01aa95 commit 87e871a

File tree

6 files changed

+101
-67
lines changed

6 files changed

+101
-67
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ test: build
2727
docker run --rm -i mujoco_py pytest
2828

2929
mount_shell:
30-
docker run --rm -it -v `pwd`:/dev mujoco_py /bin/bash -c "pip uninstall -y mujoco_py; rm -rf /mujoco_py; (cd /dev; /bin/bash)"
30+
docker run --rm -it -v `pwd`:/code mujoco_py /bin/bash -c "pip3 uninstall -y mujoco_py; rm -rf /mujoco_py; (cd /code; /bin/bash)"
3131

3232
shell:
3333
docker run --rm -it mujoco_py /bin/bash

examples/multigpu_rendering.py

+12-54
Original file line numberDiff line numberDiff line change
@@ -4,78 +4,36 @@
44
"""
55
from multiprocessing import set_start_method
66
from time import perf_counter
7-
from mujoco_py import load_model_from_path, MjSim
8-
from mujoco_py.mjrenderpool import RenderProcess
9-
import tensorflow as tf
107

11-
#
12-
# Experiment parameters
13-
#
14-
15-
# Image size for rendering
16-
IMAGE_WIDTH = 255
17-
IMAGE_HEIGHT = 255
18-
# Number of frames to render per sim
19-
N_FRAMES = 2000
20-
# Number of sims to run in parallel (assumes one per GPU),
21-
# so N_SIMS=2 assumes there are 2 GPUs available.
22-
N_SIMS = 2
23-
24-
25-
def setup_sim(device_id):
26-
model = load_model_from_path("xmls/fetch/main.xml")
27-
sim = MjSim(model)
28-
29-
image = sim.render(
30-
IMAGE_WIDTH, IMAGE_HEIGHT, device_id=device_id)
31-
assert image.shape == (IMAGE_HEIGHT, IMAGE_WIDTH, 3)
32-
33-
return sim
34-
35-
36-
def update_sim(sim, device_id):
37-
sim.step()
38-
return sim.render(IMAGE_WIDTH, IMAGE_HEIGHT, device_id=device_id)
8+
from mujoco_py import load_model_from_path, MjRenderPool
399

4010

4111
def main():
42-
print("main(): create processes", flush=True)
43-
processes = []
44-
for device_id in range(N_SIMS):
45-
p = RenderProcess(
46-
device_id, setup_sim, update_sim,
47-
(IMAGE_HEIGHT, IMAGE_WIDTH, 3))
48-
processes.append(p)
12+
# Image size for rendering
13+
IMAGE_WIDTH = 255
14+
IMAGE_HEIGHT = 255
15+
# Number of frames to render per sim
16+
N_FRAMES = 2000
17+
# Number of sims to run in parallel (assumes one per GPU),
18+
# so N_SIMS=2 assumes there are 2 GPUs available.
19+
N_SIMS = 2
4920

50-
for p in processes:
51-
p.wait()
21+
pool = MjRenderPool(load_model_from_path("xmls/tosser.xml"), device_ids=2)
5222

5323
print("main(): start benchmarking", flush=True)
5424
start_t = perf_counter()
5525

5626
for _ in range(N_FRAMES):
57-
for p in processes:
58-
p.update()
59-
60-
for p in processes:
61-
p.wait()
62-
63-
for p in processes:
64-
p.read()
27+
pool.render(IMAGE_WIDTH, IMAGE_HEIGHT)
6528

6629
t = perf_counter() - start_t
6730
print("Completed in %.1fs: %.3fms, %.1f FPS" % (
6831
t, t / (N_FRAMES * N_SIMS) * 1000, (N_FRAMES * N_SIMS) / t),
6932
flush=True)
7033

71-
print("main(): stopping processes", flush=True)
72-
for p in processes:
73-
p.stop()
74-
7534
print("main(): finished", flush=True)
7635

7736

78-
print("XXX about to call main()", flush=True)
7937
if __name__ == "__main__":
80-
set_start_method('fork')
38+
set_start_method('spawn')
8139
main()

mujoco_py/mjrenderpool.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def _worker_init(mjb_bytes, worker_id, device_ids,
122122
s.sim = MjSim(load_model_from_mjb(mjb_bytes))
123123

124124
if modder is not None:
125-
s.modder = modder(s.sim)
125+
s.modder = modder(s.sim, random_state=proc_worker_id)
126+
s.modder.whiten_materials()
126127
else:
127128
s.modder = None
128129

@@ -142,7 +143,7 @@ def _worker_render(worker_id, state, width, height,
142143
s.sim.set_state(state)
143144
forward = True
144145
if randomize and s.modder is not None:
145-
s.modder.rand_all()
146+
s.modder.randomize()
146147
forward = True
147148
if forward:
148149
s.sim.forward()
@@ -162,7 +163,7 @@ def _worker_render(worker_id, state, width, height,
162163
device_id=s.device_id)
163164

164165
def render(self, width, height, states=None, camera_name=None,
165-
depth=False, randomize=False):
166+
depth=False, randomize=False, copy=True):
166167
"""
167168
Renders the simulations in batch. If no states are provided,
168169
the max_batch_size will be used.
@@ -175,6 +176,7 @@ def render(self, width, height, states=None, camera_name=None,
175176
- camera_name (str): name of camera to render from.
176177
- depth (bool): if True, also return depth.
177178
- randomize (bool): calls modder.rand_all() before rendering.
179+
- copy (bool): return a copy rather than a reference
178180
179181
Returns:
180182
- rgbs: NxHxWx3 numpy array of N images in batch of width W
@@ -206,11 +208,11 @@ def render(self, width, height, states=None, camera_name=None,
206208
for i, state in enumerate(states)])
207209

208210
rgbs = self._shared_rgbs_array[:width * height * 3 * batch_size]
209-
rgbs = rgbs.reshape(batch_size, height, width, 3)
211+
rgbs = rgbs.reshape(batch_size, height, width, 3).copy()
210212

211213
if depth:
212214
depths = self._shared_depths_array[:width * height * batch_size]
213-
depths = depths.reshape(batch_size, height, width)
215+
depths = depths.reshape(batch_size, height, width).copy()
214216
return rgbs, depths
215217
else:
216218
return rgbs

mujoco_py/modder.py

+7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ def __init__(self, sim, random_state=None):
1313
self.sim = sim
1414
if random_state is None:
1515
self.random_state = np.random.RandomState()
16+
elif isinstance(random_state, int):
17+
# random_state assumed to be an int
18+
self.random_state = np.random.RandomState(random_state)
1619
else:
1720
self.random_state = random_state
1821

@@ -308,6 +311,10 @@ def set_noise(self, name, rgb1, rgb2, fraction=0.9):
308311
self.upload_texture(name)
309312
return bitmap
310313

314+
def randomize(self):
315+
for name in self.sim.model.geom_names:
316+
self.rand_all(name)
317+
311318
def rand_all(self, name):
312319
choices = [
313320
self.rand_checker,

mujoco_py/tests/test_render_pool.py

+69-5
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,40 @@
22
import subprocess
33
import sys
44

5+
import numpy as np
56
import pytest
67

78
from mujoco_py import MjSim, MjRenderPool, load_model_from_xml
8-
from mujoco_py.tests.utils import requires_rendering
9-
9+
from mujoco_py.modder import TextureModder
10+
from mujoco_py.tests.utils import compare_imgs, requires_rendering
1011

1112
BASIC_MODEL_XML = """
1213
<mujoco>
1314
<worldbody>
1415
<light name="light1" diffuse=".5 .5 .5" pos="0 0 3" dir="0 0 -1"/>
1516
<camera name="camera1" pos="3 0 0" zaxis="1 0 0" />
16-
<geom name="geom1" pos="0.5 0.4 0.3" type="plane" size="1 1 0.1" rgba=".9 0 0 1"/>
17+
<geom name="g1" pos="0.5 0.4 0.3" type="plane" size="1 1 0.1" rgba="1 1 1 1" material="m1" />
1718
<body pos="0 0 1" name="body1">
1819
<joint name="joint1" type="free"/>
19-
<geom name="geom2" pos="0 1 0" type="box" size=".1 .2 .3" rgba="0 .9 0 1"/>
20+
<geom name="g2" pos="0 1 0" type="box" size=".1 .2 .3" rgba="1 1 1 1" material="m2" />
2021
<site name="site1" pos="1 0 0" size="0.1" type="sphere"/>
2122
<site name="sensorsurf" pos="0 0.045 0" size=".03 .03 .03" type="ellipsoid" rgba="0.3 0.2 0.1 0.3"/>
2223
</body>
2324
<body pos="1 0 0" name="mocap1" mocap="true">
24-
<geom conaffinity="0" contype="0" pos="0 0 0" size="0.01 0.01 0.01" type="box"/>
25+
<geom name="g3" conaffinity="0" contype="0" pos="0 0 0" size="0.01 0.01 0.01" type="box" material="m3" rgba="1 1 1 1"/>
2526
</body>
2627
</worldbody>
2728
<sensor>
2829
<touch name="touchsensor" site="sensorsurf" />
2930
</sensor>
31+
<asset>
32+
<texture name="t1" width="33" height="36" type="2d" builtin="flat" />
33+
<texture name="t2" width="34" height="39" type="2d" builtin="flat" />
34+
<texture name="t3" width="31" height="37" type="2d" builtin="flat" />
35+
<material name="m1" texture="t1" />
36+
<material name="m2" texture="t2" />
37+
<material name="m3" texture="t3" />
38+
</asset>
3039
</mujoco>
3140
"""
3241

@@ -64,11 +73,59 @@ def mp_test_rendering():
6473

6574
images = pool.render(100, 100)
6675
assert images.shape == (3, 100, 100, 3)
76+
compare_imgs(images[0], 'test_render_pool.mp_test_rendering.0.png')
77+
assert np.all(images[0] == images[1])
6778

6879
images, depth = pool.render(101, 103, depth=True)
6980
assert images.shape == (3, 103, 101, 3)
7081
assert depth.shape == (3, 103, 101)
82+
assert np.all(images[0] == images[1])
83+
assert np.all(images[1] == images[2])
84+
85+
86+
def mp_test_cameras():
87+
model = load_model_from_xml(BASIC_MODEL_XML)
88+
pool = MjRenderPool(model, n_workers=1)
89+
90+
image = pool.render(100, 100)
91+
assert image.shape == (1, 100, 100, 3)
92+
compare_imgs(image[0], 'test_render_pool.mp_test_cameras.0.png')
93+
94+
image = pool.render(100, 100, camera_name='camera1')
95+
assert image.shape == (1, 100, 100, 3)
96+
compare_imgs(image[0], 'test_render_pool.mp_test_cameras.1.png')
97+
98+
99+
def mp_test_modder():
100+
model = load_model_from_xml(BASIC_MODEL_XML)
101+
pool = MjRenderPool(model, n_workers=2, modder=TextureModder)
102+
103+
images = pool.render(100, 100, randomize=True)
104+
assert images.shape == (2, 100, 100, 3)
71105

106+
# the order of the images aren't guaranteed to be consistent
107+
# between the render runs
108+
images1 = pool.render(100, 100, randomize=False)
109+
assert images1.shape == (2, 100, 100, 3)
110+
111+
if np.all(images[0] == images1[0]) and np.all(images[1] == images1[1]):
112+
images_same = True
113+
elif np.all(images[0] == images1[1]) and np.all(images[1] == images1[0]):
114+
images_same = True
115+
else:
116+
images_same = False
117+
assert images_same
118+
119+
images2 = pool.render(100, 100, randomize=True)
120+
assert images2.shape == (2, 100, 100, 3)
121+
122+
if np.all(images[0] == images2[0]) and np.all(images[1] == images2[1]):
123+
images_same = True
124+
elif np.all(images[0] == images2[1]) and np.all(images[1] == images2[0]):
125+
images_same = True
126+
else:
127+
images_same = False
128+
assert not images_same
72129

73130
def mp_test_states():
74131
sim = MjSim(load_model_from_xml(BASIC_MODEL_XML))
@@ -82,10 +139,15 @@ def mp_test_states():
82139

83140
images = pool.render(100, 100, states=states[:2])
84141
assert images.shape == (2, 100, 100, 3)
142+
compare_imgs(images[0], 'test_render_pool.mp_test_states.1.png')
143+
compare_imgs(images[1], 'test_render_pool.mp_test_states.2.png')
85144

86145
states = list(reversed(states))
87146
images = pool.render(100, 100, states=states)
88147
assert images.shape == (3, 100, 100, 3)
148+
compare_imgs(images[0], 'test_render_pool.mp_test_states.3.png')
149+
compare_imgs(images[1], 'test_render_pool.mp_test_states.4.png')
150+
compare_imgs(images[2], 'test_render_pool.mp_test_states.5.png')
89151

90152

91153
if __name__ == '__main__':
@@ -95,3 +157,5 @@ def mp_test_states():
95157
mp_test_create_destroy()
96158
mp_test_rendering()
97159
mp_test_states()
160+
mp_test_cameras()
161+
mp_test_modder()

mujoco_py/tests/utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def save_test_image(filename, array):
1414
Image.fromarray(array).save(filename)
1515

1616

17-
def compare_imgs(img, truth_filename):
17+
def compare_imgs(img, truth_filename, do_assert=True):
1818
"""
1919
PROTIP: run the following to re-generate the test images:
2020
@@ -30,6 +30,7 @@ def compare_imgs(img, truth_filename):
3030
backup_path = "%s_old%s" % (pre_path, ext)
3131
move(truth_filename, backup_path)
3232
save_test_image(truth_filename, img)
33+
return 0
3334
true_img = np.asarray(Image.open(truth_filename))
3435
assert img.shape == true_img.shape
3536
hash0 = imagehash.dhash(Image.fromarray(img))
@@ -43,7 +44,9 @@ def compare_imgs(img, truth_filename):
4344
save_test_image("/tmp/img.png", img)
4445
save_test_image("/tmp/true_img.png", true_img)
4546
save_test_image("/tmp/diff_img.png", img - true_img)
46-
assert diff <= 1
47+
if do_assert:
48+
assert diff <= 1
49+
return diff
4750

4851

4952
# Skips test when RENDERING_OFF.

0 commit comments

Comments
 (0)