Skip to content

Commit da04688

Browse files
quaglacopybara-github
authored andcommitted
Automatically compile mjSpec in to_xml().
Also, move mj_compile from MjModelWrapper to MjSpec. PiperOrigin-RevId: 737640210 Change-Id: I93d2a31c33eb359452b0d966a29939183b4c5198
1 parent 3577470 commit da04688

File tree

7 files changed

+54
-61
lines changed

7 files changed

+54
-61
lines changed

doc/APIreference/functions.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ Free last XML model if loaded. Called internally at each load.
9999
.. mujoco-include:: mj_saveXMLString
100100

101101
Save spec to XML string, return 0 on success, -1 on failure. If the length of the output buffer is too small, returns
102-
the required size. XML saving requires that the spec first be compiled.
102+
the required size. XML saving automatically compiles the spec before saving.
103103

104104
.. _mj_saveXML:
105105

doc/APIreference/functions_override.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ saving mechanisms.
5959
.. _mj_saveXMLString:
6060

6161
Save spec to XML string, return 0 on success, -1 on failure. If the length of the output buffer is too small, returns
62-
the required size. XML saving requires that the spec first be compiled.
62+
the required size. XML saving automatically compiles the spec before saving.
6363

6464
.. _mj_saveXML:
6565

python/mujoco/specs.cc

