Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
5182ecf
first pass of dict obs functionality
NixGD Sep 13, 2023
61d816b
cleanup DictObs
NixGD Sep 13, 2023
c3331f6
add dict space to test_types.py, fix some problems
NixGD Sep 14, 2023
fc9838d
add dict-obs test for rollout
NixGD Sep 14, 2023
fb9498b
add bc.py test
NixGD Sep 14, 2023
e54c36c
cleanup
NixGD Sep 14, 2023
ee04383
small fixes
NixGD Sep 14, 2023
6e2218a
small fixes
NixGD Sep 14, 2023
68fe666
fix type error in interactive.py
NixGD Sep 14, 2023
9ad2aaf
fix introduced error in mce_irl.py
NixGD Sep 14, 2023
67341d5
fix minor ci complaint
NixGD Sep 14, 2023
c497b56
add basic dictobs tests
NixGD Sep 14, 2023
d3f79bf
change default bc policy for dict obs space
NixGD Sep 14, 2023
2de9e49
refine rollout.py typechecks, comments
NixGD Sep 14, 2023
c47cca6
check rollout produces dictobs of correct shape
NixGD Sep 14, 2023
276294b
cleanup types and dictobs helpers
NixGD Sep 14, 2023
071d2a7
clean useless lines
NixGD Sep 14, 2023
a2ccd7e
clean up print statements
NixGD Sep 14, 2023
93baa2d
fix typos
NixGD Sep 15, 2023
54f33af
assert matching keys in from_obs_list
NixGD Sep 15, 2023
c711abf
move maybe_wrap, clean rollout
NixGD Sep 15, 2023
58a0d70
change policy callable to take dict[str, np.ndarray] not dictobs
NixGD Sep 15, 2023
0f080d4
rollout info wrapper supports dictobs
NixGD Sep 15, 2023
c4d3e11
fix from_obs_list key consistency check
NixGD Sep 15, 2023
b93294a
xfail save/load tests with dictobs
NixGD Sep 15, 2023
3f17ff2
doc for dictobs wrapper
NixGD Sep 15, 2023
0212e0e
don't error on int observations
NixGD Sep 15, 2023
070ebf9
lint fixes
NixGD Sep 15, 2023
657e17e
cleanup bc test for dict obs
NixGD Sep 15, 2023
1f8c12a
cleanup bc.py unwrapping
NixGD Sep 15, 2023
bd70ecd
cleanup rollout.py
NixGD Sep 15, 2023
bec464c
cleanup dictobs interface
NixGD Sep 15, 2023
bef19e6
small cleanups
NixGD Sep 15, 2023
9aaf73f
coverage fixes, test fix
NixGD Sep 15, 2023
5d6aa77
adjust error types
NixGD Sep 15, 2023
86fbcf1
docstrings for type helpers
NixGD Sep 15, 2023
8d1e0d6
add dict obs space support for density
NixGD Sep 15, 2023
96978d5
fix typos
NixGD Sep 15, 2023
e95df9d
Adam suggestions from code review
NixGD Sep 16, 2023
161ec95
small changes for code review
NixGD Sep 16, 2023
90bdf57
fix docstring
NixGD Sep 16, 2023
6aa25ff
remove FloatReward
ZiyueWang25 Oct 2, 2023
bf48c76
Merge remote-tracking branch 'origin/master' into support-dict-obs-space
ZiyueWang25 Oct 2, 2023
4ce1b57
Fix test_bc
ZiyueWang25 Oct 2, 2023
de1b1c8
Turn off GPU finding to avoid using gpu device
ZiyueWang25 Oct 2, 2023
1a1a458
Check None to ensure __add__ can work
ZiyueWang25 Oct 2, 2023
f7866f4
fix docstring
ZiyueWang25 Oct 2, 2023
daa838d
bypass pytype and lint test
ZiyueWang25 Oct 2, 2023
803eab0
format with black
ZiyueWang25 Oct 2, 2023
0ac6f54
Test dict space in density algo
ZiyueWang25 Oct 2, 2023
be9798b
black format
ZiyueWang25 Oct 2, 2023
c7e6809
small fix
ZiyueWang25 Oct 2, 2023
82fb558
Add DictObs into test_wrappers
ZiyueWang25 Oct 3, 2023
03714cc
fix format
ZiyueWang25 Oct 3, 2023
187e881
minor fix
ZiyueWang25 Oct 3, 2023
ae96521
type and lint fix
ZiyueWang25 Oct 3, 2023
535a986
Add policy training test
ZiyueWang25 Oct 3, 2023
de027c4
suppress line too long lint check on a line
ZiyueWang25 Oct 3, 2023
be79cf5
acts to obs for clarity
ZiyueWang25 Oct 3, 2023
6e5c3e8
Add HumanReadableWrapper
ZiyueWang25 Oct 3, 2023
ba6a6a7
fix dict env observation space
ZiyueWang25 Oct 3, 2023
a9b32bd
adjust wrapper and not set render_mode inside
ZiyueWang25 Oct 3, 2023
77eab66
Add additional obs check
AdamGleave Oct 4, 2023
194ec1a
Upgrade pytype and remove workaround for old versions
AdamGleave Oct 4, 2023
44b357e
Fix test_rollout test
AdamGleave Oct 4, 2023
ee83ec5
add RemoveHumanReadableWrapper and update ob space
ZiyueWang25 Oct 4, 2023
27f9dc8
Revert "add RemoveHumanReadableWrapper and update ob space"
ZiyueWang25 Oct 4, 2023
d954fed
Revert "adjust wrapper and not set render_mode inside"
ZiyueWang25 Oct 4, 2023
d1131d0
Revert "fix dict env observation space"
ZiyueWang25 Oct 4, 2023
31f8887
Revert "Add HumanReadableWrapper"
ZiyueWang25 Oct 4, 2023
ae9fa64
Revert "acts to obs for clarity"
ZiyueWang25 Oct 4, 2023
3dfafd0
Merge branch 'support-dict-obs-space' of github.com:HumanCompatibleAI…
ZiyueWang25 Oct 4, 2023
7a2b7ce
address comments
ZiyueWang25 Oct 4, 2023
15541cd
new pytype need input directory or file
ZiyueWang25 Oct 4, 2023
6884538
fix np.dtype
ZiyueWang25 Oct 4, 2023
5c6e5b8
ignore typed-dict-error
ZiyueWang25 Oct 4, 2023
5c1d751
context manager related fix
ZiyueWang25 Oct 4, 2023
f5288c6
keep pytype checking more failures
ZiyueWang25 Oct 4, 2023
6e94dea
Revert "keep pytype checking more failures"
ZiyueWang25 Oct 4, 2023
bb1f9cd
Revert "context manager related fix"
ZiyueWang25 Oct 4, 2023
a07ea26
Revert "ignore typed-dict-error"
ZiyueWang25 Oct 4, 2023
b2cca2e
Revert "fix np.dtype"
ZiyueWang25 Oct 4, 2023
1a24ae5
Revert "new pytype need input directory or file"
ZiyueWang25 Oct 4, 2023
b989af8
Revert "Upgrade pytype and remove workaround for old versions"
ZiyueWang25 Oct 4, 2023
4817c2f
lint fix
ZiyueWang25 Oct 4, 2023
94c3ecf
fix type check
ZiyueWang25 Oct 4, 2023
d5d1918
fix lint
ZiyueWang25 Oct 4, 2023
4df8f83
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
ZiyueWang25 Oct 5, 2023
0af3037
Merge branch 'master' of github.com:HumanCompatibleAI/imitation into …
ZiyueWang25 Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/imitation/algorithms/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
Mapping,
Expand Down Expand Up @@ -100,7 +101,7 @@ class BehaviorCloningLossCalculator:
def __call__(
self,
policy: policies.ActorCriticPolicy,
obs: Union[th.Tensor, np.ndarray],
obs: Union[th.Tensor, np.ndarray, types.DictObs],
acts: Union[th.Tensor, np.ndarray],
) -> BCTrainingMetrics:
"""Calculate the supervised learning loss used to train the behavioral clone.
Expand All @@ -114,9 +115,18 @@ def __call__(
A BCTrainingMetrics object with the loss and all the components it
consists of.
"""
obs = util.safe_to_tensor(obs)
tensor_obs: Union[th.Tensor, Dict[str, th.Tensor]]
if isinstance(obs, types.DictObs):
tensor_obs = {k: util.safe_to_tensor(v) for k, v in obs.unwrap().items()}
else:
tensor_obs = util.safe_to_tensor(obs)
acts = util.safe_to_tensor(acts)
_, log_prob, entropy = policy.evaluate_actions(obs, acts)
# TODO: add check obs is proper type?
# policy.evaluate_actions's type signature seems wrong to me.
# it declares it only takes a tensor but it calls
# extract_features which is happy with Dict[str, tensor].
# In reality the required type of obs depends on the feature extractor.
_, log_prob, entropy = policy.evaluate_actions(tensor_obs, acts) # type: ignore
prob_true_act = th.exp(log_prob).mean()
log_prob = log_prob.mean()
entropy = entropy.mean() if entropy is not None else None
Expand Down Expand Up @@ -325,6 +335,7 @@ def __init__(
self.rng = rng

if policy is None:
# TODO: maybe default to comb. dict when dict obs space?
policy = policy_base.FeedForward32Policy(
observation_space=observation_space,
action_space=action_space,
Expand Down Expand Up @@ -465,9 +476,19 @@ def process_batch():
minibatch_size,
num_samples_so_far,
), batch in batches_with_stats:
obs = th.as_tensor(batch["obs"], device=self.policy.device).detach()
acts = th.as_tensor(batch["acts"], device=self.policy.device).detach()
training_metrics = self.loss_calculator(self.policy, obs, acts)
obs_tensor: Union[th.Tensor, Dict[str, th.Tensor]]
if isinstance(batch["obs"], types.DictObs):
obs_dict = batch["obs"].unwrap()
obs_tensor = {
k: util.safe_to_tensor(v, device=self.policy.device)
for k, v in obs_dict.items()
}
else:
obs_tensor = util.safe_to_tensor(
batch["obs"], device=self.policy.device
)
acts = util.safe_to_tensor(batch["acts"], device=self.policy.device)
training_metrics = self.loss_calculator(self.policy, obs_tensor, acts)

# Renormalise the loss to be averaged over the whole
# batch size instead of the minibatch size.
Expand Down
7 changes: 5 additions & 2 deletions src/imitation/algorithms/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,11 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None:

if isinstance(demonstrations, types.TransitionsMinimal):
next_obs_b = getattr(demonstrations, "next_obs", None)
if next_obs_b is not None:
next_obs_b = types.assert_not_dictobs(next_obs_b)
transitions.update(
self._get_demo_from_batch(
demonstrations.obs,
types.assert_not_dictobs(demonstrations.obs),
demonstrations.acts,
next_obs_b,
),
Expand All @@ -191,8 +193,9 @@ def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None:
demonstrations = cast(Iterable[types.Trajectory], demonstrations)

for traj in demonstrations:
traj_obs = types.assert_not_dictobs(traj.obs)
for i, (obs, act, next_obs) in enumerate(
zip(traj.obs[:-1], traj.acts, traj.obs[1:]),
zip(traj_obs[:-1], traj.acts, traj_obs[1:]),
):
flat_trans = self._preprocess_transition(obs, act, next_obs)
transitions.setdefault(i, []).append(flat_trans)
Expand Down
18 changes: 14 additions & 4 deletions src/imitation/algorithms/mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def _set_demo_from_trajectories(self, trajs: Iterable[types.Trajectory]) -> None
num_demos = 0
for traj in trajs:
cum_discount = 1.0
if isinstance(traj.obs, types.DictObs):
raise ValueError(
"Dictionary observations are not currently supported for mce_irl"
)
for obs in traj.obs:
self.demo_state_om[obs] += cum_discount
cum_discount *= self.discount
Expand Down Expand Up @@ -411,12 +415,14 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None:

if isinstance(demonstrations, types.Transitions):
self._set_demo_from_obs(
demonstrations.obs,
types.assert_not_dictobs(demonstrations.obs),
demonstrations.dones,
demonstrations.next_obs,
types.assert_not_dictobs(demonstrations.next_obs),
)
elif isinstance(demonstrations, types.TransitionsMinimal):
self._set_demo_from_obs(demonstrations.obs, None, None)
self._set_demo_from_obs(
types.assert_not_dictobs(demonstrations.obs), None, None
)
elif isinstance(demonstrations, Iterable):
# Demonstrations are a Torch DataLoader or other Mapping iterable
# Collect them together into one big NumPy array. This is inefficient,
Expand All @@ -427,7 +433,11 @@ def set_demonstrations(self, demonstrations: MCEDemonstrations) -> None:
assert isinstance(batch, Mapping)
for k in ("obs", "dones", "next_obs"):
if k in batch:
collated_list[k].append(batch[k])
if isinstance(batch[k], types.DictObs):
raise ValueError(
"Dictionary observations are not currently supported for buffers"
)
collated_list[k].append()
collated = {k: np.concatenate(v) for k, v in collated_list.items()}

assert "obs" in collated
Expand Down
4 changes: 2 additions & 2 deletions src/imitation/algorithms/preference_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,9 @@ def rewards(self, transitions: Transitions) -> th.Tensor:
Shape - (num_transitions, ) for Single reward network and
(num_transitions, num_networks) for ensemble of networks.
"""
state = transitions.obs
state = types.assert_not_dictobs(transitions.obs)
action = transitions.acts
next_state = transitions.next_obs
next_state = types.assert_not_dictobs(transitions.next_obs)
done = transitions.dones
if self.ensemble_model is not None:
rews_np = self.ensemble_model.predict_processed_all(
Expand Down
5 changes: 5 additions & 0 deletions src/imitation/data/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ def from_data(
Returns:
A new ReplayBuffer.
"""
if isinstance(transitions.obs, types.DictObs):
raise ValueError(
"Dictionary observations are not currently supported for buffers"
)

obs_shape = transitions.obs.shape[1:]
act_shape = transitions.acts.shape[1:]
if capacity is None:
Expand Down
2 changes: 2 additions & 0 deletions src/imitation/data/huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def trajectories_to_dict(
],
terminal=[traj.terminal for traj in trajectories],
)
if any(isinstance(traj.obs, types.DictObs) for traj in trajectories):
raise ValueError("DictObs are not currently supported")

# Encode infos as jsonpickled strings
trajectory_dict["infos"] = [
Expand Down
89 changes: 66 additions & 23 deletions src/imitation/data/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

import numpy as np
from gym import spaces
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.utils import check_for_correct_spaces
Expand Down Expand Up @@ -69,7 +70,7 @@ def __init__(self):

def add_step(
self,
step_dict: Mapping[str, Union[np.ndarray, Mapping[str, Any]]],
step_dict: Mapping[str, Union[np.ndarray, Mapping[str, Any], types.DictObs]],
key: Hashable = None,
) -> None:
"""Add a single step to the partial trajectory identified by `key`.
Expand Down Expand Up @@ -107,17 +108,22 @@ def finish_trajectory(
for part_dict in part_dicts:
for k, array in part_dict.items():
out_dict_unstacked[k].append(array)
out_dict_stacked = {
k: np.stack(arr_list, axis=0) for k, arr_list in out_dict_unstacked.items()
}
traj = types.TrajectoryWithRew(**out_dict_stacked, terminal=terminal)
assert traj.rews.shape[0] == traj.acts.shape[0] == traj.obs.shape[0] - 1

# TODO: what about infos? Does this actually handle them well?
traj = types.TrajectoryWithRew(
obs=types.stack_maybe_dictobs(out_dict_unstacked["obs"]),
acts=np.stack(out_dict_unstacked["acts"], axis=0),
infos=np.stack(out_dict_unstacked["infos"], axis=0), # TODO: confused
rews=np.stack(out_dict_unstacked["rews"], axis=0),
terminal=terminal,
)
assert traj.rews.shape[0] == traj.acts.shape[0] == len(traj.obs) - 1
return traj

def add_steps_and_auto_finish(
self,
acts: np.ndarray,
obs: np.ndarray,
obs: Union[np.ndarray, dict[str, np.ndarray], types.DictObs],
rews: np.ndarray,
dones: np.ndarray,
infos: List[dict],
Expand All @@ -142,20 +148,26 @@ def add_steps_and_auto_finish(
each `True` in the `dones` argument.
"""
trajs: List[types.TrajectoryWithRew] = []
for env_idx in range(len(obs)):
wrapped_obs = types.DictObs.maybe_wrap(obs)

# len of dictobs is the shape[0] of each value array - which here is # of envs
for env_idx in range(len(wrapped_obs)):
assert env_idx in self.partial_trajectories
assert list(self.partial_trajectories[env_idx][0].keys()) == ["obs"], (
"Need to first initialize partial trajectory using "
"self._traj_accum.add_step({'obs': ob}, key=env_idx)"
)

zip_iter = enumerate(zip(acts, obs, rews, dones, infos))
zip_iter = enumerate(zip(acts, wrapped_obs, rews, dones, infos))
for env_idx, (act, ob, rew, done, info) in zip_iter:
if done:
# When dones[i] from VecEnv.step() is True, obs[i] is the first
# observation following reset() of the ith VecEnv, and
# infos[i]["terminal_observation"] is the actual final observation.
real_ob = info["terminal_observation"]
if isinstance(real_ob, dict):
# TODO: does this need to be unsqueezed or something?
real_ob = types.DictObs(real_ob)
else:
real_ob = ob

Expand Down Expand Up @@ -268,7 +280,11 @@ def sample_until(trajs: Sequence[types.TrajectoryWithRew]) -> bool:
# array of states, and an optional array of episode starts and returns an array of
# corresponding actions.
PolicyCallable = Callable[
[np.ndarray, Optional[Tuple[np.ndarray, ...]], Optional[np.ndarray]],
[
Union[np.ndarray, types.DictObs],
Optional[Tuple[np.ndarray, ...]],
Optional[np.ndarray],
],
Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]],
]
AnyPolicy = Union[BaseAlgorithm, BasePolicy, PolicyCallable, None]
Expand All @@ -284,7 +300,7 @@ def policy_to_callable(
if policy is None:

def get_actions(
observations: np.ndarray,
observations: Union[np.ndarray, types.DictObs],
states: Optional[Tuple[np.ndarray, ...]],
episode_starts: Optional[np.ndarray],
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
Expand All @@ -298,15 +314,15 @@ def get_actions(
# (which would call .forward()). So this elif clause must come first!

def get_actions(
observations: np.ndarray,
observations: Union[np.ndarray, types.DictObs],
states: Optional[Tuple[np.ndarray, ...]],
episode_starts: Optional[np.ndarray],
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
assert isinstance(policy, (BaseAlgorithm, BasePolicy))
# pytype doesn't seem to understand that policy is a BaseAlgorithm
# or BasePolicy here, rather than a Callable
(acts, states) = policy.predict( # pytype: disable=attribute-error
observations,
types.DictObs.maybe_unwrap(observations),
state=states,
episode_start=episode_starts,
deterministic=deterministic_policy,
Expand Down Expand Up @@ -403,7 +419,21 @@ def generate_trajectories(
# accumulator for incomplete trajectories
trajectories_accum = TrajectoryAccumulator()
obs = venv.reset()
for env_idx, ob in enumerate(obs):

assert isinstance(
obs,
(
np.ndarray,
dict,
),
), "Tuple observations are not supported."

# need to wrap here to iterate over envs properly
wrapped_obs = types.DictObs.maybe_wrap(obs)
# TODO: make this nicer, it's currently non-mypy compliant
# probably want helper

for env_idx, ob in enumerate(wrapped_obs):
# Seed with first obs only. Inside loop, we'll only add second obs from
# each (s,a,r,s') tuple, under the same "obs" key again. That way we still
# get all observations, but they're not duplicated into "next obs" and
Expand All @@ -419,13 +449,16 @@ def generate_trajectories(
#
# To start with, all environments are active.
active = np.ones(venv.num_envs, dtype=bool)
assert isinstance(obs, np.ndarray), "Dict/tuple observations are not supported."
state = None
dones = np.zeros(venv.num_envs, dtype=bool)
while np.any(active):
acts, state = get_actions(obs, state, dones)
acts, state = get_actions(wrapped_obs, state, dones)
obs, rews, dones, infos = venv.step(acts)
assert isinstance(obs, np.ndarray)
assert isinstance(
obs,
(np.ndarray, types.DictObs),
), "Tuple observations are not supported."
wrapped_obs = types.DictObs.maybe_wrap(obs)

# If an environment is inactive, i.e. the episode completed for that
# environment after `sample_until(trajectories)` was true, then we do
Expand All @@ -435,7 +468,7 @@ def generate_trajectories(

new_trajs = trajectories_accum.add_steps_and_auto_finish(
acts,
obs,
wrapped_obs,
rews,
dones,
infos,
Expand All @@ -460,9 +493,10 @@ def generate_trajectories(
for trajectory in trajectories:
n_steps = len(trajectory.acts)
# extra 1 for the end
exp_obs = (n_steps + 1,) + venv.observation_space.shape
real_obs = trajectory.obs.shape
assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}"
if not isinstance(venv.observation_space, spaces.dict.Dict):
exp_obs = (n_steps + 1,) + venv.observation_space.shape
real_obs = types.assert_not_dictobs(trajectory.obs).shape
assert real_obs == exp_obs, f"expected shape {exp_obs}, got {real_obs}"
exp_act = (n_steps,) + venv.action_space.shape
real_act = trajectory.acts.shape
assert real_act == exp_act, f"expected shape {exp_act}, got {real_act}"
Expand Down Expand Up @@ -539,7 +573,8 @@ def flatten_trajectories(
The trajectories flattened into a single batch of Transitions.
"""
keys = ["obs", "next_obs", "acts", "dones", "infos"]
parts: Mapping[str, List[np.ndarray]] = {key: [] for key in keys}
# TODO: sad to use Any here
parts: Mapping[str, List[Any]] = {key: [] for key in keys}
for traj in trajectories:
parts["acts"].append(traj.acts)

Expand All @@ -558,11 +593,19 @@ def flatten_trajectories(
parts["infos"].append(infos)

cat_parts = {
key: np.concatenate(part_list, axis=0) for key, part_list in parts.items()
key: types.concatenate_maybe_dictobs(part_list)
for key, part_list in parts.items()
}
lengths = set(map(len, cat_parts.values()))
assert len(lengths) == 1, f"expected one length, got {lengths}"
return types.Transitions(**cat_parts)
# TODO: clean
# cat_parts["obs"],
# types.assert_not_dictobs(cat_parts["acts"]),
# types.assert_not_dictobs(cat_parts["infos"]),
# cat_parts["next_obs"],
# types.assert_not_dictobs(cat_parts["done"]),
# )


def flatten_trajectories_with_rew(
Expand Down
Loading