Skip to content

Commit eb8f67e

Browse files
committed
support for SB3 callbacks in adversarial training
1 parent cb93fb0 commit eb8f67e

File tree

3 files changed

+82
-16
lines changed

3 files changed

+82
-16
lines changed

src/imitation/algorithms/adversarial/common.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import abc
33
import dataclasses
44
import logging
5-
from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload
5+
from typing import Iterable, Iterator, List, Mapping, Optional, Type, overload
66

77
import numpy as np
88
import torch as th
99
import torch.utils.tensorboard as thboard
1010
import tqdm
1111
from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env
12+
from stable_baselines3.common.type_aliases import MaybeCallback
13+
from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback
1214
from stable_baselines3.sac import policies as sac_policies
1315
from torch.nn import functional as F
1416

@@ -386,6 +388,7 @@ def train_gen(
386388
self,
387389
total_timesteps: Optional[int] = None,
388390
learn_kwargs: Optional[Mapping] = None,
391+
callback: MaybeCallback = None,
389392
) -> None:
390393
"""Trains the generator to maximize the discriminator loss.
391394
@@ -398,17 +401,27 @@ def train_gen(
398401
`self.gen_train_timesteps`.
399402
learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()`
400403
method.
404+
callback: additional callback(s) passed to the generator's `learn` method.
401405
"""
402406
if total_timesteps is None:
403407
total_timesteps = self.gen_train_timesteps
404408
if learn_kwargs is None:
405409
learn_kwargs = {}
406410

