Skip to content

Commit 7126c05

Browse files
committed
Refactor to support typevars, and more tests
1 parent eb9cccf commit 7126c05

File tree

5 files changed

+67
-27
lines changed

5 files changed

+67
-27
lines changed

mypy/plugins/attrs.py

+32-18
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
Var,
4747
is_class_var,
4848
)
49-
from mypy.plugin import FunctionContext, SemanticAnalyzerPluginInterface
49+
from mypy.plugin import SemanticAnalyzerPluginInterface
5050
from mypy.plugins.common import (
5151
_get_argument,
5252
_get_bool_argument,
@@ -1062,27 +1062,41 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
10621062
)
10631063

10641064

1065-
def _get_cls_from_init(t: Type) -> TypeInfo | None:
1066-
proper_type = get_proper_type(t)
1067-
if isinstance(proper_type, CallableType):
1068-
return proper_type.type_object()
1069-
return None
1065+
def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
1066+
"""Provide the proper signature for `attrs.fields`."""
1067+
if ctx.args and len(ctx.args) == 1 and ctx.args[0] and ctx.args[0][0]:
1068+
# <hack>
1069+
assert isinstance(ctx.api, TypeChecker)
1070+
inst_type = ctx.api.expr_checker.accept(ctx.args[0][0])
1071+
# </hack>
1072+
proper_type = get_proper_type(inst_type)
10701073

1074+
if isinstance(proper_type, AnyType): # fields(Any) -> Any
1075+
return ctx.default_signature
1076+
1077+
cls = None
1078+
arg_types = ctx.default_signature.arg_types
1079+
1080+
if isinstance(proper_type, TypeVarType):
1081+
inner = get_proper_type(proper_type.upper_bound)
1082+
if isinstance(inner, Instance):
1083+
# We need to work arg_types to compensate for the attrs stubs.
1084+
arg_types = [inst_type]
1085+
cls = inner.type
1086+
elif isinstance(proper_type, CallableType):
1087+
cls = proper_type.type_object()
10711088

1072-
def fields_function_callback(ctx: FunctionContext) -> Type:
1073-
"""Provide the proper return value for `attrs.fields`."""
1074-
if ctx.arg_types and ctx.arg_types[0] and ctx.arg_types[0][0]:
1075-
first_arg_type = ctx.arg_types[0][0]
1076-
cls = _get_cls_from_init(first_arg_type)
10771089
if cls is not None:
10781090
if MAGIC_ATTR_NAME in cls.names:
10791091
# This is a proper attrs class.
10801092
ret_type = cls.names[MAGIC_ATTR_NAME].type
10811093
if ret_type is not None:
1082-
return ret_type
1083-
else:
1084-
ctx.api.fail(
1085-
f'Argument 1 to "fields" has incompatible type "{format_type_bare(first_arg_type)}"; expected an attrs class',
1086-
ctx.context,
1087-
)
1088-
return ctx.default_return_type
1094+
return ctx.default_signature.copy_modified(
1095+
arg_types=arg_types, ret_type=ret_type
1096+
)
1097+
1098+
ctx.api.fail(
1099+
f'Argument 1 to "fields" has incompatible type "{format_type_bare(proper_type)}"; expected an attrs class',
1100+
ctx.context,
1101+
)
1102+
return ctx.default_signature

mypy/plugins/default.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,13 @@ class DefaultPlugin(Plugin):
3939
"""Type checker plugin that is enabled by default."""
4040

4141
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
42-
from mypy.plugins import attrs, ctypes, singledispatch
42+
from mypy.plugins import ctypes, singledispatch
4343

4444
if fullname == "_ctypes.Array":
4545
return ctypes.array_constructor_callback
4646
elif fullname == "functools.singledispatch":
4747
return singledispatch.create_singledispatch_function_callback
48-
elif fullname in ("attr.fields", "attrs.fields"):
49-
return attrs.fields_function_callback
48+
5049
return None
5150

5251
def get_function_signature_hook(
@@ -56,6 +55,8 @@ def get_function_signature_hook(
5655

5756
if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
5857
return attrs.evolve_function_sig_callback
58+
elif fullname in ("attr.fields", "attrs.fields"):
59+
return attrs.fields_function_sig_callback
5960
return None
6061

6162
def get_method_signature_hook(

test-data/unit/check-plugin-attrs.test

+29-4
Original file line numberDiff line numberDiff line change
@@ -1549,6 +1549,24 @@ takes_attrs_instance(A) # E: Argument 1 to "takes_attrs_instance" has incompati
15491549
[builtins fixtures/plugin_attrs.pyi]
15501550

15511551
[case testAttrsFields]
1552+
import attr
1553+
from attrs import fields as f # Common usage.
1554+
1555+
@attr.define
1556+
class A:
1557+
b: int
1558+
c: str
1559+
1560+
reveal_type(f(A)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
1561+
reveal_type(f(A)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
1562+
reveal_type(f(A).b) # N: Revealed type is "attr.Attribute[builtins.int]"
1563+
f(A).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
1564+
1565+
[builtins fixtures/attr.pyi]
1566+
1567+
[case testAttrsGenericFields]
1568+
from typing import TypeVar
1569+
15521570
import attr
15531571
from attrs import fields
15541572

@@ -1557,21 +1575,28 @@ class A:
15571575
b: int
15581576
c: str
15591577

1560-
reveal_type(fields(A)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
1561-
reveal_type(fields(A)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
1562-
reveal_type(fields(A).b) # N: Revealed type is "attr.Attribute[builtins.int]"
1563-
fields(A).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
1578+
TA = TypeVar('TA', bound=A)
1579+
1580+
def f(t: TA) -> None:
1581+
reveal_type(fields(t)) # N: Revealed type is "Tuple[attr.Attribute[builtins.int], attr.Attribute[builtins.str], fallback=__main__.A.____main___A_AttrsAttributes__]"
1582+
reveal_type(fields(t)[0]) # N: Revealed type is "attr.Attribute[builtins.int]"
1583+
reveal_type(fields(t).b) # N: Revealed type is "attr.Attribute[builtins.int]"
1584+
fields(t).x # E: "____main___A_AttrsAttributes__" has no attribute "x"
1585+
15641586

15651587
[builtins fixtures/attr.pyi]
15661588

15671589
[case testNonattrsFields]
1590+
from typing import Any, cast
15681591
from attrs import fields
15691592

15701593
class A:
15711594
b: int
15721595
c: str
15731596

15741597
fields(A) # E: Argument 1 to "fields" has incompatible type "Type[A]"; expected an attrs class
1598+
fields(None) # E: Argument 1 to "fields" has incompatible type "None"; expected an attrs class
1599+
fields(cast(Any, 42))
15751600

15761601
[builtins fixtures/attr.pyi]
15771602

test-data/unit/lib-stub/attr/__init__.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -248,4 +248,4 @@ def field(
248248
def evolve(inst: _T, **changes: Any) -> _T: ...
249249
def assoc(inst: _T, **changes: Any) -> _T: ...
250250

251-
def fields(cls: _C) -> Any: ...
251+
def fields(cls: type) -> Any: ...

test-data/unit/lib-stub/attrs/__init__.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,4 @@ def field(
130130
def evolve(inst: _T, **changes: Any) -> _T: ...
131131
def assoc(inst: _T, **changes: Any) -> _T: ...
132132

133-
def fields(cls: _C) -> Any: ...
133+
def fields(cls: type) -> Any: ...

0 commit comments

Comments
 (0)