Skip to content

Commit 0480c35

Browse files
committed
track if typing.TYPE_CHECKING to warn about non runtime bindings
When importing or defining values in ``if typing.TYPE_CHECKING`` blocks the bound names will not be available at runtime and may cause errors when used in the following way:: import typing if typing.TYPE_CHECKING: from module import Type # some slow import or circular reference def method(value) -> Type: # the import is needed by the type checker assert isinstance(value, Type) # this is a runtime error This change allows pyflakes to track what names are bound for runtime use, and allows it to warn when a non runtime name is used in a runtime context.
1 parent 59ec459 commit 0480c35

File tree

3 files changed

+176
-39
lines changed

3 files changed

+176
-39
lines changed

pyflakes/checker.py

Lines changed: 102 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,11 @@ class Binding:
215215
the node that this binding was last used.
216216
"""
217217

218-
def __init__(self, name, source):
218+
def __init__(self, name, source, *, runtime=True):
219219
self.name = name
220220
self.source = source
221221
self.used = False
222+
self.runtime = runtime
222223

223224
def __str__(self):
224225
return self.name
@@ -249,8 +250,8 @@ def redefines(self, other):
249250
class Builtin(Definition):
250251
"""A definition created for all Python builtins."""
251252

252-
def __init__(self, name):
253-
super().__init__(name, None)
253+
def __init__(self, name, *, runtime=True):
254+
super().__init__(name, None, runtime=runtime)
254255

255256
def __repr__(self):
256257
return '<{} object {!r} at 0x{:x}>'.format(
@@ -294,10 +295,10 @@ class Importation(Definition):
294295
@type fullName: C{str}
295296
"""
296297

297-
def __init__(self, name, source, full_name=None):
298+
def __init__(self, name, source, full_name=None, *, runtime=True):
298299
self.fullName = full_name or name
299300
self.redefined = []
300-
super().__init__(name, source)
301+
super().__init__(name, source, runtime=runtime)
301302

302303
def redefines(self, other):
303304
if isinstance(other, SubmoduleImportation):
@@ -342,11 +343,11 @@ class SubmoduleImportation(Importation):
342343
name is also the same, to avoid false positives.
343344
"""
344345

345-
def __init__(self, name, source):
346+
def __init__(self, name, source, *, runtime=True):
346347
# A dot should only appear in the name when it is a submodule import
347348
assert '.' in name and (not source or isinstance(source, ast.Import))
348349
package_name = name.split('.')[0]
349-
super().__init__(package_name, source)
350+
super().__init__(package_name, source, runtime=runtime)
350351
self.fullName = name
351352

352353
def redefines(self, other):
@@ -364,7 +365,9 @@ def source_statement(self):
364365

365366
class ImportationFrom(Importation):
366367

367-
def __init__(self, name, source, module, real_name=None):
368+
def __init__(
369+
self, name, source, module, real_name=None, *, runtime=True
370+
):
368371
self.module = module
369372
self.real_name = real_name or name
370373

@@ -373,7 +376,7 @@ def __init__(self, name, source, module, real_name=None):
373376
else:
374377
full_name = module + '.' + self.real_name
375378

376-
super().__init__(name, source, full_name)
379+
super().__init__(name, source, full_name, runtime=runtime)
377380

378381
def __str__(self):
379382
"""Return import full name with alias."""
@@ -393,8 +396,8 @@ def source_statement(self):
393396
class StarImportation(Importation):
394397
"""A binding created by a 'from x import *' statement."""
395398

396-
def __init__(self, name, source):
397-
super().__init__('*', source)
399+
def __init__(self, name, source, *, runtime=True):
400+
super().__init__('*', source, runtime=runtime)
398401
# Each star importation needs a unique name, and
399402
# may not be the module name otherwise it will be deemed imported
400403
self.name = name + '.*'
@@ -483,7 +486,7 @@ class ExportBinding(Binding):
483486
C{__all__} will not have an unused import warning reported for them.
484487
"""
485488

486-
def __init__(self, name, source, scope):
489+
def __init__(self, name, source, scope, *, runtime=True):
487490
if '__all__' in scope and isinstance(source, ast.AugAssign):
488491
self.names = list(scope['__all__'].names)
489492
else:
@@ -514,7 +517,7 @@ def _add_to_names(container):
514517
# If not list concatenation
515518
else:
516519
break
517-
super().__init__(name, source)
520+
super().__init__(name, source, runtime=runtime)
518521

519522

520523
class Scope(dict):
@@ -722,10 +725,6 @@ class Checker:
722725
ast.DictComp: GeneratorScope,
723726
}
724727

