Skip to content

Commit dc96ed6

Browse files
quaglacopybara-github
authored andcommitted
Remove plugins that are not referenced after detaching a body.
Fixes #2497. Also, use `Release()` instead of `delete` for removing elements when detaching a body in order to preserve correct reference count. PiperOrigin-RevId: 739133888 Change-Id: I7ccfad84446fb15259bad30c4ba1e6f7fa601518
1 parent 8d92a0f commit dc96ed6

File tree

3 files changed

+156
-75
lines changed

3 files changed

+156
-75
lines changed

src/user/user_model.cc

+91-44
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,49 @@ bool IsNullPose(const T pos[3], const T quat[4]) {
116116
}
117117

118118

119+
// set ids, check for repeated names
120+
template <class T>
121+
static void processlist(mjListKeyMap& ids, vector<T*>& list,
122+
mjtObj type, bool checkrepeat = true) {
123+
// assign ids for regular elements
124+
if (type < mjNOBJECT) {
125+
for (size_t i=0; i < list.size(); i++) {
126+
// check for incompatible id setting; SHOULD NOT OCCUR
127+
if (list[i]->id != -1 && list[i]->id != i) {
128+
throw mjCError(list[i], "incompatible id in %s array, position %d", mju_type2Str(type), i);
129+
}
130+
131+
// id equals position in array
132+
list[i]->id = i;
133+
134+
// add to ids map
135+
ids[type][list[i]->name] = i;
136+
}
137+
}
138+
139+
// check for repeated names
140+
if (checkrepeat) {
141+
// created vectors with all names
142+
vector<string> allnames;
143+
for (size_t i=0; i < list.size(); i++) {
144+
if (!list[i]->name.empty()) {
145+
allnames.push_back(list[i]->name);
146+
}
147+
}
148+
149+
// sort and check for duplicates
150+
if (allnames.size() > 1) {
151+
std::sort(allnames.begin(), allnames.end());
152+
auto adjacent = std::adjacent_find(allnames.begin(), allnames.end());
153+
if (adjacent != allnames.end()) {
154+
string msg = "repeated name '" + *adjacent + "' in " + mju_type2Str(type);
155+
throw mjCError(nullptr, "%s", msg.c_str());
156+
}
157+
}
158+
}
159+
}
160+
161+
119162
} // namespace
120163

