22import abc
33import dataclasses
44import logging
5- from typing import Callable , Iterable , Iterator , Mapping , Optional , Type , overload
5+ from typing import Iterable , Iterator , List , Mapping , Optional , Type , overload
66
77import numpy as np
88import torch as th
99import torch .utils .tensorboard as thboard
1010import tqdm
1111from stable_baselines3 .common import base_class , on_policy_algorithm , policies , vec_env
12+ from stable_baselines3 .common .type_aliases import MaybeCallback
13+ from stable_baselines3 .common .callbacks import BaseCallback , ConvertCallback
1214from stable_baselines3 .sac import policies as sac_policies
1315from torch .nn import functional as F
1416
@@ -386,6 +388,7 @@ def train_gen(
386388 self ,
387389 total_timesteps : Optional [int ] = None ,
388390 learn_kwargs : Optional [Mapping ] = None ,
391+ callback : MaybeCallback = None ,
389392 ) -> None :
390393 """Trains the generator to maximize the discriminator loss.
391394
@@ -398,17 +401,27 @@ def train_gen(
398401 `self.gen_train_timesteps`.
399402 learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()`
400403 method.
404+ callback: additional callback(s) passed to the generator's `learn` method.
401405 """
402406 if total_timesteps is None :
403407 total_timesteps = self .gen_train_timesteps
404408 if learn_kwargs is None :
405409 learn_kwargs = {}
406410
411+ callbacks = [self .gen_callback ]
412+
413+ if isinstance (callback , list ):
414+ callbacks .extend (callback )
415+ elif isinstance (callback , BaseCallback ):
416+ callbacks .append (callback )
417+ elif callback is not None :
418+ callbacks .append (ConvertCallback (callback ))
419+
407420 with self .logger .accumulate_means ("gen" ):
408421 self .gen_algo .learn (
409422 total_timesteps = total_timesteps ,
410423 reset_num_timesteps = False ,
411- callback = self . gen_callback ,
424+ callback = callbacks ,
412425 ** learn_kwargs ,
413426 )
414427 self ._global_step += 1
@@ -421,37 +434,33 @@ def train_gen(
421434 def train (
422435 self ,
423436 total_timesteps : int ,
424- callback : Optional [ Callable [[ int ], None ]] = None ,
437+ callback : MaybeCallback = None ,
425438 ) -> None :
426439 """Alternates between training the generator and discriminator.
427440
428- Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`,
429- a call to `train_disc`, and finally a call to `callback(round) `.
441+ Every "round" consists of a call to
442+ `train_gen(self.gen_train_timesteps, callback)`, then a call to `train_disc `.
430443
431444 Training ends once an additional "round" would cause the number of transitions
432445 sampled from the environment to exceed `total_timesteps`.
433446
434447 Args:
435448 total_timesteps: An upper bound on the number of transitions to sample
436449 from the environment during training.
437- callback: A function called at the end of every round which takes in a
438- single argument, the round number. Round numbers are in
439- `range(total_timesteps // self.gen_train_timesteps)`.
450+ callback: callback(s) passed to the generator's `learn` method.
440451 """
441452 n_rounds = total_timesteps // self .gen_train_timesteps
442453 assert n_rounds >= 1 , (
443454 "No updates (need at least "
444455 f"{ self .gen_train_timesteps } timesteps, have only "
445456 f"total_timesteps={ total_timesteps } )!"
446457 )
447- for r in tqdm .tqdm (range (0 , n_rounds ), desc = "round" ):
448- self .train_gen (self .gen_train_timesteps )
458+ for _r in tqdm .tqdm (range (0 , n_rounds ), desc = "round" ):
459+ self .train_gen (self .gen_train_timesteps , callback = callback )
449460 for _ in range (self .n_disc_updates_per_round ):
450461 with networks .training (self .reward_train ):
451462 # switch to training mode (affects dropout, normalization)
452463 self .train_disc ()
453- if callback :
454- callback (r )
455464 self .logger .dump (self ._global_step )
456465
457466 @overload
0 commit comments