Skip to content

Commit e6d8886

Browse files
NixGDAdamGleaveZiyueWang25
authored
Add partial support for dictionary observation spaces (bc, density) (#785)
* first pass of dict obs functionality * cleanup DictObs * add dict space to test_types.py, fix some problems * add dict-obs test for rollout * add bc.py test * cleanup * small fixes * small fixes * fix type error in interactive.py * fix introduced error in mce_irl.py * fix minor ci complaint * add basic dictobs tests * change default bc policy for dict obs space * refine rollout.py typechecks, comments * check rollout produces dictobs of correct shape * cleanup types and dictobs helpers * clean useless lines * clean up print statements * fix typos Co-authored-by: Adam Gleave <[email protected]> * assert matching keys in from_obs_list * move maybe_wrap, clean rollout * change policy callable to take dict[str, np.ndarray] not dictobs * rollout info wrapper supports dictobs * fix from_obs_list key consistency check * xfail save/load tests with dictobs * doc for dictobs wrapper * don't error on int observations * lint fixes * cleanup bc test for dict obs * cleanup bc.py unwrapping * cleanup rollout.py * cleanup dictobs interface * small cleanups * coverage fixes, test fix * adjust error types * docstrings for type helpers * add dict obs space support for density * fix typos Co-authored-by: Adam Gleave <[email protected]> * Adam suggestions from code review Co-authored-by: Adam Gleave <[email protected]> * small changes for code review * fix docstring * remove FloatReward * Fix test_bc * Turn off GPU finding to avoid using gpu device * Check None to ensure __add__ can work * fix docstring * bypass pytype and lint test * format with black * Test dict space in density algo * black format * small fix * Add DictObs into test_wrappers * fix format * minor fix * type and lint fix * Add policy training test * suppress line too long lint check on a line * acts to obs for clarity * Add HumanReadableWrapper * fix dict env observation space * adjust wrapper and not set render_mode inside * Add additional obs check * Upgrade pytype and remove workaround for old versions * Fix test_rollout test * add RemoveHumanReadableWrapper and update ob space * Revert "add RemoveHumanReadableWrapper and update ob space" This reverts commit ee83ec5. * Revert "adjust wrapper and not set render_mode inside" This reverts commit a9b32bd. * Revert "fix dict env observation space" This reverts commit ba6a6a7. * Revert "Add HumanReadableWrapper" This reverts commit 6e5c3e8. * Revert "acts to obs for clarity" This reverts commit be79cf5. * address comments * new pytype need input directory or file * fix np.dtype * ignore typed-dict-error * context manager related fix * keep pytype checking more failures * Revert "keep pytype checking more failures" This reverts commit f5288c6. * Revert "context manager related fix" This reverts commit 5c1d751. * Revert "ignore typed-dict-error" This reverts commit 5c6e5b8. * Revert "fix np.dtype" This reverts commit 6884538. * Revert "new pytype need input directory or file" This reverts commit 15541cd. * Revert "Upgrade pytype and remove workaround for old versions" This reverts commit 194ec1a. * lint fix * fix type check * fix lint --------- Co-authored-by: Adam Gleave <[email protected]> Co-authored-by: ZiyueWang25 <[email protected]>
1 parent 573b086 commit e6d8886

25 files changed

+881
-195
lines changed

src/imitation/algorithms/bc.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import (
1010
Any,
1111
Callable,
12+
Dict,
1213
Iterable,
1314
Iterator,
1415
Mapping,
@@ -22,7 +23,7 @@
2223
import numpy as np
2324
import torch as th
2425
import tqdm
25-
from stable_baselines3.common import policies, utils, vec_env
26+
from stable_baselines3.common import policies, torch_layers, utils, vec_env
2627

2728
from imitation.algorithms import base as algo_base
2829
from imitation.data import rollout, types
@@ -99,7 +100,12 @@ class BehaviorCloningLossCalculator:
99100
def __call__(
100101
self,
101102
policy: policies.ActorCriticPolicy,
102-
obs: Union[th.Tensor, np.ndarray],
103+
obs: Union[
104+
types.AnyTensor,
105+
types.DictObs,
106+
Dict[str, np.ndarray],
107+
Dict[str, th.Tensor],
108+
],
103109
acts: Union[th.Tensor, np.ndarray],
104110
) -> BCTrainingMetrics:
105111
"""Calculate the supervised learning loss used to train the behavioral clone.
@@ -113,9 +119,18 @@ def __call__(
113119
A BCTrainingMetrics object with the loss and all the components it
114120
consists of.
115121
"""
116-
obs = util.safe_to_tensor(obs)
122+
tensor_obs = types.map_maybe_dict(
123+
util.safe_to_tensor,
124+
types.maybe_unwrap_dictobs(obs),
125+
)
117126
acts = util.safe_to_tensor(acts)
118-
_, log_prob, entropy = policy.evaluate_actions(obs, acts)
127+
128+
# policy.evaluate_actions's type signatures are incorrect.
129+
# See https://github.com/DLR-RM/stable-baselines3/issues/1679
130+
(_, log_prob, entropy) = policy.evaluate_actions(
131+
tensor_obs, # type: ignore[arg-type]
132+
acts,
133+
)
119134
prob_true_act = th.exp(log_prob).mean()
120135
log_prob = log_prob.mean()
121136
entropy = entropy.mean() if entropy is not None else None
@@ -324,12 +339,18 @@ def __init__(
324339
self.rng = rng
325340

326341
if policy is None:
342+
extractor = (
343+
torch_layers.CombinedExtractor
344+
if isinstance(observation_space, gym.spaces.Dict)
345+
else torch_layers.FlattenExtractor
346+
)
327347
policy = policy_base.FeedForward32Policy(
328348
observation_space=observation_space,
329349
action_space=action_space,
330350
# Set lr_schedule to max value to force error if policy.optimizer
331351
# is used by mistake (should use self.optimizer instead).
332352
lr_schedule=lambda _: th.finfo(th.float32).max,
353+
features_extractor_class=extractor,
333354
)
334355
self._policy = policy.to(utils.get_device(device))
335356
# TODO(adam): make policy mandatory and delete observation/action space params?
@@ -464,9 +485,14 @@ def process_batch():
464485
minibatch_size,
465486
num_samples_so_far,
466487
), batch in batches_with_stats:
467-
obs = th.as_tensor(batch["obs"], device=self.policy.device).detach()
468-
acts = th.as_tensor(batch["acts"], device=self.policy.device).detach()
469-
training_metrics = self.loss_calculator(self.policy, obs, acts)
488+
obs_tensor: Union[th.Tensor, Dict[str, th.Tensor]]
489+
# unwraps the observation if it's a dictobs and converts arrays to tensors
490+
obs_tensor = types.map_maybe_dict(
491+
lambda x: util.safe_to_tensor(x, device=self.policy.device),
492+
types.maybe_unwrap_dictobs(batch["obs"]),
493+
)
494+
acts = util.safe_to_tensor(batch["acts"], device=self.policy.device)
495+
training_metrics = self.loss_calculator(self.policy, obs_tensor, acts)
470496

471497
# Renormalise the loss to be averaged over the whole
472498
# batch size instead of the minibatch size.

src/imitation/algorithms/density.py

Lines changed: 52 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ def __init__(
134134

135135
def _get_demo_from_batch(
136136
self,
137-
obs_b: np.ndarray,
137+
obs_b: types.Observation,
138138
act_b: np.ndarray,
139-
next_obs_b: Optional[np.ndarray],
139+
next_obs_b: Optional[types.Observation],
140140
) -> Dict[Optional[int], List[np.ndarray]]:
141141
if next_obs_b is None and self.density_type == DensityType.STATE_STATE_DENSITY:
142142
raise ValueError(
@@ -145,11 +145,18 @@ def _get_demo_from_batch(
145145
)
146146

147147
assert act_b.shape[1:] == self.venv.action_space.shape
148-
assert obs_b.shape[1:] == self.venv.observation_space.shape
148+
ob_space = self.venv.observation_space
149+
if isinstance(obs_b, types.DictObs):
150+
exp_shape = {
151+
k: v.shape for k, v in ob_space.items() # type: ignore[attr-defined]
152+
}
153+
obs_shape = {k: v.shape[1:] for k, v in obs_b.items()}
154+
assert exp_shape == obs_shape, f"Expected {exp_shape}, got {obs_shape}"
155+
else:
156+
assert obs_b.shape[1:] == ob_space.shape
149157
assert len(act_b) == len(obs_b)
150158
if next_obs_b is not None:
151-
assert next_obs_b.shape[1:] == self.venv.observation_space.shape
152-
assert len(next_obs_b) == len(obs_b)
159+
assert next_obs_b.shape == obs_b.shape
153160

154161
if next_obs_b is not None:
155162
next_obs_b_iterator: Iterable = next_obs_b
@@ -200,14 +207,17 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None:
200207
# analogous to cast above.
201208
demonstrations = cast(Iterable[types.TransitionMapping], demonstrations)
202209

210+
def to_np_maybe_dictobs(x):
211+
if isinstance(x, types.DictObs):
212+
return x
213+
else:
214+
return util.safe_to_numpy(x, warn=True)
215+
203216
for batch in demonstrations:
204-
transitions.update(
205-
self._get_demo_from_batch(
206-
util.safe_to_numpy(batch["obs"], warn=True),
207-
util.safe_to_numpy(batch["acts"], warn=True),
208-
util.safe_to_numpy(batch.get("next_obs"), warn=True),
209-
),
210-
)
217+
obs = to_np_maybe_dictobs(batch["obs"])
218+
acts = util.safe_to_numpy(batch["acts"], warn=True)
219+
next_obs = to_np_maybe_dictobs(batch.get("next_obs"))
220+
transitions.update(self._get_demo_from_batch(obs, acts, next_obs))
211221
else:
212222
raise TypeError(
213223
f"Unsupported demonstration type {type(demonstrations)}",
@@ -253,65 +263,40 @@ def _fit_density(self, transitions: np.ndarray) -> neighbors.KernelDensity:
253263

254264
def _preprocess_transition(
255265
self,
256-
obs: np.ndarray,
266+
obs: types.Observation,
257267
act: np.ndarray,
258-
next_obs: Optional[np.ndarray],
268+
next_obs: Optional[types.Observation],
259269
) -> np.ndarray:
260270
"""Compute flattened transition on subset specified by `self.density_type`."""
271+
flattened_obs = space_utils.flatten(
272+
self.venv.observation_space,
273+
types.maybe_unwrap_dictobs(obs),
274+
)
275+
flattened_obs = _check_data_is_np_array(flattened_obs, "observation")
261276
if self.density_type == DensityType.STATE_DENSITY:
262-
flat_observations = space_utils.flatten(self.venv.observation_space, obs)
263-
if not isinstance(flat_observations, np.ndarray):
264-
raise ValueError(
265-
"The density estimator only supports spaces that "
266-
"flatten to a numpy array but the observation space "
267-
f"flattens to {type(flat_observations)}",
268-
)
269-
270-
return flat_observations
277+
return flattened_obs
271278
elif self.density_type == DensityType.STATE_ACTION_DENSITY:
272-
flat_observation = space_utils.flatten(self.venv.observation_space, obs)
273-
flat_action = space_utils.flatten(self.venv.action_space, act)
274-
275-
if not isinstance(flat_observation, np.ndarray):
276-
raise ValueError(
277-
"The density estimator only supports spaces that "
278-
"flatten to a numpy array but the observation space "
279-
f"flattens to {type(flat_observation)}",
280-
)
281-
if not isinstance(flat_action, np.ndarray):
282-
raise ValueError(
283-
"The density estimator only supports spaces that "
284-
"flatten to a numpy array but the action space "
285-
f"flattens to {type(flat_action)}",
286-
)
287-
288-
return np.concatenate([flat_observation, flat_action])
279+
flattened_action = space_utils.flatten(self.venv.action_space, act)
280+
flattened_action = _check_data_is_np_array(flattened_action, "action")
281+
return np.concatenate([flattened_obs, flattened_action])
289282
elif self.density_type == DensityType.STATE_STATE_DENSITY:
290283
assert next_obs is not None
291-
flat_observation = space_utils.flatten(self.venv.observation_space, obs)
292-
flat_next_observation = space_utils.flatten(
284+
flat_next_obs = space_utils.flatten(
293285
self.venv.observation_space,
294-
next_obs,
286+
types.maybe_unwrap_dictobs(next_obs),
295287
)
288+
flat_next_obs = _check_data_is_np_array(flat_next_obs, "observation")
289+
assert type(flattened_obs) is type(flat_next_obs)
296290

297-
if not isinstance(flat_observation, np.ndarray):
298-
raise ValueError(
299-
"The density estimator only supports spaces that "
300-
"flatten to a numpy array but the observation space "
301-
f"flattens to {type(flat_observation)}",
302-
)
303-
304-
assert type(flat_observation) is type(flat_next_observation)
305-
306-
return np.concatenate([flat_observation, flat_next_observation])
291+
return np.concatenate([flattened_obs, flat_next_obs])
307292
else:
308293
raise ValueError(f"Unknown density type {self.density_type}")
309294

310295
def __call__(
311296
self,
312-
state: np.ndarray,
297+
state: types.Observation,
313298
action: np.ndarray,
314-
next_state: np.ndarray,
299+
next_state: types.Observation,
315300
done: np.ndarray,
316301
steps: Optional[np.ndarray] = None,
317302
) -> np.ndarray:
@@ -347,6 +332,8 @@ def __call__(
347332

348333
rew_list = []
349334
assert len(state) == len(action) and len(state) == len(next_state)
335+
state = types.maybe_wrap_in_dictobs(state)
336+
next_state = types.maybe_wrap_in_dictobs(next_state)
350337
for idx, (obs, act, next_obs) in enumerate(zip(state, action, next_state)):
351338
flat_trans = self._preprocess_transition(obs, act, next_obs)
352339
assert self._scaler is not None
@@ -424,3 +411,13 @@ def policy(self) -> base_class.BasePolicy:
424411
assert self.rl_algo is not None
425412
assert self.rl_algo.policy is not None
426413
return self.rl_algo.policy
414+
415+
416+
def _check_data_is_np_array(data: space_utils.FlatType, name: str) -> np.ndarray:
417+
"""Raises error if the flattened data is not a numpy array."""
418+
assert isinstance(data, np.ndarray), (
419+
"The density estimator only supports spaces that "
420+
f"flatten to a numpy array but the {name} space "
421+
f"flattens to {type(data)}",
422+
)
423+
return data

src/imitation/algorithms/mce_irl.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,18 @@
77
"""
88
import collections
99
import warnings
10-
from typing import Any, Iterable, List, Mapping, NoReturn, Optional, Tuple, Type, Union
10+
from typing import (
11+
Any,
12+
Dict,
13+
Iterable,
14+
List,
15+
Mapping,
16+
NoReturn,
17+
Optional,
18+
Tuple,
19+
Type,
20+
Union,
21+
)
1122

1223
import gymnasium as gym
1324
import numpy as np
@@ -347,7 +358,7 @@ def _set_demo_from_trajectories(self, trajs: Iterable[types.Trajectory]) -> None
347358
num_demos = 0
348359
for traj in trajs:
349360
cum_discount = 1.0
350-
for obs in traj.obs:
361+
for obs in types.assert_not_dictobs(traj.obs):
351362
self.demo_state_om[obs] += cum_discount
352363
cum_discount *= self.discount
353364
num_demos += 1
@@ -411,23 +422,32 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None:
411422

412423
if isinstance(demonstrations, types.Transitions):
413424
self._set_demo_from_obs(
414-
demonstrations.obs,
425+
types.assert_not_dictobs(demonstrations.obs),
415426
demonstrations.dones,
416-
demonstrations.next_obs,
427+
types.assert_not_dictobs(demonstrations.next_obs),
417428
)
418429
elif isinstance(demonstrations, types.TransitionsMinimal):
419-
self._set_demo_from_obs(demonstrations.obs, None, None)
430+
self._set_demo_from_obs(
431+
types.assert_not_dictobs(demonstrations.obs),
432+
None,
433+
None,
434+
)
420435
elif isinstance(demonstrations, Iterable):
421436
# Demonstrations are a Torch DataLoader or other Mapping iterable
422437
# Collect them together into one big NumPy array. This is inefficient,
423438
# we could compute the running statistics instead, but in practice do
424439
# not expect large dataset sizes together with MCE IRL.
425-
collated_list = collections.defaultdict(list)
440+
collated_list: Dict[
441+
str,
442+
List[types.AnyTensor],
443+
] = collections.defaultdict(list)
426444
for batch in demonstrations:
427445
assert isinstance(batch, Mapping)
428446
for k in ("obs", "dones", "next_obs"):
429-
if k in batch:
430-
collated_list[k].append(batch[k])
447+
x = batch.get(k)
448+
if x is not None:
449+
assert isinstance(x, (np.ndarray, th.Tensor))
450+
collated_list[k].append(x)
431451
collated = {k: np.concatenate(v) for k, v in collated_list.items()}
432452

433453
assert "obs" in collated

src/imitation/algorithms/preference_comparisons.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,9 +465,9 @@ def rewards(self, transitions: Transitions) -> th.Tensor:
465465
Shape - (num_transitions, ) for Single reward network and
466466
(num_transitions, num_networks) for ensemble of networks.
467467
"""
468-
state = transitions.obs
468+
state = types.assert_not_dictobs(transitions.obs)
469469
action = transitions.acts
470-
next_state = transitions.next_obs
470+
next_state = types.assert_not_dictobs(transitions.next_obs)
471471
done = transitions.dones
472472
if self.ensemble_model is not None:
473473
rews_np = self.ensemble_model.predict_processed_all(

src/imitation/data/buffer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Buffers to store NumPy arrays and transitions in."""
22

3-
import dataclasses
43
from typing import Any, Mapping, Optional, Tuple
54

65
import numpy as np
@@ -368,15 +367,16 @@ def from_data(
368367
Returns:
369368
A new ReplayBuffer.
370369
"""
371-
obs_shape = transitions.obs.shape[1:]
370+
obs = types.assert_not_dictobs(transitions.obs)
371+
obs_shape = obs.shape[1:]
372372
act_shape = transitions.acts.shape[1:]
373373
if capacity is None:
374-
capacity = transitions.obs.shape[0]
374+
capacity = obs.shape[0]
375375
instance = cls(
376376
capacity=capacity,
377377
obs_shape=obs_shape,
378378
act_shape=act_shape,
379-
obs_dtype=transitions.obs.dtype,
379+
obs_dtype=obs.dtype,
380380
act_dtype=transitions.acts.dtype,
381381
)
382382
instance.store(transitions, truncate_ok=truncate_ok)
@@ -406,7 +406,7 @@ def store(self, transitions: types.Transitions, truncate_ok: bool = True) -> Non
406406
Raises:
407407
ValueError: The arguments didn't have the same length.
408408
""" # noqa: DAR402
409-
trans_dict = dataclasses.asdict(transitions)
409+
trans_dict = types.dataclass_quick_asdict(transitions)
410410
# Remove unnecessary fields
411411
trans_dict = {k: trans_dict[k] for k in self._buffer.sample_shapes.keys()}
412412
self._buffer.store(trans_dict, truncate_ok=truncate_ok)

src/imitation/data/huggingface_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def trajectories_to_dict(
124124
],
125125
terminal=[traj.terminal for traj in trajectories],
126126
)
127+
if any(isinstance(traj.obs, types.DictObs) for traj in trajectories):
128+
raise ValueError("DictObs are not currently supported")
127129

128130
# Encode infos as jsonpickled strings
129131
trajectory_dict["infos"] = [

0 commit comments

Comments
 (0)