Skip to content

Commit 5c23650

Browse files
committed
Fix test errors
1 parent 25d1eef commit 5c23650

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/imitation/algorithms/adversarial/common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def _on_step(self) -> bool:
114114
return True
115115

116116
def _on_rollout_end(self) -> None:
117+
if self.gen_ctx_manager is not None:
118+
self.exit_gen_ctx_manager()
117119
gen_trajs, ep_lens = self.adversarial_trainer.venv_buffering.pop_trajectories()
118120
self.adversarial_trainer._check_fixed_horizon(ep_lens)
119121
gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs)
@@ -132,9 +134,13 @@ def _on_rollout_end(self) -> None:
132134
self.gen_ctx_manager = self.adversarial_trainer.logger.accumulate_means("gen")
133135
self.gen_ctx_manager.__enter__()
134136

135-
def _on_training_end(self) -> None:
137+
def exit_gen_ctx_manager(self) -> None:
136138
assert self.gen_ctx_manager is not None
137139
self.gen_ctx_manager.__exit__(None, None, None)
140+
self.gen_ctx_manager = None
141+
142+
def _on_training_end(self) -> None:
143+
self.exit_gen_ctx_manager()
138144

139145

140146
class AdversarialTrainer(base.DemonstrationAlgorithm[types.Transitions]):
@@ -507,8 +513,8 @@ def train(
507513
) -> None:
508514
"""Alternates between training the generator and discriminator.
509515
510-
Every "round" consists of a call to `train_gen_with_disc(self.gen_train_timesteps)`,
511-
a call to `train_disc`, and finally a call to `callback(round)`.
516+
Every "round" consists of a call to
517+
`train_gen_with_disc(self.gen_train_timesteps)` and a call to `callback(round)`.
512518
513519
Training ends once an additional "round" would cause the number of transitions
514520
sampled from the environment to exceed `total_timesteps`.

tests/algorithms/test_adversarial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_train_gen_train_disc_no_crash(
232232
n_updates: int = 2,
233233
) -> None:
234234
trainer_parametrized.train_gen_with_disc(
235-
n_updates * trainer_parametrized.gen_train_timesteps
235+
n_updates * trainer_parametrized.gen_train_timesteps,
236236
)
237237

238238

0 commit comments

Comments
 (0)