Skip to content

Commit 17f7064

Browse files
authored
feat: infer type of parameter when user doesn't specify type (#403)
* feat: infer type of parameter when user doesn't specify type * chore: delete _infer_task_on_kart_attr_init_type
1 parent 599837b commit 17f7064

File tree

3 files changed

+161
-53
lines changed

3 files changed

+161
-53
lines changed

gokart/mypy.py

Lines changed: 102 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,23 @@
1010
from typing import Callable, Final, Iterator, Literal, Optional
1111

1212
import luigi
13-
from mypy.expandtype import expand_type, expand_type_by_instance
13+
from mypy.expandtype import expand_type
1414
from mypy.nodes import (
1515
ARG_NAMED_OPT,
16-
ARG_POS,
1716
Argument,
1817
AssignmentStmt,
1918
Block,
2019
CallExpr,
2120
ClassDef,
22-
Context,
2321
EllipsisExpr,
2422
Expression,
25-
FuncDef,
2623
IfStmt,
2724
JsonDict,
2825
MemberExpr,
2926
NameExpr,
3027
PlaceholderNode,
3128
RefExpr,
3229
Statement,
33-
SymbolTableNode,
3430
TempNode,
3531
TypeInfo,
3632
Var,
@@ -45,12 +41,11 @@
4541
from mypy.typeops import map_type_from_supertype
4642
from mypy.types import (
4743
AnyType,
48-
CallableType,
4944
Instance,
5045
NoneType,
5146
Type,
5247
TypeOfAny,
53-
get_proper_type,
48+
UnionType,
5449
)
5550
from mypy.typevars import fill_typevars
5651

@@ -327,7 +322,11 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:
327322

328323
current_attr_names.add(lhs.name)
329324
with state.strict_optional_set(self._api.options.strict_optional):
330-
init_type = self._infer_task_on_kart_attr_init_type(sym, stmt)
325+
init_type = sym.type
326+
327+
# infer Parameter type
328+
if init_type is None:
329+
init_type = self._infer_type_from_parameters(stmt.rvalue)
331330

332331
found_attrs[lhs.name] = TaskOnKartAttribute(
333332
name=lhs.name,
@@ -361,65 +360,115 @@ def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Exp
361360
return True, args
362361
return False, {}
363362

364-
def _infer_task_on_kart_attr_init_type(self, sym: SymbolTableNode, context: Context) -> Type | None:
365-
"""Infer __init__ argument type for an attribute.
363+
def _infer_type_from_parameters(self, parameter: Expression) -> Optional[Type]:
364+
"""
365+
Generate default type from Parameter.
366+
For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type.
367+
"""
368+
parameter_name = _extract_parameter_name(parameter)
369+
if parameter_name is None:
370+
return None
371+
372+
underlying_type: Optional[Type] = None
373+
if parameter_name in ['luigi.parameter.Parameter', 'luigi.parameter.OptionalParameter']:
374+
underlying_type = self._api.named_type('builtins.str', [])
375+
elif parameter_name in ['luigi.parameter.IntParameter', 'luigi.parameter.OptionalIntParameter']:
376+
underlying_type = self._api.named_type('builtins.int', [])
377+
elif parameter_name in ['luigi.parameter.FloatParameter', 'luigi.parameter.OptionalFloatParameter']:
378+
underlying_type = self._api.named_type('builtins.float', [])
379+
elif parameter_name in ['luigi.parameter.BoolParameter', 'luigi.parameter.OptionalBoolParameter']:
380+
underlying_type = self._api.named_type('builtins.bool', [])
381+
elif parameter_name in ['luigi.parameter.DateParameter', 'luigi.parameter.MonthParameter', 'luigi.parameter.YearParameter']:
382+
underlying_type = self._api.named_type('datetime.date', [])
383+
elif parameter_name in ['luigi.parameter.DateHourParameter', 'luigi.parameter.DateMinuteParameter', 'luigi.parameter.DateSecondParameter']:
384+
underlying_type = self._api.named_type('datetime.datetime', [])
385+
elif parameter_name in ['luigi.parameter.TimeDeltaParameter']:
386+
underlying_type = self._api.named_type('datetime.timedelta', [])
387+
elif parameter_name in ['luigi.parameter.DictParameter', 'luigi.parameter.OptionalDictParameter']:
388+
underlying_type = self._api.named_type('builtins.dict', [AnyType(TypeOfAny.unannotated), AnyType(TypeOfAny.unannotated)])
389+
elif parameter_name in ['luigi.parameter.ListParameter', 'luigi.parameter.OptionalListParameter']:
390+
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)])
391+
elif parameter_name in ['luigi.parameter.TupleParameter', 'luigi.parameter.OptionalTupleParameter']:
392+
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)])
393+
elif parameter_name in ['luigi.parameter.PathParameter', 'luigi.parameter.OptionalPathParameter']:
394+
underlying_type = self._api.named_type('pathlib.Path', [])
395+
elif parameter_name in ['gokart.parameter.TaskInstanceParameter']:
396+
underlying_type = self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])
397+
elif parameter_name in ['gokart.parameter.ListTaskInstanceParameter']:
398+
underlying_type = self._api.named_type('builtins.list', [self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])])
399+
elif parameter_name in ['gokart.parameter.ExplicitBoolParameter']:
400+
underlying_type = self._api.named_type('builtins.bool', [])
401+
elif parameter_name in ['luigi.parameter.NumericalParameter']:
402+
underlying_type = self._get_type_from_args(parameter, 'var_type')
403+
elif parameter_name in ['luigi.parameter.ChoiceParameter']:
404+
underlying_type = self._get_type_from_args(parameter, 'var_type')
405+
elif parameter_name in ['luigi.parameter.ChoiceListPareameter']:
406+
base_type = self._get_type_from_args(parameter, 'var_type')
407+
if base_type is not None:
408+
underlying_type = self._api.named_type('builtins.tuple', [base_type])
409+
elif parameter_name in ['luigi.parameter.EnumParameter']:
410+
underlying_type = self._get_type_from_args(parameter, 'enum')
411+
elif parameter_name in ['luigi.parameter.EnumListParameter']:
412+
base_type = self._get_type_from_args(parameter, 'enum')
413+
if base_type is not None:
414+
underlying_type = self._api.named_type('builtins.tuple', [base_type])
415+
416+
if underlying_type is None:
417+
return None
418+
419+
# When parameter has Optional, it can be none value.
420+
if 'Optional' in parameter_name:
421+
return UnionType([underlying_type, NoneType()])
422+
423+
return underlying_type
424+
425+
def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Optional[Type]:
426+
"""
427+
get type from parameter arguments.
366428
367-
In particular, possibly use the signature of __set__.
429+
e.x)
430+
When parameter is `luigi.ChoiceParameter(var_type=int)`, this method should return `int` type.
368431
"""
369-
default = sym.type
370-
if sym.implicit:
371-
return default
372-
t = get_proper_type(sym.type)
373-
374-
# Perform a simple-minded inference from the signature of __set__, if present.
375-
# We can't use mypy.checkmember here, since this plugin runs before type checking.
376-
# We only support some basic scanerios here, which is hopefully sufficient for
377-
# the vast majority of use cases.
378-
if not isinstance(t, Instance):
379-
return default
380-
setter = t.type.get('__set__')
381-
382-
if not setter:
383-
return default
384-
385-
if isinstance(setter.node, FuncDef):
386-
super_info = t.type.get_containing_type_info('__set__')
387-
assert super_info
388-
if setter.type:
389-
setter_type = get_proper_type(map_type_from_supertype(setter.type, t.type, super_info))
390-
else:
391-
return AnyType(TypeOfAny.unannotated)
392-
if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [
393-
ARG_POS,
394-
ARG_POS,
395-
ARG_POS,
396-
]:
397-
return expand_type_by_instance(setter_type.arg_types[2], t)
398-
else:
399-
self._api.fail(f'Unsupported signature for "__set__" in "{t.type.name}"', context)
400-
else:
401-
self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context)
402-
403-
return default
432+
ok, args = self._collect_parameter_args(parameter)
433+
if not ok:
434+
return None
435+
436+
if arg_key not in args:
437+
return None
438+
439+
arg = args[arg_key]
440+
if not isinstance(arg, NameExpr):
441+
return None
442+
if not isinstance(arg.node, TypeInfo):
443+
return None
444+
return Instance(arg.node, [])
404445

405446

406447
def is_parameter_call(expr: Expression) -> bool:
407448
"""Checks if the expression is a call to luigi.Parameter()"""
408-
if not isinstance(expr, CallExpr):
449+
parameter_name = _extract_parameter_name(expr)
450+
if parameter_name is None:
409451
return False
452+
return PARAMETER_FULLNAME_MATCHER.match(parameter_name) is not None
453+
454+
455+
def _extract_parameter_name(expr: Expression) -> Optional[str]:
456+
"""Extract name if the expression is a call to luigi.Parameter()"""
457+
if not isinstance(expr, CallExpr):
458+
return None
410459

411460
callee = expr.callee
412461
if isinstance(callee, MemberExpr):
413462
type_info = callee.node
414463
if type_info is None and isinstance(callee.expr, NameExpr):
415-
return PARAMETER_FULLNAME_MATCHER.match(f'{callee.expr.name}.{callee.name}') is not None
464+
return f'{callee.expr.name}.{callee.name}'
416465
elif isinstance(callee, NameExpr):
417466
type_info = callee.node
418467
else:
419-
return False
468+
return None
420469

421470
if isinstance(type_info, TypeInfo):
422-
return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None
471+
return type_info.fullname
423472

424473
# Currently, luigi doesn't provide py.typed. it will be released next to 3.5.1.
425474
# https://github.com/spotify/luigi/pull/3297
@@ -429,8 +478,9 @@ def is_parameter_call(expr: Expression) -> bool:
429478
# class MyTask(gokart.TaskOnKart):
430479
# param = Parameter()
431480
if isinstance(type_info, Var) and luigi.__version__ <= '3.5.1':
432-
return PARAMETER_TMP_MATCHER.match(type_info.name) is not None
433-
return False
481+
return type_info.name
482+
483+
return None
434484

435485

436486
def plugin(version: str) -> type[Plugin]:

test/test_mypy.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,61 @@ class MyTask(gokart.TaskOnKart):
6464
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
6565
self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
6666
self.assertIn('Found 3 errors in 1 file (checked 1 source file)', result[0])
67+
68+
def test_parameter_has_default_type_invalid_pattern(self):
69+
"""
70+
If user doesn't set the type of the parameter, mypy infer the default type from Parameter types.
71+
"""
72+
test_code = """
73+
import enum
74+
import luigi
75+
import gokart
76+
77+
78+
class MyEnum(enum.Enum):
79+
FOO = enum.auto()
80+
81+
class MyTask(gokart.TaskOnKart):
82+
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
83+
foo = luigi.IntParameter()
84+
bar = luigi.DateParameter()
85+
baz = gokart.TaskInstanceParameter()
86+
qux = luigi.NumericalParameter(var_type=int)
87+
quux = luigi.ChoiceParameter(choices=[1, 2, 3], var_type=int)
88+
corge = luigi.EnumParameter(enum=MyEnum)
89+
90+
MyTask(foo="1", bar=1, baz=1, qux='1', quux='1', corge=1)
91+
"""
92+
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
93+
test_file.write(test_code.encode('utf-8'))
94+
test_file.flush()
95+
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
96+
self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
97+
self.assertIn('error: Argument "bar" to "MyTask" has incompatible type "int"; expected "date" [arg-type]', result[0])
98+
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "int"; expected "TaskOnKart[Any]" [arg-type]', result[0])
99+
self.assertIn('error: Argument "qux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
100+
self.assertIn('error: Argument "quux" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
101+
self.assertIn('error: Argument "corge" to "MyTask" has incompatible type "int"; expected "MyEnum" [arg-type]', result[0])
102+
103+
def test_parameter_has_default_type_no_issue_pattern(self):
104+
"""
105+
If user doesn't set the type of the parameter, mypy infer the default type from Parameter types.
106+
"""
107+
test_code = """
108+
from datetime import date
109+
import luigi
110+
import gokart
111+
112+
class MyTask(gokart.TaskOnKart):
113+
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
114+
foo = luigi.IntParameter()
115+
bar = luigi.DateParameter()
116+
baz = gokart.TaskInstanceParameter()
117+
118+
MyTask(foo=1, bar=date.today(), baz=gokart.TaskOnKart())
119+
"""
120+
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
121+
test_file.write(test_code.encode('utf-8'))
122+
test_file.flush()
123+
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
124+
self.assertIn('Success: no issues found', result[0])

test/test_task_instance_parameter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class _DummyPipelineD(TaskOnKart):
9090
subtask = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask)
9191

9292
with self.assertRaises(TypeError):
93-
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask])
93+
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask()])
9494

9595

9696
if __name__ == '__main__':

0 commit comments

Comments
 (0)