Skip to content
157 changes: 141 additions & 16 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
import torch as th
import torch.utils.tensorboard as thboard
import tqdm
from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env
from stable_baselines3.common import base_class
from stable_baselines3.common import buffers as sb3_buffers
from stable_baselines3.common import on_policy_algorithm, policies, type_aliases
from stable_baselines3.common import utils as sb3_utils
from stable_baselines3.common import vec_env
from stable_baselines3.sac import policies as sac_policies
from torch.nn import functional as F

from imitation.algorithms import base
from imitation.data import buffer, rollout, types, wrappers
from imitation.policies import replay_buffer_wrapper
from imitation.rewards import reward_nets, reward_wrapper
from imitation.util import logger, networks, util

Expand Down Expand Up @@ -246,6 +251,38 @@ def __init__(
else:
self.gen_train_timesteps = gen_train_timesteps

self.is_gen_on_policy = isinstance(
self.gen_algo, on_policy_algorithm.OnPolicyAlgorithm
)
if self.is_gen_on_policy:
rollout_buffer = self.gen_algo.rollout_buffer
self.gen_algo.rollout_buffer = (
replay_buffer_wrapper.RolloutBufferRewardWrapper(
buffer_size=self.gen_train_timesteps // rollout_buffer.n_envs,
observation_space=rollout_buffer.observation_space,
action_space=rollout_buffer.action_space,
rollout_buffer_class=rollout_buffer.__class__,
reward_fn=self.reward_train.predict_processed,
device=rollout_buffer.device,
gae_lambda=rollout_buffer.gae_lambda,
gamma=rollout_buffer.gamma,
n_envs=rollout_buffer.n_envs,
)
)
else:
replay_buffer = self.gen_algo.replay_buffer
self.gen_algo.replay_buffer = (
replay_buffer_wrapper.ReplayBufferRewardWrapper(
buffer_size=self.gen_train_timesteps,
observation_space=replay_buffer.observation_space,
action_space=replay_buffer.action_space,
replay_buffer_class=sb3_buffers.ReplayBuffer,
reward_fn=self.reward_train.predict_processed,
device=replay_buffer.device,
n_envs=replay_buffer.n_envs,
)
)

if gen_replay_buffer_capacity is None:
gen_replay_buffer_capacity = self.gen_train_timesteps
self._gen_replay_buffer = buffer.ReplayBuffer(
Expand Down Expand Up @@ -382,41 +419,126 @@ def train_disc(

return train_stats

def train_gen(
def collect_rollouts(
self,
total_timesteps: Optional[int] = None,
callback: type_aliases.MaybeCallback = None,
learn_kwargs: Optional[Mapping] = None,
) -> None:
"""Trains the generator to maximize the discriminator loss.

After the end of training populates the generator replay buffer (used in
discriminator training) with `self.disc_batch_size` transitions.
):
"""Collect rollouts.

Args:
total_timesteps: The number of transitions to sample from
`self.venv_train` during training. By default,
`self.gen_train_timesteps`.
callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()`
method.
"""
if total_timesteps is None:
total_timesteps = self.gen_train_timesteps
if learn_kwargs is None:
learn_kwargs = {}

with self.logger.accumulate_means("gen"):
self.gen_algo.learn(
total_timesteps=total_timesteps,
reset_num_timesteps=False,
callback=self.gen_callback,
if total_timesteps is None:
total_timesteps = self.gen_train_timesteps

# total timesteps should be per env
total_timesteps = total_timesteps // self.gen_algo.n_envs
# NOTE (Taufeeque): call setup_learn or not?
if "eval_env" not in learn_kwargs:
total_timesteps, callback = self.gen_algo._setup_learn(
total_timesteps,
eval_env=None,
callback=callback,
**learn_kwargs,
)
self._global_step += 1
else:
total_timesteps, callback = self.gen_algo._setup_learn(
total_timesteps,
callback=callback,
**learn_kwargs,
)
callback.on_training_start(locals(), globals())
if self.is_gen_on_policy:
self.gen_algo.collect_rollouts(
self.gen_algo.env,
callback,
self.gen_algo.rollout_buffer,
n_rollout_steps=total_timesteps,
)
rollouts = None
else:
self.gen_algo.train_freq = total_timesteps
self.gen_algo._convert_train_freq()
rollouts = self.gen_algo.collect_rollouts(
self.gen_algo.env,
train_freq=self.gen_algo.train_freq,
action_noise=self.gen_algo.action_noise,
callback=callback,
learning_starts=self.gen_algo.learning_starts,
replay_buffer=self.gen_algo.replay_buffer,
)

if self.is_gen_on_policy:
if (
len(self.gen_algo.ep_info_buffer) > 0
and len(self.gen_algo.ep_info_buffer[0]) > 0
):
self.logger.record(
"rollout/ep_rew_mean",
sb3_utils.safe_mean(
[ep_info["r"] for ep_info in self.gen_algo.ep_info_buffer]
),
)
self.logger.record(
"rollout/ep_len_mean",
sb3_utils.safe_mean(
[ep_info["l"] for ep_info in self.gen_algo.ep_info_buffer]
),
)
self.logger.record(
"time/total_timesteps",
self.gen_algo.num_timesteps,
exclude="tensorboard",
)
else:
self.gen_algo._dump_logs()

gen_trajs, ep_lens = self.venv_buffering.pop_trajectories()
self._check_fixed_horizon(ep_lens)
gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs)
self._gen_replay_buffer.store(gen_samples)
callback.on_training_end()
return rollouts

def train_gen(
self,
rollouts,
) -> None:
"""Trains the generator to maximize the discriminator loss.

After the end of training populates the generator replay buffer (used in
discriminator training) with `self.disc_batch_size` transitions.
"""
with self.logger.accumulate_means("gen"):
# self.gen_algo.learn(
# total_timesteps=total_timesteps,
# reset_num_timesteps=False,
# callback=self.gen_callback,
# **learn_kwargs,
# )
if self.is_gen_on_policy:
self.gen_algo.train()
else:
if self.gen_algo.gradient_steps >= 0:
gradient_steps = self.gen_algo.gradient_steps
else:
gradient_steps = rollouts.episode_timesteps
self.gen_algo.train(
batch_size=self.gen_algo.batch_size,
gradient_steps=gradient_steps,
)
self._global_step += 1

def train(
self,
Expand Down Expand Up @@ -445,11 +567,14 @@ def train(
f"total_timesteps={total_timesteps})!"
)
for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
self.train_gen(self.gen_train_timesteps)
rollouts = self.collect_rollouts(
self.gen_train_timesteps, self.gen_callback
)
for _ in range(self.n_disc_updates_per_round):
with networks.training(self.reward_train):
# switch to training mode (affects dropout, normalization)
self.train_disc()
self.train_gen(rollouts)
if callback:
callback(r)
self.logger.dump(self._global_step)
Expand Down
140 changes: 136 additions & 4 deletions src/imitation/policies/replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from typing import Mapping, Type

import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.buffers import BaseBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.type_aliases import ReplayBufferSamples

from imitation.rewards.reward_function import RewardFn
from imitation.util import util


def _samples_to_reward_fn_input(
def _replay_samples_to_reward_fn_input(
samples: ReplayBufferSamples,
) -> Mapping[str, np.ndarray]:
"""Convert a sample from a replay buffer to a numpy array."""
Expand All @@ -23,6 +24,18 @@ def _samples_to_reward_fn_input(
)


def _rollout_samples_to_reward_fn_input(
buffer: RolloutBuffer,
) -> Mapping[str, np.ndarray]:
"""Convert a sample from a rollout buffer to a numpy array."""
return dict(
state=buffer.observations,
action=buffer.actions,
next_state=buffer.next_observations,
done=buffer.dones,
)


class ReplayBufferRewardWrapper(ReplayBuffer):
"""Relabel the rewards in transitions sampled from a ReplayBuffer."""

Expand Down Expand Up @@ -50,7 +63,9 @@ def __init__(
# DictReplayBuffer because the current RewardFn only takes in NumPy array-based
# inputs, and SAC is the only use case for ReplayBuffer relabeling. See:
# https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194
assert replay_buffer_class is ReplayBuffer, "only ReplayBuffer is supported"
assert (
replay_buffer_class is ReplayBuffer
), f"only ReplayBuffer is supported: given {replay_buffer_class}"
assert not isinstance(observation_space, spaces.Dict)
self.replay_buffer = replay_buffer_class(
buffer_size,
Expand Down Expand Up @@ -80,7 +95,7 @@ def full(self, full: bool):

def sample(self, *args, **kwargs):
samples = self.replay_buffer.sample(*args, **kwargs)
rewards = self.reward_fn(**_samples_to_reward_fn_input(samples))
rewards = self.reward_fn(**_replay_samples_to_reward_fn_input(samples))
shape = samples.rewards.shape
device = samples.rewards.device
rewards_th = util.safe_to_tensor(rewards).reshape(shape).to(device)
Expand All @@ -101,3 +116,120 @@ def _get_samples(self):
"_get_samples() is intentionally not implemented."
"This method should not be called.",
)


class RolloutBufferRewardWrapper(BaseBuffer):
"""Relabel the rewards in transitions sampled from a RolloutBuffer."""

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
*,
rollout_buffer_class: Type[RolloutBuffer],
reward_fn: RewardFn,
**kwargs,
):
"""Builds RolloutBufferRewardWrapper.

Args:
buffer_size: Max number of elements in the buffer
observation_space: Observation space
action_space: Action space
rollout_buffer_class: Class of the rollout buffer.
reward_fn: Reward function for reward relabeling.
**kwargs: keyword arguments for RolloutBuffer.
"""
# Note(yawen-d): we directly inherit RolloutBuffer and leave out the case of
# DictRolloutBuffer because the current RewardFn only takes in NumPy array-based
# inputs, and GAIL/AIRL is the only use case for RolloutBuffer relabeling. See:
# https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194
assert rollout_buffer_class is RolloutBuffer, "only RolloutBuffer is supported"
assert not isinstance(observation_space, spaces.Dict)
self.rollout_buffer = rollout_buffer_class(
buffer_size,
observation_space,
action_space,
**kwargs,
)
self.reward_fn = reward_fn
_base_kwargs = {k: v for k, v in kwargs.items() if k in ["device", "n_envs"]}
super().__init__(buffer_size, observation_space, action_space, **_base_kwargs)

@property
def pos(self) -> int:
return self.rollout_buffer.pos

@property
def values(self):
return self.rollout_buffer.values

@property
def observations(self):
return self.rollout_buffer.observations

@property
def actions(self):
return self.rollout_buffer.actions

@property
def log_probs(self):
return self.rollout_buffer.log_probs

@property
def advantages(self):
return self.rollout_buffer.advantages

@property
def rewards(self):
return self.rollout_buffer.rewards

@property
def returns(self):
return self.rollout_buffer.returns

@pos.setter
def pos(self, pos: int):
self.rollout_buffer.pos = pos

@property
def full(self) -> bool:
return self.rollout_buffer.full

@full.setter
def full(self, full: bool):
self.rollout_buffer.full = full

def reset(self):
self.rollout_buffer.reset()

def get(self, *args, **kwargs):
if not self.rollout_buffer.generator_ready:
input_dict = _rollout_samples_to_reward_fn_input(self.rollout_buffer)
rewards = np.zeros_like(self.rollout_buffer.rewards)
for i in range(self.buffer_size):
rewards[i] = self.reward_fn(**{k: v[i] for k, v in input_dict.items()})

self.rollout_buffer.rewards = rewards
self.rollout_buffer.compute_returns_and_advantage(
self.last_values, self.last_dones
)
ret = self.rollout_buffer.get(*args, **kwargs)
return ret

def add(self, *args, **kwargs):
self.rollout_buffer.add(*args, **kwargs)

def _get_samples(self):
raise NotImplementedError(
"_get_samples() is intentionally not implemented."
"This method should not be called.",
)

def compute_returns_and_advantage(
self, last_values: th.Tensor, dones: np.ndarray
) -> None:
self.last_values = last_values
self.last_dones = dones
self.rollout_buffer.compute_returns_and_advantage(last_values, dones)
Loading