Skip to content

Commit 157b074

Browse files
quaglacopybara-github
authored andcommitted
Add signature to mjSpec and mjModel and use it to perform safe bind to mjModel and mjData.
PiperOrigin-RevId: 740378879 Change-Id: If14b326942529494f172e7aedcae30195798b458
1 parent c931565 commit 157b074

File tree

15 files changed

+165
-2
lines changed

15 files changed

+165
-2
lines changed

doc/includes/references.h

+7
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,9 @@ struct mjData_ {
414414

415415
// thread pool pointer
416416
uintptr_t threadpool;
417+
418+
// compilation signature
419+
uint64_t signature; // also held by the mjSpec that compiled the model
417420
};
418421
typedef struct mjData_ mjData;
419422
typedef enum mjtDisableBit_ { // disable default feature bitflags
@@ -1451,6 +1454,9 @@ struct mjModel_ {
14511454

14521455
// paths
14531456
char* paths; // paths to assets, 0-terminated (npaths x 1)
1457+
1458+
// compilation signature
1459+
uint64_t signature; // also held by the mjSpec that compiled this model
14541460
};
14551461
typedef struct mjModel_ mjModel;
14561462
struct mjResource_ {
@@ -1713,6 +1719,7 @@ typedef enum mjtOrientation_ { // type of orientation specifier
17131719
} mjtOrientation;
17141720
typedef struct mjsElement_ { // element type, do not modify
17151721
mjtObj elemtype; // element type
1722+
uint64_t signature; // compilation signature
17161723
} mjsElement;
17171724
typedef struct mjsCompiler_ { // compiler options
17181725
mjtByte autolimits; // infer "limited" attribute based on range

include/mujoco/mjdata.h

+3
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,9 @@ struct mjData_ {
442442

443443
// thread pool pointer
444444
uintptr_t threadpool;
445+
446+
// compilation signature
447+
uint64_t signature; // also held by the mjSpec that compiled the model
445448
};
446449
typedef struct mjData_ mjData;
447450

include/mujoco/mjmodel.h

+3
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,9 @@ struct mjModel_ {
11551155

11561156
// paths
11571157
char* paths; // paths to assets, 0-terminated (npaths x 1)
1158+
1159+
// compilation signature
1160+
uint64_t signature; // also held by the mjSpec that compiled this model
11581161
};
11591162
typedef struct mjModel_ mjModel;
11601163

include/mujoco/mjspec.h

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
// this is a C-API
2424
#ifdef __cplusplus
2525
#include <cstddef>
26+
#include <cstdint>
2627
#include <string>
2728
#include <vector>
2829

@@ -119,6 +120,7 @@ typedef enum mjtOrientation_ { // type of orientation specifier
119120

120121
typedef struct mjsElement_ { // element type, do not modify
121122
mjtObj elemtype; // element type
123+
uint64_t signature; // compilation signature
122124
} mjsElement;
123125

124126

python/mujoco/codegen/generate_spec_bindings.py

+15
Original file line numberDiff line numberDiff line change
@@ -613,12 +613,27 @@ def generate_find() -> None:
613613
print(code)
614614

615615

616+
def generate_signature() -> None:
617+
"""Generate signature functions."""
618+
for key, _, _, _, _ in SPECS:
619+
elem = key.removeprefix('mjs')
620+
titlecase = 'Mjs' + elem
621+
code = f"""\n
622+
{key}.def_property_readonly("signature",
623+
[](raw::{titlecase}& self) -> uint64_t {{
624+
return mjs_getSpec(self.element)->element->signature;
625+
}});
626+
"""
627+
print(code)
628+
629+
616630
def main(argv: Sequence[str]) -> None:
617631
if len(argv) > 1:
618632
raise app.UsageError('Too many command-line arguments.')
619633
generate()
620634
generate_add()
621635
generate_find()
636+
generate_signature()
622637

623638

624639
if __name__ == '__main__':

python/mujoco/introspect/structs.py

+15
Original file line numberDiff line numberDiff line change
@@ -4454,6 +4454,11 @@
44544454
doc='paths to assets, 0-terminated',
44554455
array_extent=('npaths',),
44564456
),
4457+
StructFieldDecl(
4458+
name='signature',
4459+
type=ValueType(name='uint64_t'),
4460+
doc='also held by the mjSpec that compiled this model',
4461+
),
44574462
),
44584463
)),
44594464
('mjThreadPool',
@@ -6037,6 +6042,11 @@
60376042
type=ValueType(name='uintptr_t'),
60386043
doc='thread pool pointer',
60396044
),
6045+
StructFieldDecl(
6046+
name='signature',
6047+
type=ValueType(name='uint64_t'),
6048+
doc='also held by the mjSpec that compiled the model',
6049+
),
60406050
),
60416051
)),
60426052
('mjvPerturb',
@@ -9044,6 +9054,11 @@
90449054
type=ValueType(name='mjtObj'),
90459055
doc='element type',
90469056
),
9057+
StructFieldDecl(
9058+
name='signature',
9059+
type=ValueType(name='uint64_t'),
9060+
doc='compilation signature',
9061+
),
90479062
),
90489063
)),
90499064
('mjsCompiler',

python/mujoco/specs_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,13 @@ def test_bind(self):
11531153
AttributeError, "object has no attribute 'invalid'"
11541154
):
11551155
print(mj_model.bind(joints).invalid)
1156+
invalid_spec = mujoco.MjSpec()
1157+
invalid_spec.worldbody.add_body(name='main')
1158+
with self.assertRaisesRegex(
1159+
ValueError,
1160+
'The mjSpec does not match mjModel. Please recompile the mjSpec.',
1161+
):
1162+
print(mj_model.bind(invalid_spec.body('main')))
11561163

