-
Notifications
You must be signed in to change notification settings - Fork 822
/
Copy pathmjsimstate.pyx
89 lines (75 loc) · 2.88 KB
/
mjsimstate.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class MjSimState(namedtuple('SimStateBase', 'time qpos qvel act udd_state')):
"""Represents a snapshot of the simulator's state.
This includes time, qpos, qvel, act, and udd_state.
"""
__slots__ = ()
# need to implement this because numpy doesn't support == on arrays
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
if set(self.udd_state.keys()) != set(other.udd_state.keys()):
return False
for k in self.udd_state.keys():
if isinstance(self.udd_state[k], Number) and self.udd_state[k] != other.udd_state[k]:
return False
elif not np.array_equal(self.udd_state[k], other.udd_state[k]):
return False
return (self.time == other.time and
np.array_equal(self.qpos, other.qpos) and
np.array_equal(self.qvel, other.qvel) and
np.array_equal(self.act, other.act))
def __ne__(self, other):
return not self.__eq__(other)
def flatten(self):
""" Flattens a state into a numpy array of numbers."""
if self.act is None:
act = np.empty(0)
else:
act = self.act
state_tuple = ([self.time], self.qpos, self.qvel, act,
MjSimState._flatten_dict(self.udd_state))
return np.concatenate(state_tuple)
@staticmethod
def _flatten_dict(d):
a = []
for k in sorted(d.keys()):
v = d[k]
if isinstance(v, Number):
a.extend([v])
else:
a.extend(v.ravel())
return np.array(a)
@staticmethod
def from_flattened(array, sim):
idx_time = 0
idx_qpos = idx_time + 1
idx_qvel = idx_qpos + sim.model.nq
idx_act = idx_qvel + sim.model.nv
idx_udd = idx_act + sim.model.na
time = array[idx_time]
qpos = array[idx_qpos:idx_qpos + sim.model.nq]
qvel = array[idx_qvel:idx_qvel + sim.model.nv]
if sim.model.na == 0:
act = None
else:
act = array[idx_act:idx_act + sim.model.na]
flat_udd_state = array[idx_udd:]
udd_state = MjSimState._unflatten_dict(flat_udd_state, sim.udd_state)
return MjSimState(time, qpos, qvel, act, udd_state)
@staticmethod
def _unflatten_dict(a, schema_example):
d = {}
idx = 0
for k in sorted(schema_example.keys()):
schema_val = schema_example[k]
if isinstance(schema_val, Number):
val = a[idx]
idx += 1
d[k] = val
else:
assert isinstance(schema_val, np.ndarray)
val_array = a[idx:idx+schema_val.size]
idx += schema_val.size
val = np.array(val_array).reshape(schema_val.shape)
d[k] = val
return d