@@ -114,6 +114,8 @@ def _on_step(self) -> bool:
114114 return True
115115
116116 def _on_rollout_end (self ) -> None :
117+ if self .gen_ctx_manager is not None :
118+ self .exit_gen_ctx_manager ()
117119 gen_trajs , ep_lens = self .adversarial_trainer .venv_buffering .pop_trajectories ()
118120 self .adversarial_trainer ._check_fixed_horizon (ep_lens )
119121 gen_samples = rollout .flatten_trajectories_with_rew (gen_trajs )
@@ -132,9 +134,13 @@ def _on_rollout_end(self) -> None:
132134 self .gen_ctx_manager = self .adversarial_trainer .logger .accumulate_means ("gen" )
133135 self .gen_ctx_manager .__enter__ ()
134136
135- def _on_training_end (self ) -> None :
137+ def exit_gen_ctx_manager (self ) -> None :
136138 assert self .gen_ctx_manager is not None
137139 self .gen_ctx_manager .__exit__ (None , None , None )
140+ self .gen_ctx_manager = None
141+
142+ def _on_training_end (self ) -> None :
143+ self .exit_gen_ctx_manager ()
138144
139145
140146class AdversarialTrainer (base .DemonstrationAlgorithm [types .Transitions ]):
@@ -507,8 +513,8 @@ def train(
507513 ) -> None :
508514 """Alternates between training the generator and discriminator.
509515
510- Every "round" consists of a call to `train_gen_with_disc(self.gen_train_timesteps)`,
511- a call to `train_disc`, and finally a call to `callback(round)`.
516+ Every "round" consists of a call to
517+ `train_gen_with_disc(self.gen_train_timesteps)` and a call to `callback(round)`.
512518
513519 Training ends once an additional "round" would cause the number of transitions
514520 sampled from the environment to exceed `total_timesteps`.
0 commit comments