11571164
def test_incorrect_hfield_size(self):
11581165
nrow = 300

python/mujoco/structs.cc

+22-2
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,10 @@ This is useful for example when the MJB is not available as a file on disk.)"));
17101710
// Return the full bytes array of concatenated paths
17111711
return m.paths_bytes;
17121712
});
1713+
mjModel.def_property_readonly(
1714+
"signature", [](const MjModelWrapper& m) -> const uint64_t& {
1715+
return m.get()->signature;
1716+
});
17131717

17141718
#define XGROUP(MjModelGroupedViews, field, nfield, FIELD_XMACROS) \
17151719
mjModel.def( \
@@ -1730,7 +1734,13 @@ This is useful for example when the MJB is not available as a file on disk.)"));
17301734
mjModel.def( \
17311735
"bind_scalar", \
17321736
[](MjModelWrapper& m, spectype& spec) -> auto& { \
1733-
return m.indexer().field##_by_name(mjs_getString(spec.name)); \
1737+
if (mjs_getSpec(spec.element)->element->signature != \
1738+
m.get()->signature) { \
1739+
throw py::value_error( \
1740+
"The mjSpec does not match mjModel. Please recompile " \
1741+
"the mjSpec."); \
1742+
} \
1743+
return m.indexer().field(mjs_getId(spec.element)); \
17341744
}, \
17351745
py::return_value_policy::reference_internal, \
17361746
py::arg_v("spec", py::none()));
@@ -2018,6 +2028,10 @@ This is useful for example when the MJB is not available as a file on disk.)"));
20182028
std::istringstream input(b, std::ios::in | std::ios::binary);
20192029
return MjDataWrapper::Deserialize(input);
20202030
}));
2031+
mjData.def_property_readonly(
2032+
"signature", [](const MjDataWrapper& d) -> uint64_t {
2033+
return d.get()->signature;
2034+
});
20212035

20222036
#define X(type, var) \
20232037
mjData.def_property( \
@@ -2076,7 +2090,13 @@ This is useful for example when the MJB is not available as a file on disk.)"));
20762090
mjData.def( \
20772091
"bind_scalar", \
20782092
[](MjDataWrapper& d, spectype& spec) -> auto& { \
2079-
return d.indexer().field##_by_name(mjs_getString(spec.name)); \
2093+
if (mjs_getSpec(spec.element)->element->signature != \
2094+
d.get()->signature) { \
2095+
throw py::value_error( \
2096+
"The mjSpec does not match mjData. Please recompile "\
2097+
"the mjSpec."); \
2098+
} \
2099+
return d.indexer().field(mjs_getId(spec.element)); \
20802100
}, \
20812101
py::return_value_policy::reference_internal, \
20822102
py::arg_v("spec", py::none()));

src/engine/engine_io.c

+3
Original file line numberDiff line numberDiff line change
@@ -2014,6 +2014,9 @@ static void _resetData(const mjModel* m, mjData* d, unsigned char debug_value) {
20142014
}
20152015
}
20162016
}
2017+
2018+
// copy signature from model
2019+
d->signature = m->signature;
20172020
}
20182021

20192022

src/user/user_flexcomp.cc