725-
nodeDepth = 0
726-
offset = None
727-
_in_annotation = AnnotationState.NONE
728-
729728
builtIns = set(builtin_vars).union(_MAGIC_GLOBALS)
730729
_customBuiltIns = os.environ.get('PYFLAKES_BUILTINS')
731730
if _customBuiltIns:
@@ -734,6 +733,10 @@ class Checker:
734733

735734
def __init__(self, tree, filename='(none)', builtins=None,
736735
withDoctest='PYFLAKES_DOCTEST' in os.environ, file_tokens=()):
736+
self.nodeDepth = 0
737+
self.offset = None
738+
self._in_annotation = AnnotationState.NONE
739+
self._in_type_check_guard = False
737740
self._nodeHandlers = {}
738741
self._deferred = collections.deque()
739742
self.deadScopes = []
@@ -1000,9 +1003,11 @@ def addBinding(self, node, value):
10001003
# then assume the rebound name is used as a global or within a loop
10011004
value.used = self.scope[value.name].used
10021005

1003-
# don't treat annotations as assignments if there is an existing value
1004-
# in scope
1005-
if value.name not in self.scope or not isinstance(value, Annotation):
1006+
# always allow the first assignment or if not already a runtime value,
1007+
# but do not shadow an existing assignment with an annotation or non
1008+
# runtime value.
1009+
if (not existing or not existing.runtime
1010+
or (not isinstance(value, Annotation) and value.runtime)):
10061011
if isinstance(value, NamedExprAssignment):
10071012
# PEP 572: use scope in which outermost generator is defined
10081013
scope = next(
@@ -1080,20 +1085,28 @@ def handleNodeLoad(self, node, parent):
10801085
self.report(messages.InvalidPrintSyntax, node)
10811086

10821087
try:
1083-
scope[name].used = (self.scope, node)
1088+
binding = scope[name]
1089+
except KeyError:
1090+
pass
1091+
else:
1092+
# check if the binding is used in the wrong context
1093+
if (not binding.runtime
1094+
and not (self._in_type_check_guard or self._in_annotation)):
1095+
self.report(messages.TypeCheckingOnly, node, name)
1096+
return
1097+
1098+
# mark the binding as used
1099+
binding.used = (self.scope, node)
10841100

10851101
# if the name of SubImportation is same as
10861102
# alias of other Importation and the alias
10871103
# is used, SubImportation also should be marked as used.
1088-
n = scope[name]
1089-
if isinstance(n, Importation) and n._has_alias():
1104+
if isinstance(binding, Importation) and binding._has_alias():
10901105
try:
1091-
scope[n.fullName].used = (self.scope, node)
1106+
scope[binding.fullName].used = (self.scope, node)
10921107
except KeyError:
10931108
pass
1094-
except KeyError:
1095-
pass
1096-
else:
1109+
10971110
return
10981111

10991112
importStarred = importStarred or scope.importStarred
@@ -1150,12 +1163,13 @@ def handleNodeStore(self, node):
11501163
break
11511164

11521165
parent_stmt = self.getParent(node)
1166+
runtime = not self._in_type_check_guard
11531167
if isinstance(parent_stmt, ast.AnnAssign) and parent_stmt.value is None:
11541168
binding = Annotation(name, node)
11551169
elif isinstance(parent_stmt, (FOR_TYPES, ast.comprehension)) or (
11561170
parent_stmt != node._pyflakes_parent and
11571171
not self.isLiteralTupleUnpacking(parent_stmt)):
1158-
binding = Binding(name, node)
1172+
binding = Binding(name, node, runtime=runtime)
11591173
elif (
11601174
name == '__all__' and
11611175
isinstance(self.scope, ModuleScope) and
@@ -1164,11 +1178,13 @@ def handleNodeStore(self, node):
11641178
(ast.Assign, ast.AugAssign, ast.AnnAssign)
11651179
)
11661180
):
1167-
binding = ExportBinding(name, node._pyflakes_parent, self.scope)
1181+
binding = ExportBinding(
1182+
name, node._pyflakes_parent, self.scope, runtime=runtime
1183+
)
11681184
elif isinstance(parent_stmt, ast.NamedExpr):
1169-
binding = NamedExprAssignment(name, node)
1185+
binding = NamedExprAssignment(name, node, runtime=runtime)
11701186
else:
1171-
binding = Assignment(name, node)
1187+
binding = Assignment(name, node, runtime=runtime)
11721188
self.addBinding(node, binding)
11731189

