Skip to content

Commit 7b8b4bf

Browse files
Upgrade pytype (#801)
* Upgrade pytype and remove workaround for old versions * new pytype need input directory or file * fix np.dtype * ignore typed-dict-error * context manager related fix * keep pytype checking more failures * Move pytype config to pyproject.toml * Use inputs specified in pyproject.toml * Fix lint * Fix undefined variable * Fix end of file * Fix typo --------- Co-authored-by: Adam Gleave <[email protected]>
1 parent aca4c07 commit 7b8b4bf

File tree

12 files changed

+30
-27
lines changed

12 files changed

+30
-27
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ repos:
7878
name: pytype
7979
language: system
8080
types: [python]
81-
entry: "bash -c 'pytype -j ${NUM_CPUS:-auto}'"
81+
entry: "bash -c 'pytype --keep-going -j ${NUM_CPUS:-auto}'"
8282
require_serial: true
8383
verbose: true
8484
- id: docs

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,12 @@ build-backend = "setuptools.build_meta"
1414

1515
[tool.black]
1616
target-version = ["py38"]
17+
18+
[tool.pytype]
19+
inputs = [
20+
"src/",
21+
"tests/",
22+
"experiments/",
23+
"setup.py"
24+
]
25+
python_version = "3.8"

setup.cfg

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,3 @@ omit =
5959
source =
6060
src/imitation
6161
*venv/lib/python*/site-packages/imitation
62-
63-
[pytype]
64-
inputs =
65-
src/
66-
tests/
67-
experiments/
68-
setup.py
69-
python_version >= 3.8

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
ATARI_REQUIRE = [
1818
"seals[atari]~=0.2.1",
1919
]
20-
PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else []
20+
PYTYPE = ["pytype==2023.9.27"] if IS_NOT_WINDOWS else []
2121

2222
# Note: the versions of the test and doc requirements should be tightly pinned to known
2323
# working versions to make our CI/CD pipeline as stable as possible.

src/imitation/scripts/eval_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def eval_policy(
9696
sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes)
9797
post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None
9898
render_mode = "rgb_array" if videos else None
99-
with environment.make_venv(
99+
with environment.make_venv( # type: ignore[wrong-keyword-args]
100100
post_wrappers=post_wrappers,
101101
env_make_kwargs=dict(render_mode=render_mode),
102102
) as venv:

src/imitation/scripts/ingredients/demonstrations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ def _generate_expert_trajs(
143143
raise ValueError("n_expert_demos must be specified when generating demos.")
144144

145145
logger.info(f"Generating {n_expert_demos} expert trajectories")
146-
with environment.make_rollout_venv() as rollout_env:
146+
with environment.make_rollout_venv() as env: # type: ignore[wrong-arg-count]
147147
return rollout.rollout(
148-
expert.get_expert_policy(rollout_env),
149-
rollout_env,
148+
expert.get_expert_policy(env),
149+
env,
150150
rollout.make_sample_until(min_episodes=n_expert_demos),
151151
rng=_rnd,
152152
)

src/imitation/scripts/train_adversarial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def train_adversarial(
119119
custom_logger, log_dir = logging_ingredient.setup_logging()
120120
expert_trajs = demonstrations.get_expert_trajectories()
121121

122-
with environment.make_venv() as venv:
122+
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
123123
reward_net = reward.make_reward_net(venv)
124124
relabel_reward_fn = functools.partial(
125125
reward_net.predict_processed,

src/imitation/scripts/train_imitation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def bc(
7373
custom_logger, log_dir = logging_ingredient.setup_logging()
7474

7575
expert_trajs = demonstrations.get_expert_trajectories()
76-
with environment.make_venv() as venv:
76+
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
7777
bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger)
7878

7979
bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"])
@@ -115,7 +115,7 @@ def dagger(
115115
if dagger["use_offline_rollouts"]:
116116
expert_trajs = demonstrations.get_expert_trajectories()
117117

118-
with environment.make_venv() as venv:
118+
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
119119
bc_trainer = bc_ingredient.make_bc(venv, expert_trajs, custom_logger)
120120

121121
bc_train_kwargs = dict(log_rollouts_venv=venv, **bc["train_kwargs"])
@@ -161,7 +161,7 @@ def sqil(
161161
custom_logger, log_dir = logging_ingredient.setup_logging()
162162
expert_trajs = demonstrations.get_expert_trajectories()
163163

164-
with environment.make_venv() as venv:
164+
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
165165
sqil_trainer = sqil_algorithm.SQIL(
166166
venv=venv,
167167
demonstrations=expert_trajs,

src/imitation/scripts/train_preference_comparisons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def train_preference_comparisons(
166166

167167
custom_logger, log_dir = logging_ingredient.setup_logging()
168168

169-
with environment.make_venv() as venv:
169+
with environment.make_venv() as venv: # type: ignore[wrong-arg-count]
170170
reward_net = reward.make_reward_net(venv)
171171
relabel_reward_fn = functools.partial(
172172
reward_net.predict_processed,

src/imitation/scripts/train_rl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def train_rl(
9999
policy_dir.mkdir(parents=True, exist_ok=True)
100100

101101
post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)]
102-
with environment.make_venv(post_wrappers=post_wrappers) as venv:
102+
with environment.make_venv( # type: ignore[wrong-keyword-args]
103+
post_wrappers=post_wrappers,
104+
) as venv:
103105
callback_objs = []
104106
if reward_type is not None:
105107
reward_fn = load_reward(

0 commit comments

Comments
 (0)