+2
Original file line numberDiff line numberDiff line change
@@ -403,12 +403,14 @@ bool mjCFlexcomp::Make(mjsBody* body, char* error, int error_sz) {
403403
mjCFlex* flex = model->AddFlex();
404404
mjsFlex* pf = &flex->spec;
405405
int id = flex->id;
406+
int uid = flex->uid;
406407

407408
*flex = def.Flex();
408409
flex->PointToLocal();
409410

410411
flex->model = model;
411412
flex->id = id;
413+
flex->uid = uid;
412414
mjs_setString(pf->name, name.c_str());
413415
mjs_setInt(pf->elem, element.data(), element.size());
414416
mjs_setFloat(pf->texcoord, texcoord.data(), texcoord.size());

src/user/user_model.cc

+40
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ mjCModel::mjCModel() {
201201
world->mass = 0;
202202
mjuu_zerovec(world->inertia, 3);
203203
world->id = 0;
204+
world->uid = GetUid();
204205
world->parent = nullptr;
205206
world->weldid = 0;
206207
world->name = "world";
@@ -213,6 +214,9 @@ mjCModel::mjCModel() {
213214

214215
// the source spec is the model itself, overwritten in the copy constructor
215216
source_spec_ = &spec;
217+
218+
// set the signature
219+
spec.element->signature = 0;
216220
}
217221

218222

@@ -289,6 +293,7 @@ void mjCModel::CopyList(std::vector<T*>& dest,
289293
// copy the element from the other model to this model
290294
if (deepcopy_) {
291295
source[i]->ForgetKeyframes();
296+
candidate->uid = GetUid();
292297
} else {
293298
candidate->AddRef();
294299
}
@@ -499,6 +504,9 @@ mjCModel& mjCModel::operator+=(const mjCModel& other) {
499504
nq = nv = na = nu = nmocap = 0;
500505
}
501506

507+
// update signature before we reset the tree lists
508+
spec.element->signature = Signature();
509+
502510
PointToLocal();
503511
return *this;
504512
}
@@ -633,6 +641,9 @@ mjCModel& mjCModel::operator-=(const mjCBody& subtree) {
633641
RemoveFromList(sensors_, oldmodel);
634642
RemovePlugins();
635643

644+
// update signature before we reset the tree lists
645+
spec.element->signature = Signature();
646+
636647
return *this;
637648
}
638649

@@ -803,6 +814,9 @@ void mjCModel::DeleteElement(mjsElement* el) {
803814
break;
804815
}
805816

817+
// update signature before we reset the tree lists
818+
spec.element->signature = Signature();
819+
806820
ResetTreeLists(); // in case of a nested delete
807821
MakeTreeLists();
808822
ProcessLists(/*checkrepeat=*/false);
@@ -1019,7 +1033,9 @@ template <class T>
10191033
T* mjCModel::AddObject(vector<T*>& list, string type) {
10201034
T* obj = new T(this);
10211035
obj->id = (int)list.size();
1036+
obj->uid = GetUid();
10221037
list.push_back(obj);
1038+
spec.element->signature = Signature();
10231039
return obj;
10241040
}
10251041

@@ -1030,7 +1046,9 @@ T* mjCModel::AddObjectDefault(vector<T*>& list, string type, mjCDef* def) {
10301046
T* obj = new T(this, def ? def : defaults_[0]);
10311047
obj->id = (int)list.size();
10321048
obj->classname = def ? def->name : "main";
1049+
obj->uid = GetUid();
10331050
list.push_back(obj);
1051+
spec.element->signature = Signature();
10341052
return obj;
10351053
}
10361054

@@ -4572,6 +4590,28 @@ void mjCModel::TryCompile(mjModel*& m, mjData*& d, const mjVFS* vfs) {
45724590
mju::strcpy_arr(errInfo.message, warningtext);
45734591
errInfo.warning = true;
45744592
}
4593+
4594+
// save signature
4595+
m->signature = Signature();
4596+
}
4597+
4598+
4599+
4600+
uint64_t mjCModel::Signature() {
4601+
std::string uid_str;
4602+
for (int i = 0; i < mjNOBJECT; ++i) {
4603+
if (i == mjOBJ_XBODY || i == mjOBJ_UNKNOWN || i == mjOBJ_DOF) {
4604+
continue;
4605+
}
4606+
if (object_lists_[i] == nullptr) {
4607+
throw mjCError(0, "object list %s is null", std::to_string(i).c_str());
4608+
}
4609+
uid_str += '|';
4610+
for (mjCBase* object : *object_lists_[i]) {
4611+
uid_str += std::to_string(object->uid) + " ";
4612+
}
4613+
}
4614+
return mj_hashString(uid_str.c_str(), UINT64_MAX);
45754615
}
45764616

45774617

src/user/user_model.h

+7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define MUJOCO_SRC_USER_USER_MODEL_H_
1717

1818
#include <array>
19+
#include <cstdint>
1920
#include <functional>
2021
#include <map>
2122
#include <string>
@@ -324,6 +325,9 @@ class mjCModel : public mjCModel_, private mjSpec {
324325
// set attached flag
325326
void SetAttached(bool deepcopy) { attached_ |= !deepcopy; }
326327

328+
// get new uid
329+
int GetUid() { return uid_count_++; }
330+
327331
private:
328332
// settings for each defaults class
329333
std::vector<mjCDef*> defaults_;
@@ -440,11 +444,14 @@ class mjCModel : public mjCModel_, private mjSpec {
440444
void MarkPluginInstance(std::unordered_map<std::string, bool>& instances,
441445
const std::vector<T*>& list);
442446

447+
// generate a signature for the model
448+
uint64_t Signature();
443449

444450
mjListKeyMap ids; // map from object names to ids
445451
mjCError errInfo; // last error info
446452
std::vector<mjKeyInfo> key_pending_; // attached keyframes
447453
bool deepcopy_; // copy objects when attaching
448454
bool attached_ = false; // true if model is attached to a parent model
455+
int uid_count_ = 0; // unique id count for all objects
449456
};
450457
#endif // MUJOCO_SRC_USER_USER_MODEL_H_

0 commit comments

Comments
 (0)