From fa9b2f48e7f2b03537de58aa3c183a1c3a514914 Mon Sep 17 00:00:00 2001 From: Tri Wahyu Guntara Date: Tue, 15 Nov 2022 18:25:36 +0900 Subject: [PATCH] change callback mechanism --- src/imitation/algorithms/adversarial/common.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 8729d3c30..6959d282e 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -2,13 +2,14 @@ import abc import dataclasses import logging -from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload +from typing import Iterable, Iterator, List, Mapping, Optional, Type, overload import numpy as np import torch as th import torch.utils.tensorboard as thboard import tqdm from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env +from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.sac import policies as sac_policies from torch.nn import functional as F @@ -421,7 +422,7 @@ def train_gen( def train( self, total_timesteps: int, - callback: Optional[Callable[[int], None]] = None, + callback: Optional[List[BaseCallback]] = None ) -> None: """Alternates between training the generator and discriminator. @@ -434,10 +435,15 @@ def train( Args: total_timesteps: An upper bound on the number of transitions to sample from the environment during training. - callback: A function called at the end of every round which takes in a - single argument, the round number. Round numbers are in - `range(total_timesteps // self.gen_train_timesteps)`. + callback: List of stable_baslines3 callback to be passed to the policy + learning function. """ + if callback is not None: + if self.gen_callback is None: + self.gen_callback = callback + else: + self.gen_callback = callback + [self.gen_callback] + n_rounds = total_timesteps // self.gen_train_timesteps assert n_rounds >= 1, ( "No updates (need at least " @@ -450,8 +456,6 @@ def train( with networks.training(self.reward_train): # switch to training mode (affects dropout, normalization) self.train_disc() - if callback: - callback(r) self.logger.dump(self._global_step) @overload