411+
callbacks = [self.gen_callback]
412+
413+
if isinstance(callback, list):
414+
callbacks.extend(callback)
415+
elif isinstance(callback, BaseCallback):
416+
callbacks.append(callback)
417+
elif callback is not None:
418+
callbacks.append(ConvertCallback(callback))
419+
407420
with self.logger.accumulate_means("gen"):
408421
self.gen_algo.learn(
409422
total_timesteps=total_timesteps,
410423
reset_num_timesteps=False,
411-
callback=self.gen_callback,
424+
callback=callbacks,
412425
**learn_kwargs,
413426
)
414427
self._global_step += 1
@@ -421,37 +434,33 @@ def train_gen(
421434
def train(
422435
self,
423436
total_timesteps: int,
424-
callback: Optional[Callable[[int], None]] = None,
437+
callback: MaybeCallback = None,
425438
) -> None:
426439
"""Alternates between training the generator and discriminator.
427440
428-
Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`,
429-
a call to `train_disc`, and finally a call to `callback(round)`.
441+
Every "round" consists of a call to
442+
`train_gen(self.gen_train_timesteps, callback)`, then a call to `train_disc`.
430443
431444
Training ends once an additional "round" would cause the number of transitions
432445
sampled from the environment to exceed `total_timesteps`.
433446
434447
Args:
435448
total_timesteps: An upper bound on the number of transitions to sample
436449
from the environment during training.
437-
callback: A function called at the end of every round which takes in a
438-
single argument, the round number. Round numbers are in
439-
`range(total_timesteps // self.gen_train_timesteps)`.
450+
callback: callback(s) passed to the generator's `learn` method.
440451
"""
441452
n_rounds = total_timesteps // self.gen_train_timesteps
442453
assert n_rounds >= 1, (
443454
"No updates (need at least "
444455
f"{self.gen_train_timesteps} timesteps, have only "
445456
f"total_timesteps={total_timesteps})!"
446457
)
447-
for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
448-
self.train_gen(self.gen_train_timesteps)
458+
for _r in tqdm.tqdm(range(0, n_rounds), desc="round"):
459+
self.train_gen(self.gen_train_timesteps, callback=callback)
449460
for _ in range(self.n_disc_updates_per_round):
450461
with networks.training(self.reward_train):
451462
# switch to training mode (affects dropout, normalization)
452463
self.train_disc()
453-
if callback:
454-
callback(r)
455464
self.logger.dump(self._global_step)
456465

457466
@overload

src/imitation/scripts/train_adversarial.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sacred.commands
99
import torch as th
1010
from sacred.observers import FileStorageObserver
11+
from stable_baselines3.common.callbacks import BaseCallback
1112

1213
from imitation.algorithms.adversarial import airl as airl_algo
1314
from imitation.algorithms.adversarial import common
@@ -22,6 +23,28 @@
2223
logger = logging.getLogger("imitation.scripts.train_adversarial")
2324

2425

26+
class CheckpointCallback(BaseCallback):
27+
def __init__(
28+
self,
29+
trainer: common.AdversarialTrainer,
30+
log_dir: pathlib.Path,
31+
interval: int
32+
):
33+
super().__init__(self)
34+
self.trainer = trainer
35+
self.log_dir = log_dir
36+
self.interval = interval
37+
self.round_num = 0
38+
39+
def _on_step(self) -> bool:
40+
return True
41+
42+
def _on_training_end(self) -> None:
43+
self.round_num += 1
44+
if self.interval > 0 and self.round_num % self.interval == 0:
45+
save(self.trainer, self.log_dir / "checkpoints" / f"{self.round_num:05d}")
46+
47+
2548
def save(trainer: common.AdversarialTrainer, save_path: pathlib.Path):
2649
"""Save discriminator and generator."""
2750
# We implement this here and not in Trainer since we do not want to actually
@@ -153,10 +176,7 @@ def train_adversarial(
153176
**algorithm_kwargs,
154177
)
155178

156-
def callback(round_num: int, /) -> None:
157-
if checkpoint_interval > 0 and round_num % checkpoint_interval == 0:
158-
save(trainer, log_dir / "checkpoints" / f"{round_num:05d}")
159-
179+
callback = CheckpointCallback(trainer, log_dir, checkpoint_interval)
160180
trainer.train(total_timesteps, callback)
161181
imit_stats = policy_evaluation.eval_policy(trainer.policy, trainer.venv_train)
162182

tests/algorithms/test_adversarial.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import stable_baselines3
1111
import torch as th
1212
from stable_baselines3.common import policies
13+
from stable_baselines3.common.callbacks import BaseCallback
1314
from torch.utils import data as th_data
1415

1516
from imitation.algorithms.adversarial import airl, common, gail
@@ -464,3 +465,39 @@ def test_regression_gail_with_sac(
464465
reward_net=reward_net,
465466
)
466467
gail_trainer.train(8)
468+
469+
470+
def test_gen_callback(trainer: common.AdversarialTrainer):
471+
learner = stable_baselines3.PPO("MlpPolicy", env=trainer.venv)
472+
473+
def make_fn_callback(calls, key):
474+
def cb(_a, _b):
475+
calls[key] += 1
476+
return cb
477+
478+
class SB3Callback(BaseCallback):
479+
def __init__(self, calls, key):
480+
super().__init__(self)
481+
self.calls = calls
482+
self.key = key
483+
484+
def _on_step(self):
485+
self.calls[self.key] += 1
486+
return True
487+
488+
n_steps = trainer.gen_train_timesteps * 2
489+
calls = {"fn": 0, "sb3": 0, "list.0": 0, "list.1": 0}
490+
491+
trainer.train(n_steps, callback=make_fn_callback(calls, "fn"))
492+
trainer.train(n_steps, callback=SB3Callback(calls, "sb3"))
493+
trainer.train(n_steps, callback=[
494+
SB3Callback(calls, "list.0"),
495+
SB3Callback(calls, "list.1")
496+
])
497+
498+
# Env steps for off-plicy algos (DQN) may exceed `total_timesteps`,
499+
# so we check if the callback was called *at least* that many times.
500+
assert calls["fn"] >= n_steps
501+
assert calls["sb3"] >= n_steps
502+
assert calls["list.0"] >= n_steps
503+
assert calls["list.1"] >= n_steps

0 commit comments

Comments
 (0)