Skip to content

Commit 459abc3

Browse files
committed
fix lint errors
1 parent 771179c commit 459abc3

File tree

8 files changed

+11
-14
lines changed

8 files changed

+11
-14
lines changed

ci/clean_notebooks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def clean_notebook(file: pathlib.Path, check_only=False) -> None:
6363
print(f"Checking {file}")
6464

6565
for cell in nb.cells:
66-
6766
# Remove empty cells
6867
if cell["cell_type"] == "code" and not cell["source"]:
6968
if check_only:

src/imitation/algorithms/adversarial/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import torch.utils.tensorboard as thboard
1010
import tqdm
1111
from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env
12-
from stable_baselines3.common.type_aliases import MaybeCallback
1312
from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback
13+
from stable_baselines3.common.type_aliases import MaybeCallback
1414
from stable_baselines3.sac import policies as sac_policies
1515
from torch.nn import functional as F
1616

src/imitation/algorithms/bc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __post_init__(self) -> None:
5757

5858
def __iter__(self) -> Iterator[types.TransitionMapping]:
5959
def batch_iterator() -> Iterator[types.TransitionMapping]:
60-
6160
# Note: the islice here ensures we do not exceed self.n_epochs
6261
for epoch_num in itertools.islice(itertools.count(), self.n_epochs):
6362
some_batch_was_yielded = False

src/imitation/data/huggingface_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class TrajectoryDatasetSequence(Sequence[types.Trajectory]):
1616

1717
def __init__(self, dataset: datasets.Dataset):
1818
"""Construct a TrajectoryDatasetSequence."""
19+
1920
# TODO: this is just a temporary workaround for
2021
# https://github.com/huggingface/datasets/issues/5517
2122
# switch to .with_format("numpy") once it's fixed
@@ -31,7 +32,6 @@ def __len__(self) -> int:
3132
return len(self._dataset)
3233

3334
def __getitem__(self, idx):
34-
3535
if isinstance(idx, slice):
3636
# Note: we could use self._dataset[idx] here and then convert the result of
3737
# that to a series of trajectories, but if we do that, we run into trouble

src/imitation/scripts/train_adversarial.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424

2525

2626
class CheckpointCallback(BaseCallback):
27+
"""A callback for calling `save` at regular intervals."""
28+
2729
def __init__(
2830
self,
2931
trainer: common.AdversarialTrainer,
3032
log_dir: pathlib.Path,
31-
interval: int
33+
interval: int,
3234
):
35+
"""Creates new Checkpoint callback."""
3336
super().__init__(self)
3437
self.trainer = trainer
3538
self.log_dir = log_dir

tests/algorithms/test_adversarial.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,11 +468,10 @@ def test_regression_gail_with_sac(
468468

469469

470470
def test_gen_callback(trainer: common.AdversarialTrainer):
471-
learner = stable_baselines3.PPO("MlpPolicy", env=trainer.venv)
472-
473471
def make_fn_callback(calls, key):
474472
def cb(_a, _b):
475473
calls[key] += 1
474+
476475
return cb
477476

478477
class SB3Callback(BaseCallback):
@@ -490,10 +489,10 @@ def _on_step(self):
490489

491490
trainer.train(n_steps, callback=make_fn_callback(calls, "fn"))
492491
trainer.train(n_steps, callback=SB3Callback(calls, "sb3"))
493-
trainer.train(n_steps, callback=[
494-
SB3Callback(calls, "list.0"),
495-
SB3Callback(calls, "list.1")
496-
])
492+
trainer.train(
493+
n_steps,
494+
callback=[SB3Callback(calls, "list.0"), SB3Callback(calls, "list.1")],
495+
)
497496

498497
# Env steps for off-plicy algos (DQN) may exceed `total_timesteps`,
499498
# so we check if the callback was called *at least* that many times.

tests/data/test_huggingface_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def test_sliced_access(data: st.DataObject, trajectories: Sequence[types.Traject
7171

7272
# Note: we test if for 10 slices at a time because creating the dataset is slow
7373
for _ in range(10):
74-
7574
# GIVEN
7675
the_slice = data.draw(slices_strategy)
7776
indices_of_slice = list(range(*the_slice.indices(len(trajectories))))
@@ -113,7 +112,6 @@ def test_sliced_info_dict_access(
113112

114113
# Note: we test if for 10 slices at a time because creating the dataset is slow
115114
for _ in range(10):
116-
117115
# GIVEN
118116
the_slice = data.draw(slices_strategy)
119117
indices_of_slice = list(range(*the_slice.indices(len(wrapped_info_dicts))))

tests/util/test_wb_logger.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def log(
6262
commit: bool = False,
6363
sync: bool = False,
6464
):
65-
6665
assert self._initialized
6766
if sync:
6867
raise NotImplementedError("usage of sync to MockWandb.log not implemented")

0 commit comments

Comments
 (0)