11741190
def handleNodeDelete(self, node):
@@ -1832,7 +1848,37 @@ def DICT(self, node):
18321848
def IF(self, node):
18331849
if isinstance(node.test, ast.Tuple) and node.test.elts != []:
18341850
self.report(messages.IfTuple, node)
1835-
self.handleChildren(node)
1851+
1852+
# check for typing.TYPE_CHECKING, and if so handle each node specifically
1853+
if_type_checking = _is_typing(node.test, 'TYPE_CHECKING', self.scopeStack)
1854+
if if_type_checking or (
1855+
# handle else TYPE_CHECKING
1856+
isinstance(node.test, ast.UnaryOp)
1857+
and isinstance(node.test.op, ast.Not)
1858+
and _is_typing(node.test.operand, 'TYPE_CHECKING', self.scopeStack)
1859+
):
1860+
self.handleNode(node.test, node)
1861+
_in_type_check_guard = self._in_type_check_guard
1862+
1863+
# update the current TYPE_CHECKING state and handle the if-node(s)
1864+
self._in_type_check_guard = if_type_checking
1865+
if isinstance(node.body, list):
1866+
for child in node.body:
1867+
self.handleNode(child, node)
1868+
else:
1869+
self.handleNode(node.body, node)
1870+
1871+
# update the current TYPE_CHECKING state and handle the else-node(s)
1872+
self._in_type_check_guard = not if_type_checking or _in_type_check_guard
1873+
if isinstance(node.orelse, list):
1874+
for child in node.orelse:
1875+
self.handleNode(child, node)
1876+
else:
1877+
self.handleNode(node.orelse, node)
1878+
1879+
self._in_type_check_guard = _in_type_check_guard
1880+
else:
1881+
self.handleChildren(node)
18361882

18371883
IFEXP = IF
18381884

@@ -1943,7 +1989,12 @@ def FUNCTIONDEF(self, node):
19431989
with self._type_param_scope(node):
19441990
self.LAMBDA(node)
19451991

1946-
self.addBinding(node, FunctionDefinition(node.name, node))
1992+
self.addBinding(
1993+
node,
1994+
FunctionDefinition(
1995+
node.name, node, runtime=not self._in_type_check_guard
1996+
),
1997+
)
19471998
# doctest does not process doctest within a doctest,
19481999
# or in nested functions.
19492000
if (self.withDoctest and
@@ -2028,7 +2079,12 @@ def CLASSDEF(self, node):
20282079
for stmt in node.body:
20292080
self.handleNode(stmt, node)
20302081

2031-
self.addBinding(node, ClassDefinition(node.name, node))
2082+
self.addBinding(
2083+
node,
2084+
ClassDefinition(
2085+
node.name, node, runtime=not self._in_type_check_guard
2086+
),
2087+
)
20322088

20332089
def AUGASSIGN(self, node):
20342090
self.handleNodeLoad(node.target, node)
@@ -2061,12 +2117,17 @@ def TUPLE(self, node):
20612117
LIST = TUPLE
20622118

20632119
def IMPORT(self, node):
2120+
runtime = not self._in_type_check_guard
20642121
for alias in node.names:
20652122
if '.' in alias.name and not alias.asname:
2066-
importation = SubmoduleImportation(alias.name, node)
2123+
importation = SubmoduleImportation(
2124+
alias.name, node, runtime=runtime
2125+
)
20672126
else:
20682127
name = alias.asname or alias.name
2069-
importation = Importation(name, node, alias.name)
2128+
importation = Importation(
2129+
name, node, alias.name, runtime=runtime
2130+
)
20702131
self.addBinding(node, importation)
20712132

20722133
def IMPORTFROM(self, node):
@@ -2078,6 +2139,7 @@ def IMPORTFROM(self, node):
20782139

20792140
module = ('.' * node.level) + (node.module or '')
20802141

2142+
runtime = not self._in_type_check_guard
20812143
for alias in node.names:
20822144
name = alias.asname or alias.name
20832145
if node.module == '__future__':
@@ -2095,10 +2157,11 @@ def IMPORTFROM(self, node):
20952157

20962158
self.scope.importStarred = True
20972159
self.report(messages.ImportStarUsed, node, module)
2098-
importation = StarImportation(module, node)
2160+
importation = StarImportation(module, node, runtime=runtime)
20992161
else:
2100-
importation = ImportationFrom(name, node,
2101-
module, alias.name)
2162+
importation = ImportationFrom(
2163+
name, node, module, alias.name, runtime=runtime
2164+
)
21022165
self.addBinding(node, importation)
21032166

21042167
def TRY(self, node):

pyflakes/messages.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ def __init__(self, filename, loc, name, from_list):
6565
self.message_args = (name, from_list)
6666

6767

68+
class TypeCheckingOnly(Message):
69+
message = 'name only defined for TYPE_CHECKIN: %r'
70+
71+
def __init__(self, filename, loc, name):
72+
Message.__init__(self, filename, loc)
73+
self.message_args = (name,)
74+
75+
6876
class UndefinedName(Message):
6977
message = 'undefined name %r'
7078

0 commit comments

Comments
 (0)