Skip to content

Commit 25d1eef

Browse files
committed
Fix test errors
1 parent a14c7d2 commit 25d1eef

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

src/imitation/algorithms/adversarial/common.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def __init__(self, adversarial_trainer, *args, **kwargs):
101101
"""Builds TrainDiscriminatorCallback.
102102
103103
Args:
104+
adversarial_trainer: The AdversarialTrainer instance in which
105+
this callback will be called.
104106
*args: Passed through to `callbacks.BaseCallback`.
105107
**kwargs: Passed through to `callbacks.BaseCallback`.
106108
"""
@@ -276,7 +278,7 @@ def __init__(
276278
# Would use an identity reward fn here, but RewardFns can't see rewards.
277279
self.venv_wrapped = self.venv_buffering
278280
self.gen_callback: List[callbacks.BaseCallback] = [
279-
self.disc_trainer_callback
281+
self.disc_trainer_callback,
280282
]
281283
else:
282284
self.venv_wrapped = reward_wrapper.RewardVecEnvWrapper(
@@ -369,7 +371,7 @@ def update_rewards_of_rollouts(self) -> None:
369371
buffer = self.gen_algo.rollout_buffer
370372
assert buffer is not None
371373
reward_fn_inputs = replay_buffer_wrapper._rollout_buffer_to_reward_fn_input(
372-
self.gen_algo.rollout_buffer
374+
self.gen_algo.rollout_buffer,
373375
)
374376
rewards = self._reward_net.predict(**reward_fn_inputs)
375377
rewards = rewards.reshape(buffer.rewards.shape)
@@ -380,13 +382,14 @@ def update_rewards_of_rollouts(self) -> None:
380382
last_dones = last_values == 0.0
381383
self.gen_algo.rollout_buffer.rewards[:] = rewards
382384
self.gen_algo.rollout_buffer.compute_returns_and_advantage(
383-
th.tensor(last_values), last_dones
385+
th.tensor(last_values),
386+
last_dones,
384387
)
385388
elif isinstance(self.gen_algo, off_policy_algorithm.OffPolicyAlgorithm):
386389
buffer = self.gen_algo.replay_buffer
387390
assert buffer is not None
388391
reward_fn_inputs = replay_buffer_wrapper._replay_buffer_to_reward_fn_input(
389-
buffer
392+
buffer,
390393
)
391394
rewards = self._reward_net.predict(**reward_fn_inputs)
392395
buffer.rewards[:] = rewards.reshape(buffer.rewards.shape)
@@ -465,13 +468,15 @@ def train_disc(
465468

466469
return train_stats
467470

468-
def train_gen(
471+
def train_gen_with_disc(
469472
self,
470473
total_timesteps: Optional[int] = None,
471474
learn_kwargs: Optional[Mapping] = None,
472475
) -> None:
473476
"""Trains the generator to maximize the discriminator loss.
474477
478+
The discriminator is also trained after the rollouts are collected and before
479+
the generator is trained.
475480
After the end of training populates the generator replay buffer (used in
476481
discriminator training) with `self.disc_batch_size` transitions.
477482
@@ -502,7 +507,7 @@ def train(
502507
) -> None:
503508
"""Alternates between training the generator and discriminator.
504509
505-
Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`,
510+
Every "round" consists of a call to `train_gen_with_disc(self.gen_train_timesteps)`,
506511
a call to `train_disc`, and finally a call to `callback(round)`.
507512
508513
Training ends once an additional "round" would cause the number of transitions
@@ -522,7 +527,7 @@ def train(
522527
f"total_timesteps={total_timesteps})!"
523528
)
524529
for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
525-
self.train_gen(self.gen_train_timesteps)
530+
self.train_gen_with_disc(self.gen_train_timesteps)
526531
if callback:
527532
callback(r)
528533
self.logger.dump(self._global_step)
@@ -610,7 +615,8 @@ def _make_disc_train_batches(
610615
if gen_samples is None:
611616
if self._gen_replay_buffer.size() == 0:
612617
raise RuntimeError(
613-
"No generator samples for training. " "Call `train_gen()` first.",
618+
"No generator samples for training. "
619+
"Call `train_gen_with_disc()` first.",
614620
)
615621
gen_samples_dataclass = self._gen_replay_buffer.sample(batch_size)
616622
gen_samples = types.dataclass_quick_asdict(gen_samples_dataclass)

tests/algorithms/test_adversarial.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,9 @@ def test_train_gen_train_disc_no_crash(
231231
trainer_parametrized: common.AdversarialTrainer,
232232
n_updates: int = 2,
233233
) -> None:
234-
trainer_parametrized.train_gen(n_updates * trainer_parametrized.gen_train_timesteps)
235-
trainer_parametrized.train_disc()
234+
trainer_parametrized.train_gen_with_disc(
235+
n_updates * trainer_parametrized.gen_train_timesteps
236+
)
236237

237238

238239
@pytest.fixture

0 commit comments

Comments
 (0)