@@ -101,6 +101,8 @@ def __init__(self, adversarial_trainer, *args, **kwargs):
101101 """Builds TrainDiscriminatorCallback.
102102
103103 Args:
104+ adversarial_trainer: The AdversarialTrainer instance in which
105+ this callback will be called.
104106 *args: Passed through to `callbacks.BaseCallback`.
105107 **kwargs: Passed through to `callbacks.BaseCallback`.
106108 """
@@ -276,7 +278,7 @@ def __init__(
276278 # Would use an identity reward fn here, but RewardFns can't see rewards.
277279 self .venv_wrapped = self .venv_buffering
278280 self .gen_callback : List [callbacks .BaseCallback ] = [
279- self .disc_trainer_callback
281+ self .disc_trainer_callback ,
280282 ]
281283 else :
282284 self .venv_wrapped = reward_wrapper .RewardVecEnvWrapper (
@@ -369,7 +371,7 @@ def update_rewards_of_rollouts(self) -> None:
369371 buffer = self .gen_algo .rollout_buffer
370372 assert buffer is not None
371373 reward_fn_inputs = replay_buffer_wrapper ._rollout_buffer_to_reward_fn_input (
372- self .gen_algo .rollout_buffer
374+ self .gen_algo .rollout_buffer ,
373375 )
374376 rewards = self ._reward_net .predict (** reward_fn_inputs )
375377 rewards = rewards .reshape (buffer .rewards .shape )
@@ -380,13 +382,14 @@ def update_rewards_of_rollouts(self) -> None:
380382 last_dones = last_values == 0.0
381383 self .gen_algo .rollout_buffer .rewards [:] = rewards
382384 self .gen_algo .rollout_buffer .compute_returns_and_advantage (
383- th .tensor (last_values ), last_dones
385+ th .tensor (last_values ),
386+ last_dones ,
384387 )
385388 elif isinstance (self .gen_algo , off_policy_algorithm .OffPolicyAlgorithm ):
386389 buffer = self .gen_algo .replay_buffer
387390 assert buffer is not None
388391 reward_fn_inputs = replay_buffer_wrapper ._replay_buffer_to_reward_fn_input (
389- buffer
392+ buffer ,
390393 )
391394 rewards = self ._reward_net .predict (** reward_fn_inputs )
392395 buffer .rewards [:] = rewards .reshape (buffer .rewards .shape )
@@ -465,13 +468,15 @@ def train_disc(
465468
466469 return train_stats
467470
468- def train_gen (
471+ def train_gen_with_disc (
469472 self ,
470473 total_timesteps : Optional [int ] = None ,
471474 learn_kwargs : Optional [Mapping ] = None ,
472475 ) -> None :
473476 """Trains the generator to maximize the discriminator loss.
474477
478+ The discriminator is also trained after the rollouts are collected and before
479+ the generator is trained.
475480 After the end of training populates the generator replay buffer (used in
476481 discriminator training) with `self.disc_batch_size` transitions.
477482
@@ -502,7 +507,7 @@ def train(
502507 ) -> None :
503508 """Alternates between training the generator and discriminator.
504509
505- Every "round" consists of a call to `train_gen (self.gen_train_timesteps)`,
510+ Every "round" consists of a call to `train_gen_with_disc (self.gen_train_timesteps)`,
506511 a call to `train_disc`, and finally a call to `callback(round)`.
507512
508513 Training ends once an additional "round" would cause the number of transitions
@@ -522,7 +527,7 @@ def train(
522527 f"total_timesteps={ total_timesteps } )!"
523528 )
524529 for r in tqdm .tqdm (range (0 , n_rounds ), desc = "round" ):
525- self .train_gen (self .gen_train_timesteps )
530+ self .train_gen_with_disc (self .gen_train_timesteps )
526531 if callback :
527532 callback (r )
528533 self .logger .dump (self ._global_step )
@@ -610,7 +615,8 @@ def _make_disc_train_batches(
610615 if gen_samples is None :
611616 if self ._gen_replay_buffer .size () == 0 :
612617 raise RuntimeError (
613- "No generator samples for training. " "Call `train_gen()` first." ,
618+ "No generator samples for training. "
619+ "Call `train_gen_with_disc()` first." ,
614620 )
615621 gen_samples_dataclass = self ._gen_replay_buffer .sample (batch_size )
616622 gen_samples = types .dataclass_quick_asdict (gen_samples_dataclass )
0 commit comments