+42-30
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <array>
1616
#include <cstddef> // IWYU pragma: keep
1717
#include <cstdint>
18+
#include <cstring>
1819
#include <memory>
1920
#include <optional>
2021
#include <string>
@@ -122,6 +123,41 @@ struct MjSpec {
122123
~MjSpec() {
123124
mj_deleteSpec(ptr);
124125
}
126+
127+
raw::MjModel* Compile() {
128+
if (assets.empty()) {
129+
auto m = mj_compile(ptr, 0);
130+
if (!m || mjs_isWarning(ptr)) {
131+
throw py::value_error(mjs_getError(ptr));
132+
}
133+
return m;
134+
}
135+
mjVFS vfs;
136+
mj_defaultVFS(&vfs);
137+
for (const auto& asset : assets) {
138+
std::string buffer_name =
139+
_impl::StripPath(py::cast<std::string>(asset.first).c_str());
140+
std::string buffer = py::cast<std::string>(asset.second);
141+
const int vfs_error = InterceptMjErrors(mj_addBufferVFS)(
142+
&vfs, buffer_name.c_str(), buffer.c_str(), buffer.size());
143+
if (vfs_error) {
144+
mj_deleteVFS(&vfs);
145+
if (vfs_error == 2) {
146+
throw py::value_error("Repeated file name in assets dict: " +
147+
buffer_name);
148+
} else {
149+
throw py::value_error("Asset failed to load: " + buffer_name);
150+
}
151+
}
152+
}
153+
auto m = mj_compile(ptr, &vfs);
154+
if (!m || mjs_isWarning(ptr)) {
155+
throw py::value_error(mjs_getError(ptr));
156+
}
157+
mj_deleteVFS(&vfs);
158+
return m;
159+
}
160+
125161
raw::MjSpec* ptr;
126162
py::dict assets;
127163
bool override_assets = true;
@@ -245,8 +281,8 @@ void SetFrame(raw::MjsBody* body, mjtObj objtype, raw::MjsFrame* frame) {
245281

246282
PYBIND11_MODULE(_specs, m) {
247283
auto structs_m = py::module::import("mujoco._structs");
248-
py::function mjmodel_from_spec_ptr =
249-
structs_m.attr("MjModel").attr("_from_spec_ptr");
284+
py::function mjmodel_from_raw_ptr =
285+
structs_m.attr("MjModel").attr("_from_model_ptr");
250286
py::function mjmodel_mjdata_from_spec_ptr =
251287
structs_m.attr("_recompile_spec_addr");
252288

@@ -416,33 +452,8 @@ PYBIND11_MODULE(_specs, m) {
416452
return mjs_findDefault(self.ptr, classname.c_str());
417453
},
418454
py::return_value_policy::reference_internal);
419-
mjSpec.def("compile", [mjmodel_from_spec_ptr](MjSpec& self) -> py::object {
420-
if (self.assets.empty()) {
421-
return mjmodel_from_spec_ptr(reinterpret_cast<uintptr_t>(self.ptr));
422-
}
423-
mjVFS vfs;
424-
mj_defaultVFS(&vfs);
425-
for (const auto& asset : self.assets) {
426-
std::string buffer_name =
427-
_impl::StripPath(py::cast<std::string>(asset.first).c_str());
428-
std::string buffer = py::cast<std::string>(asset.second);
429-
const int vfs_error = InterceptMjErrors(mj_addBufferVFS)(
430-
&vfs, buffer_name.c_str(), buffer.c_str(), buffer.size());
431-
if (vfs_error) {
432-
mj_deleteVFS(&vfs);
433-
if (vfs_error == 2) {
434-
throw py::value_error("Repeated file name in assets dict: " +
435-
buffer_name);
436-
} else {
437-
throw py::value_error("Asset failed to load: " + buffer_name);
438-
}
439-
}
440-
}
441-
auto model =
442-
mjmodel_from_spec_ptr(reinterpret_cast<uintptr_t>(self.ptr),
443-
reinterpret_cast<uintptr_t>(&vfs));
444-
mj_deleteVFS(&vfs);
445-
return model;
455+
mjSpec.def("compile", [mjmodel_from_raw_ptr](MjSpec& self) -> py::object {
456+
return mjmodel_from_raw_ptr(reinterpret_cast<uintptr_t>(self.Compile()));
446457
});
447458
mjSpec.def_property(
448459
"assets",
@@ -463,9 +474,10 @@ PYBIND11_MODULE(_specs, m) {
463474
self.override_assets = override_assets;
464475
});
465476
mjSpec.def("to_xml", [](MjSpec& self) -> std::string {
477+
mj_deleteModel(self.Compile());
478+
std::array<char, 1024> err;
466479
int size = mj_saveXMLString(self.ptr, nullptr, 0, nullptr, 0);
467480
std::unique_ptr<char[]> buf(new char[size + 1]);
468-
std::array<char, 1024> err;
469481
buf[0] = '\0';
470482
err[0] = '\0';
471483
mj_saveXMLString(self.ptr, buf.get(), size + 1, err.data(), err.size());

python/mujoco/specs_test.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -520,13 +520,9 @@ def test_recompile(self):
520520
np.testing.assert_array_equal(data_new.qpos[:4], data.qpos)
521521
np.testing.assert_array_equal(data_new.qvel[:3], data.qvel)
522522

523-
def test_uncompiled_spec_cannot_be_written(self):
523+
def test_uncompiled_spec_can_be_written(self):
524524
spec = mujoco.MjSpec()
525-
526-
# Cannot write XML of an uncompiled spec.
527-
expected_error = 'XML Write error: Only compiled model can be written'
528-
with self.assertRaisesWithLiteralMatch(mujoco.FatalError, expected_error):
529-
spec.to_xml()
525+
spec.to_xml()
530526

531527
def test_modelname_default_class(self):
532528
XML = textwrap.dedent("""\

python/mujoco/structs.cc

+4-18
Original file line numberDiff line numberDiff line change
@@ -408,12 +408,7 @@ MjModelWrapper MjModelWrapper::LoadXML(
408408
return MjModelWrapper(model);
409409
}
410410

411-
MjModelWrapper MjModelWrapper::CompileSpec(raw::MjSpec* spec,
412-
const mjVFS* vfs) {
413-
auto m = mj_compile(spec, vfs);
414-
if (!m || mjs_isWarning(spec)) {
415-
throw py::value_error(mjs_getError(spec));
416-
}
411+
MjModelWrapper MjModelWrapper::WrapRawModel(raw::MjModel* m) {
417412
return MjModelWrapper(m);
418413
}
419414

@@ -1616,18 +1611,9 @@ PYBIND11_MODULE(_structs, m) {
16161611
py::arg("xml"), py::arg_v("assets", py::none()),
16171612
py::doc(
16181613
R"(Loads an MjModel from an XML string and an optional assets dictionary.)"));
1619-
mjModel.def_static(
1620-
"_from_spec_ptr", [](uintptr_t addr) {
1621-
return MjModelWrapper::CompileSpec(
1622-
reinterpret_cast<raw::MjSpec*>(addr),
1623-
nullptr);
1624-
});
1625-
mjModel.def_static(
1626-
"_from_spec_ptr", [](uintptr_t addr, uintptr_t vfs) {
1627-
return MjModelWrapper::CompileSpec(
1628-
reinterpret_cast<raw::MjSpec*>(addr),
1629-
reinterpret_cast<mjVFS*>(vfs));
1630-
});
1614+
mjModel.def_static("_from_model_ptr", [](uintptr_t addr) {
1615+
return MjModelWrapper::WrapRawModel(reinterpret_cast<raw::MjModel*>(addr));
1616+
});
16311617
mjModel.def_static(
16321618
"from_xml_path", &MjModelWrapper::LoadXMLFile,
16331619
py::arg("filename"), py::arg_v("assets", py::none()),

python/mujoco/structs.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ class MjWrapper<raw::MjModel> : public WrapperBase<raw::MjModel> {
533533
const std::optional<
534534
std::unordered_map<std::string, pybind11::bytes>>& assets);
535535

536-
static MjWrapper CompileSpec(raw::MjSpec* spec, const mjVFS* vfs);
536+
static MjWrapper WrapRawModel(raw::MjModel* m);
537537

538538
static constexpr char kFromRawPointer[] =
539539
"__MUJOCO_STRUCTS_MJMODELWRAPPER_LOOKUP";

src/xml/xml_api.cc

+3-4
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,14 @@ int mj_saveXML(const mjSpec* s, const char* filename, char* error, int error_sz)
246246
// if length of the output buffer is too small, returns the required size
247247
int mj_saveXMLString(const mjSpec* s, char* xml, int xml_sz, char* error, int error_sz) {
248248
std::string result = WriteXML(NULL, s, error, error_sz);
249-
if (result.size() >= xml_sz) {
249+
if (result.empty()) {
250+
return -1;
251+
} else if (result.size() >= xml_sz) {
250252
std::string error_msg = "Output string too short, should be at least " +
251253
std::to_string(result.size()+1);
252254
mjCopyError(error, error_msg.c_str(), error_sz);
253255
return result.size();
254256
}
255-
if (result.empty()) {
256-
return -1;
257-
}
258257

259258
result.copy(xml, xml_sz);
260259
xml[result.size()] = 0;

0 commit comments

Comments
 (0)