Skip to content

Commit 1d5739e

Browse files
committed
fix mypy errors
1 parent 7da3e8f commit 1d5739e

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/imitation/algorithms/adversarial/common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import abc
33
import dataclasses
44
import logging
5-
from typing import Iterable, Iterator, Mapping, Optional, Type, overload
5+
from typing import Iterable, Iterator, Mapping, Optional, Type, List, overload
66

77
import numpy as np
88
import torch as th
@@ -408,7 +408,10 @@ def train_gen(
408408
if learn_kwargs is None:
409409
learn_kwargs = {}
410410

411-
callbacks = [self.gen_callback]
411+
callbacks: List[BaseCallback] = []
412+
413+
if self.gen_callback:
414+
callbacks.append(self.gen_callback)
412415

413416
if isinstance(callback, list):
414417
callbacks.extend(callback)

src/imitation/scripts/train_adversarial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
interval: int,
3434
):
3535
"""Creates new Checkpoint callback."""
36-
super().__init__(self)
36+
super().__init__()
3737
self.trainer = trainer
3838
self.log_dir = log_dir
3939
self.interval = interval

0 commit comments

Comments
 (0)