121164
//---------------------------------- CONSTRUCTOR AND DESTRUCTOR ------------------------------------
@@ -490,7 +533,7 @@ void mjCModel::RemoveFromList(std::vector<T*>& list, const mjCModel& other) {
490533
element->ResolveReferences(this);
491534
} catch (mjCError err) {
492535
ids[element->elemtype].erase(element->name);
493-
delete element;
536+
element->Release();
494537
list.erase(list.begin() + i);
495538
nlist--;
496539
i--;
@@ -515,6 +558,52 @@ void mjCModel::DeleteAll<mjCKey>(std::vector<mjCKey*>& elements) {
515558

516559

517560

561+
template <class T>
562+
void mjCModel::MarkPluginInstance(std::unordered_map<std::string, bool>& instances,
563+
const std::vector<T*>& list) {
564+
for (const auto& element : list) {
565+
if (!element->plugin_instance_name.empty()) {
566+
instances[element->plugin_instance_name] = true;
567+
}
568+
}
569+
}
570+
571+
572+
573+
void mjCModel::RemovePlugins() {
574+
// store elements that reference a plugin instance
575+
std::unordered_map<std::string, bool> instances;
576+
MarkPluginInstance(instances, bodies_);
577+
MarkPluginInstance(instances, geoms_);
578+
MarkPluginInstance(instances, meshes_);
579+
MarkPluginInstance(instances, actuators_);
580+
MarkPluginInstance(instances, sensors_);
581+
582+
// remove plugins that are not referenced
583+
int nlist = (int)plugins_.size();
584+
int removed = 0;
585+
for (int i = 0; i < nlist; i++) {
586+
if (plugins_[i]->name.empty()) {
587+
continue;
588+
}
589+
if (instances.find(plugins_[i]->name) == instances.end()) {
590+
ids[plugins_[i]->elemtype].erase(plugins_[i]->name);
591+
plugins_[i]->Release();
592+
plugins_.erase(plugins_.begin() + i);
593+
nlist--;
594+
i--;
595+
removed++;
596+
}
597+
}
598+
599+
// if any elements were removed, update ids using processlist
600+
if (removed > 0 && !plugins_.empty()) {
601+
processlist(ids, plugins_, plugins_[0]->elemtype, /*checkrepeat=*/false);
602+
}
603+
}
604+
605+
606+
518607
mjCModel& mjCModel::operator-=(const mjCBody& subtree) {
519608
mjCModel oldmodel(*this);
520609

@@ -550,6 +639,7 @@ mjCModel& mjCModel::operator-=(const mjCBody& subtree) {
550639
RemoveFromList(equalities_, oldmodel);
551640
RemoveFromList(actuators_, oldmodel);
552641
RemoveFromList(sensors_, oldmodel);
642+
RemovePlugins();
553643

554644
// restore to the original state
555645
if (!compiled) {
@@ -3923,49 +4013,6 @@ static void reassignid(vector<T*>& list) {
39234013
}
39244014

39254015

3926-
// set ids, check for repeated names
3927-
template <class T>
3928-
static void processlist(mjListKeyMap& ids, vector<T*>& list,
3929-
mjtObj type, bool checkrepeat = true) {
3930-
// assign ids for regular elements
3931-
if (type < mjNOBJECT) {
3932-
for (size_t i=0; i < list.size(); i++) {
3933-
// check for incompatible id setting; SHOULD NOT OCCUR
3934-
if (list[i]->id != -1 && list[i]->id != i) {
3935-
throw mjCError(list[i], "incompatible id in %s array, position %d", mju_type2Str(type), i);
3936-
}
3937-
3938-
// id equals position in array
3939-
list[i]->id = i;
3940-
3941-
// add to ids map
3942-
ids[type][list[i]->name] = i;
3943-
}
3944-
}
3945-
3946-
// check for repeated names
3947-
if (checkrepeat) {
3948-
// created vectors with all names
3949-
vector<string> allnames;
3950-
for (size_t i=0; i < list.size(); i++) {
3951-
if (!list[i]->name.empty()) {
3952-
allnames.push_back(list[i]->name);
3953-
}
3954-
}
3955-
3956-
// sort and check for duplicates
3957-
if (allnames.size() > 1) {
3958-
std::sort(allnames.begin(), allnames.end());
3959-
auto adjacent = std::adjacent_find(allnames.begin(), allnames.end());
3960-
if (adjacent != allnames.end()) {
3961-
string msg = "repeated name '" + *adjacent + "' in " + mju_type2Str(type);
3962-
throw mjCError(nullptr, "%s", msg.c_str());
3963-
}
3964-
}
3965-
}
3966-
}
3967-
3968-
39694016

39704017
// set object ids, check for repeated names
39714018
void mjCModel::ProcessLists(bool checkrepeat) {

src/user/user_model.h

+8
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ class mjCModel : public mjCModel_, private mjSpec {
350350
void CopyPlugins(mjModel*); // copy plugin data
351351
int CountNJmom(const mjModel* m); // compute number of non-zeros in actuator_moment matrix
352352

353+
// remove plugins that are not referenced by any object
354+
void RemovePlugins();
355+
353356
// objects created here
354357
std::vector<mjCFlex*> flexes_; // list of flexes
355358
std::vector<mjCMesh*> meshes_; // list of meshes
@@ -430,6 +433,11 @@ class mjCModel : public mjCModel_, private mjSpec {
430433
// return true if body has valid mass and inertia
431434
bool CheckBodyMassInertia(mjCBody* body);
432435

436+
// Mark plugin instances mentioned in the list
437+
template <class T>
438+
void MarkPluginInstance(std::unordered_map<std::string, bool>& instances,
439+
const std::vector<T*>& list);
440+
433441

434442
mjListKeyMap ids; // map from object names to ids
435443
mjCError errInfo; // last error info

test/user/user_api_test.cc

+57-31
Original file line numberDiff line numberDiff line change
@@ -230,41 +230,41 @@ TEST_F(PluginTest, DeletePlugin) {
230230
mj_deleteModel(newmodel);
231231
}
232232

233-
TEST_F(PluginTest, AttachPlugin) {
234-
static constexpr char xml_1[] = R"(
235-
<mujoco model="MuJoCo Model">
236-
<worldbody>
237-
<body name="body"/>
238-
</worldbody>
239-
</mujoco>)";
233+
static constexpr char xml_plugin_1[] = R"(
234+
<mujoco model="MuJoCo Model">
235+
<worldbody>
236+
<body name="body"/>
237+
</worldbody>
238+
</mujoco>)";
240239

241-
static constexpr char xml_2[] = R"(
242-
<mujoco model="MuJoCo Model">
243-
<extension>
244-
<plugin plugin="mujoco.pid">
245-
<instance name="actuator-1">
246-
<config key="ki" value="4.0"/>
247-
<config key="slewmax" value="3.14159"/>
248-
</instance>
249-
</plugin>
250-
</extension>
251-
<worldbody>
252-
<body name="empty"/>
253-
<body name="body">
254-
<joint name="joint"/>
255-
<geom size="0.1"/>
256-
</body>
257-
</worldbody>
258-
<actuator>
259-
<plugin name="actuator-1" plugin="mujoco.pid" instance="actuator-1"
260-
joint="joint" actdim="2"/>
261-
</actuator>
262-
</mujoco>)";
240+
static constexpr char xml_plugin_2[] = R"(
241+
<mujoco model="MuJoCo Model">
242+
<extension>
243+
<plugin plugin="mujoco.pid">
244+
<instance name="actuator-1">
245+
<config key="ki" value="4.0"/>
246+
<config key="slewmax" value="3.14159"/>
247+
</instance>
248+
</plugin>
249+
</extension>
250+
<worldbody>
251+
<body name="empty"/>
252+
<body name="body">
253+
<joint name="joint"/>
254+
<geom size="0.1"/>
255+
</body>
256+
</worldbody>
257+
<actuator>
258+
<plugin name="actuator-1" plugin="mujoco.pid" instance="actuator-1"
259+
joint="joint" actdim="2"/>
260+
</actuator>
261+
</mujoco>)";
263262

263+
TEST_F(PluginTest, AttachPlugin) {
264264
std::array<char, 1000> err;
265-
mjSpec* parent = mj_parseXMLString(xml_1, 0, err.data(), err.size());
265+
mjSpec* parent = mj_parseXMLString(xml_plugin_1, 0, err.data(), err.size());
266266
ASSERT_THAT(parent, NotNull()) << err.data();
267-
mjSpec* spec_1 = mj_parseXMLString(xml_2, 0, err.data(), err.size());
267+
mjSpec* spec_1 = mj_parseXMLString(xml_plugin_2, 0, err.data(), err.size());
268268
ASSERT_THAT(spec_1, NotNull()) << err.data();
269269

270270
// do a copy before attaching
@@ -306,6 +306,32 @@ TEST_F(PluginTest, AttachPlugin) {
306306
mj_deleteSpec(spec_3);
307307
}
308308

309+
TEST_F(PluginTest, DetachPlugin) {
310+
std::array<char, 1000> err;
311+
mjSpec* parent = mj_parseXMLString(xml_plugin_1, 0, err.data(), err.size());
312+
ASSERT_THAT(parent, NotNull()) << err.data();
313+
mjSpec* child = mj_parseXMLString(xml_plugin_2, 0, err.data(), err.size());
314+
ASSERT_THAT(child, NotNull()) << err.data();
315+
316+
// attach a body referencing the plugin to the frame
317+
mjsFrame* frame = mjs_addFrame(mjs_findBody(parent, "world"), 0);
318+
mjsBody* body = mjs_findBody(child, "body");
319+
EXPECT_THAT(mjs_attachBody(frame, body, "child-", ""), NotNull());
320+
321+
// detach the body and compile
322+
mjsBody* body_to_detach = mjs_findBody(parent, "child-body");
323+
EXPECT_THAT(body_to_detach, NotNull());
324+
EXPECT_THAT(mjs_detachBody(parent, body_to_detach), 0);
325+
mjModel* model = mj_compile(parent, nullptr);
326+
EXPECT_THAT(model, NotNull());
327+
EXPECT_THAT(model->nbody, 2);
328+
EXPECT_THAT(model->nplugin, 0);
329+
330+
mj_deleteModel(model);
331+
mj_deleteSpec(parent);
332+
mj_deleteSpec(child);
333+
}
334+
309335
TEST_F(PluginTest, AttachExplicitPlugin) {
310336
static constexpr char xml_parent[] = R"(
311337
<mujoco model="MuJoCo Model">

0 commit comments

Comments
 (0)