Skip to content

Commit cc17519

Browse files
committed
added multigpu rendering support
1 parent 6d04ef5 commit cc17519

11 files changed

+310
-76
lines changed

examples/multigpu_rendering.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
This is an example for rendering on multiple GPUs in parallel,
3+
using the multiprocessing module.
4+
"""
5+
import ctypes
6+
from multiprocessing import (
7+
Process, set_start_method, freeze_support, Condition,
8+
Value, Array)
9+
from time import perf_counter
10+
import numpy as np
11+
from mujoco_py import load_model_from_path, MjSim
12+
13+
14+
#
15+
# Experiment parameters
16+
#
17+
18+
# Image size for rendering
19+
IMAGE_WIDTH = 255
20+
IMAGE_HEIGHT = 255
21+
# Number of frames to render per sim
22+
N_FRAMES = 2000
23+
# Number of sims to run in parallel (assumes one per GPU),
24+
# so N_SIMS=2 assumes there are 2 GPUs available.
25+
N_SIMS = 2
26+
27+
28+
class RenderProcess:
29+
"""
30+
Wraps a multiprocessing.Process for rendering. Assumes there
31+
is one MjSim per process.
32+
"""
33+
34+
def __init__(self, device_id, setup_sim, update_sim, output_var_shape):
35+
"""
36+
Args:
37+
- device_id (int): GPU device to use for rendering (0-indexed)
38+
- setup_sim (callback): callback that is given a device_id and
39+
returns a MjSim. It is responsible for making MjSim render
40+
to given device.
41+
- update_sim (callback): callback given a sim and device_id, and
42+
should return a numpy array of shape `output_var_shape`.
43+
- output_var_shape (tuple): shape of the synchronized output
44+
array from `update_sim`.
45+
"""
46+
self.device_id = device_id
47+
self.setup_sim = setup_sim
48+
self.update_sim = update_sim
49+
50+
# Create a synchronized output variable (numpy array)
51+
self._shared_output_var = Array(
52+
ctypes.c_double, int(np.prod(output_var_shape)))
53+
self._output_var = np.frombuffer(
54+
self._shared_output_var.get_obj())
55+
56+
# Number of variables used to communicate with process
57+
self._cv = Condition()
58+
self._ready = Value('b', 0)
59+
self._start = Value('b', 0)
60+
self._terminate = Value('b', 0)
61+
62+
# Start the actual process
63+
self._process = Process(target=self._run)
64+
self._process.start()
65+
66+
def wait(self):
67+
""" Wait for process to be ready for another update call. """
68+
with self._cv:
69+
if self._start.value:
70+
self._cv.wait()
71+
if self._ready.value:
72+
return
73+
self._cv.wait()
74+
75+
def read(self, copy=False):
76+
""" Reads the output variable. Returns a copy if copy=True. """
77+
if copy:
78+
with self._shared_output_var.get_lock():
79+
return np.copy(self._output_var)
80+
else:
81+
return self._output_var
82+
83+
def update(self):
84+
""" Calls update_sim asynchronously. """
85+
with self._cv:
86+
self._start.value = 1
87+
self._cv.notify()
88+
89+
def stop(self):
90+
""" Tells process to stop and waits for it to terminate. """
91+
with self._cv:
92+
self._terminate.value = 1
93+
self._cv.notify()
94+
self._process.join()
95+
96+
def _run(self):
97+
sim = self.setup_sim(self.device_id)
98+
99+
while True:
100+
with self._cv:
101+
self._ready.value = 1
102+
self._cv.notify_all()
103+
104+
with self._cv:
105+
if not self._start.value and not self._terminate.value:
106+
self._cv.wait()
107+
if self._terminate.value:
108+
break
109+
assert self._start.value
110+
self._start.value = 0
111+
112+
# Run the update and assign output variable
113+
with self._shared_output_var.get_lock():
114+
self._output_var[:] = self.update_sim(
115+
sim, self.device_id).ravel()
116+
117+
118+
def setup_sim(device_id):
119+
model = load_model_from_path("xmls/fetch/main.xml")
120+
sim = MjSim(model)
121+
122+
image = sim.render(
123+
IMAGE_WIDTH, IMAGE_HEIGHT, device_id=device_id)
124+
assert image.shape == (IMAGE_HEIGHT, IMAGE_WIDTH, 3)
125+
126+
return sim
127+
128+
129+
def update_sim(sim, device_id):
130+
sim.step()
131+
return sim.render(IMAGE_WIDTH, IMAGE_HEIGHT, device_id=device_id)
132+
133+
134+
def main():
135+
print("main(): create processes", flush=True)
136+
processes = []
137+
for device_id in range(N_SIMS):
138+
p = RenderProcess(
139+
device_id, setup_sim, update_sim,
140+
(IMAGE_HEIGHT, IMAGE_WIDTH, 3))
141+
processes.append(p)
142+
143+
for p in processes:
144+
p.wait()
145+
146+
print("main(): start benchmarking", flush=True)
147+
start_t = perf_counter()
148+
149+
for _ in range(N_FRAMES):
150+
for p in processes:
151+
p.update()
152+
153+
for p in processes:
154+
p.wait()
155+
156+
for p in processes:
157+
p.read()
158+
159+
t = perf_counter() - start_t
160+
print("Completed in %.1fs: %.3fms, %.1f FPS" % (
161+
t, t / (N_FRAMES * N_SIMS) * 1000, (N_FRAMES * N_SIMS) / t),
162+
flush=True)
163+
164+
print("main(): stopping processes", flush=True)
165+
for p in processes:
166+
p.stop()
167+
168+
print("main(): finished", flush=True)
169+
170+
171+
if __name__ == "__main__":
172+
# freeze_support()
173+
set_start_method('spawn')
174+
main()

mujoco_py/__init__.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99
MjSim = cymj.MjSim
1010
MjSimState = cymj.MjSimState
1111
MjSimPool = cymj.MjSimPool
12-
PyMjModel = cymj.PyMjModel
13-
PyMjData = cymj.PyMjData
12+
MjRenderContext = cymj.MjRenderContext
13+
MjRenderContextOffscreen = cymj.MjRenderContextOffscreen
14+
MjRenderContextWindow = cymj.MjRenderContextWindow
15+
1416

1517
# Public API:
16-
__all__ = ['MjSim', 'MjSimState', 'MjSimPool', 'MjViewer', "MjViewerBasic", "MujocoException",
17-
'load_model_from_path', 'load_model_from_xml', 'load_model_from_mjb',
18+
__all__ = ['MjSim', 'MjSimState', 'MjSimPool',
19+
'MjRenderContextOffscreen', 'MjRenderContextWindow',
20+
'MjRenderContext', 'MjViewer', 'MjViewerBasic',
21+
'MujocoException',
22+
'load_model_from_path', 'load_model_from_xml',
23+
'load_model_from_mjb',
1824
'ignore_mujoco_warnings', 'const', "functions",
1925
"__version__", "get_version"]

mujoco_py/builder.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from distutils.dist import Distribution
99
from distutils.sysconfig import customize_compiler
1010
from os.path import abspath, dirname, exists, join, getmtime
11+
from shutil import move
1112

1213
import numpy as np
1314
from Cython.Build import cythonize
@@ -50,8 +51,10 @@ def load_cython_ext(mjpro_path):
5051
raise RuntimeError("Unsupported platform %s" % sys.platform)
5152

5253
builder = Builder(mjpro_path)
53-
54-
cext_so_path = builder.build()
54+
cext_so_path = builder.get_so_file_path()
55+
if not exists(cext_so_path):
56+
print("building")
57+
cext_so_path = builder.build()
5558
mod = imp.load_dynamic("cymj", cext_so_path)
5659
return mod
5760

@@ -114,6 +117,12 @@ def __init__(self, mjpro_path):
114117
language='c')
115118

116119
def build(self):
120+
built_so_file_path = self._build_impl()
121+
new_so_file_path = self.get_so_file_path()
122+
move(built_so_file_path, new_so_file_path)
123+
return new_so_file_path
124+
125+
def _build_impl(self):
117126
dist = Distribution({
118127
"script_name": None,
119128
"script_args": ["build_ext"]
@@ -129,8 +138,13 @@ def build(self):
129138
dist.parse_command_line()
130139
obj_build_ext = dist.get_command_obj("build_ext")
131140
dist.run_commands()
132-
so_file_path, = obj_build_ext.get_outputs()
133-
return so_file_path
141+
built_so_file_path, = obj_build_ext.get_outputs()
142+
return built_so_file_path
143+
144+
def get_so_file_path(self):
145+
dir_path = abspath(dirname(__file__))
146+
return join(dir_path, "generated", "cymj_%s.so" % (
147+
self.__class__.__name__.lower()))
134148

135149

136150
class WindowsExtensionBuilder(MujocoExtensionBuilder):
@@ -162,8 +176,8 @@ def __init__(self, mjpro_path):
162176
self.extension.libraries.extend(['glewegl'])
163177
self.extension.runtime_library_dirs = [join(mjpro_path, 'bin')]
164178

165-
def build(self):
166-
so_file_path = super().build()
179+
def _build_impl(self):
180+
so_file_path = super()._build_impl()
167181
nvidia_lib_dir = '/usr/local/nvidia/lib64/'
168182
fix_shared_library(so_file_path, 'libOpenGL.so',
169183
join(nvidia_lib_dir, 'libOpenGL.so.0'))
@@ -182,7 +196,7 @@ def __init__(self, mjpro_path):
182196
self.extension.define_macros = [('ONMAC', None)]
183197
self.extension.runtime_library_dirs = [join(mjpro_path, 'bin')]
184198

185-
def build(self):
199+
def _build_impl(self):
186200
# Prefer GCC 6 for now since GCC 7 may behave differently.
187201
c_compilers = ['/usr/local/bin/gcc-6', '/usr/local/bin/gcc-7']
188202
available_c_compiler = None
@@ -197,7 +211,7 @@ def build(self):
197211
'`brew install gcc --without-multilib`.')
198212
os.environ['CC'] = available_c_compiler
199213

200-
so_file_path = super().build()
214+
so_file_path = super()._build_impl()
201215
del os.environ['CC']
202216
return self.manually_link_libraries(so_file_path)
203217

mujoco_py/cymj.pyx

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ cdef extern from "gl/glshim.h":
2828

2929
cdef int initOpenGL(int device_id)
3030
cdef void closeOpenGL()
31-
cdef int setOpenGLBufferSize(int width, int height)
31+
cdef int makeOpenGLContextCurrent(int device_id)
32+
cdef int setOpenGLBufferSize(int device_id, int width, int height)
3233

3334
# TODO: make this function or class so these comments turn into doc strings:
3435

mujoco_py/gl/dummyshim.c

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ int initOpenGL(int device_id) {
44
return 1;
55
}
66

7-
int setOpenGLBufferSize(int width, int height) {
7+
int setOpenGLBufferSize(int device_id, int width, int height) {
8+
return 1;
9+
}
10+
11+
int makeOpenGLContextCurrent(int device_id) {
812
return 1;
913
}
1014

mujoco_py/gl/eglshim.c

+50-17
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,18 @@
66

77
#define MAX_DEVICES 8
88

9-
int is_initialized = 0;
9+
int is_device_initialized[MAX_DEVICES] = {0};
10+
EGLDisplay eglDisplays[MAX_DEVICES];
11+
EGLContext eglContexts[MAX_DEVICES];
1012

1113
int initOpenGL(int device_id)
1214
{
15+
if (device_id < 0 || device_id > MAX_DEVICES) {
16+
printf("Device id outside of range.\n");
17+
return -1;
18+
}
19+
int is_initialized = is_device_initialized[device_id];
20+
1321
if (is_initialized)
1422
return 1;
1523

@@ -94,34 +102,59 @@ int initOpenGL(int device_id)
94102
return -8;
95103
}
96104

97-
is_initialized = 1;
105+
is_device_initialized[device_id] = 1;
106+
eglDisplays[device_id] = eglDpy;
107+
eglContexts[device_id] = eglCtx;
98108
return 1;
99109
}
100110

101-
int setOpenGLBufferSize(int width, int height) {
111+
int makeOpenGLContextCurrent(device_id) {
112+
if (device_id < 0 || device_id > MAX_DEVICES) {
113+
printf("Device id outside of range.\n");
114+
return -1;
115+
}
116+
if (!is_device_initialized[device_id])
117+
return -2;
118+
119+
if( eglMakeCurrent(eglDisplays[device_id],
120+
EGL_NO_SURFACE,
121+
EGL_NO_SURFACE,
122+
eglContexts[device_id]) != EGL_TRUE ) {
123+
eglDestroyContext(eglDisplays[device_id],
124+
eglContexts[device_id]);
125+
printf("Could not make EGL context current\n");
126+
return -3;
127+
} else {
128+
return 1;
129+
}
130+
}
131+
132+
int setOpenGLBufferSize(int device_id, int width, int height) {
102133
// Noop since we don't need to change buffer here.
103134
return 1;
104135
}
105136

106137
void closeOpenGL()
107138
{
108-
if (!is_initialized)
109-
return;
139+
for (int device_id=0; device_id<MAX_DEVICES; device_id++) {
140+
if (!is_device_initialized[device_id])
141+
continue;
110142

111-
EGLDisplay eglDpy = eglGetCurrentDisplay();
112-
if( eglDpy==EGL_NO_DISPLAY )
113-
return;
143+
EGLDisplay eglDpy = eglDisplays[device_id];
144+
if( eglDpy==EGL_NO_DISPLAY )
145+
continue;
114146

115-
// get current context
116-
EGLContext eglCtx = eglGetCurrentContext();
147+
// get current context
148+
EGLContext eglCtx = eglContexts[device_id];
117149

118-
// release context
119-
eglMakeCurrent(eglDpy, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT);
150+
// release context
151+
eglMakeCurrent(eglDpy, EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT);
120152

121-
// destroy context if valid
122-
if( eglCtx!=EGL_NO_CONTEXT )
123-
eglDestroyContext(eglDpy, eglCtx);
153+
// destroy context if valid
154+
if( eglCtx!=EGL_NO_CONTEXT )
155+
eglDestroyContext(eglDpy, eglCtx);
124156

125-
// terminate display
126-
eglTerminate(eglDpy);
157+
// terminate display
158+
eglTerminate(eglDpy);
159+
}
127160
}

0 commit comments

Comments
 (0)