1010from typing import Callable , Final , Iterator , Literal , Optional
1111
1212import luigi
13- from mypy .expandtype import expand_type , expand_type_by_instance
13+ from mypy .expandtype import expand_type
1414from mypy .nodes import (
1515 ARG_NAMED_OPT ,
16- ARG_POS ,
1716 Argument ,
1817 AssignmentStmt ,
1918 Block ,
2019 CallExpr ,
2120 ClassDef ,
22- Context ,
2321 EllipsisExpr ,
2422 Expression ,
25- FuncDef ,
2623 IfStmt ,
2724 JsonDict ,
2825 MemberExpr ,
2926 NameExpr ,
3027 PlaceholderNode ,
3128 RefExpr ,
3229 Statement ,
33- SymbolTableNode ,
3430 TempNode ,
3531 TypeInfo ,
3632 Var ,
4541from mypy .typeops import map_type_from_supertype
4642from mypy .types import (
4743 AnyType ,
48- CallableType ,
4944 Instance ,
5045 NoneType ,
5146 Type ,
5247 TypeOfAny ,
53- get_proper_type ,
48+ UnionType ,
5449)
5550from mypy .typevars import fill_typevars
5651
@@ -327,7 +322,11 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:
327322
328323 current_attr_names .add (lhs .name )
329324 with state .strict_optional_set (self ._api .options .strict_optional ):
330- init_type = self ._infer_task_on_kart_attr_init_type (sym , stmt )
325+ init_type = sym .type
326+
327+ # infer Parameter type
328+ if init_type is None :
329+ init_type = self ._infer_type_from_parameters (stmt .rvalue )
331330
332331 found_attrs [lhs .name ] = TaskOnKartAttribute (
333332 name = lhs .name ,
@@ -361,65 +360,115 @@ def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Exp
361360 return True , args
362361 return False , {}
363362
364- def _infer_task_on_kart_attr_init_type (self , sym : SymbolTableNode , context : Context ) -> Type | None :
365- """Infer __init__ argument type for an attribute.
363+ def _infer_type_from_parameters (self , parameter : Expression ) -> Optional [Type ]:
364+ """
365+ Generate default type from Parameter.
366+ For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type.
367+ """
368+ parameter_name = _extract_parameter_name (parameter )
369+ if parameter_name is None :
370+ return None
371+
372+ underlying_type : Optional [Type ] = None
373+ if parameter_name in ['luigi.parameter.Parameter' , 'luigi.parameter.OptionalParameter' ]:
374+ underlying_type = self ._api .named_type ('builtins.str' , [])
375+ elif parameter_name in ['luigi.parameter.IntParameter' , 'luigi.parameter.OptionalIntParameter' ]:
376+ underlying_type = self ._api .named_type ('builtins.int' , [])
377+ elif parameter_name in ['luigi.parameter.FloatParameter' , 'luigi.parameter.OptionalFloatParameter' ]:
378+ underlying_type = self ._api .named_type ('builtins.float' , [])
379+ elif parameter_name in ['luigi.parameter.BoolParameter' , 'luigi.parameter.OptionalBoolParameter' ]:
380+ underlying_type = self ._api .named_type ('builtins.bool' , [])
381+ elif parameter_name in ['luigi.parameter.DateParameter' , 'luigi.parameter.MonthParameter' , 'luigi.parameter.YearParameter' ]:
382+ underlying_type = self ._api .named_type ('datetime.date' , [])
383+ elif parameter_name in ['luigi.parameter.DateHourParameter' , 'luigi.parameter.DateMinuteParameter' , 'luigi.parameter.DateSecondParameter' ]:
384+ underlying_type = self ._api .named_type ('datetime.datetime' , [])
385+ elif parameter_name in ['luigi.parameter.TimeDeltaParameter' ]:
386+ underlying_type = self ._api .named_type ('datetime.timedelta' , [])
387+ elif parameter_name in ['luigi.parameter.DictParameter' , 'luigi.parameter.OptionalDictParameter' ]:
388+ underlying_type = self ._api .named_type ('builtins.dict' , [AnyType (TypeOfAny .unannotated ), AnyType (TypeOfAny .unannotated )])
389+ elif parameter_name in ['luigi.parameter.ListParameter' , 'luigi.parameter.OptionalListParameter' ]:
390+ underlying_type = self ._api .named_type ('builtins.tuple' , [AnyType (TypeOfAny .unannotated )])
391+ elif parameter_name in ['luigi.parameter.TupleParameter' , 'luigi.parameter.OptionalTupleParameter' ]:
392+ underlying_type = self ._api .named_type ('builtins.tuple' , [AnyType (TypeOfAny .unannotated )])
393+ elif parameter_name in ['luigi.parameter.PathParameter' , 'luigi.parameter.OptionalPathParameter' ]:
394+ underlying_type = self ._api .named_type ('pathlib.Path' , [])
395+ elif parameter_name in ['gokart.parameter.TaskInstanceParameter' ]:
396+ underlying_type = self ._api .named_type ('gokart.task.TaskOnKart' , [AnyType (TypeOfAny .unannotated )])
397+ elif parameter_name in ['gokart.parameter.ListTaskInstanceParameter' ]:
398+ underlying_type = self ._api .named_type ('builtins.list' , [self ._api .named_type ('gokart.task.TaskOnKart' , [AnyType (TypeOfAny .unannotated )])])
399+ elif parameter_name in ['gokart.parameter.ExplicitBoolParameter' ]:
400+ underlying_type = self ._api .named_type ('builtins.bool' , [])
401+ elif parameter_name in ['luigi.parameter.NumericalParameter' ]:
402+ underlying_type = self ._get_type_from_args (parameter , 'var_type' )
403+ elif parameter_name in ['luigi.parameter.ChoiceParameter' ]:
404+ underlying_type = self ._get_type_from_args (parameter , 'var_type' )
405+ elif parameter_name in ['luigi.parameter.ChoiceListPareameter' ]:
406+ base_type = self ._get_type_from_args (parameter , 'var_type' )
407+ if base_type is not None :
408+ underlying_type = self ._api .named_type ('builtins.tuple' , [base_type ])
409+ elif parameter_name in ['luigi.parameter.EnumParameter' ]:
410+ underlying_type = self ._get_type_from_args (parameter , 'enum' )
411+ elif parameter_name in ['luigi.parameter.EnumListParameter' ]:
412+ base_type = self ._get_type_from_args (parameter , 'enum' )
413+ if base_type is not None :
414+ underlying_type = self ._api .named_type ('builtins.tuple' , [base_type ])
415+
416+ if underlying_type is None :
417+ return None
418+
419+ # When parameter has Optional, it can be none value.
420+ if 'Optional' in parameter_name :
421+ return UnionType ([underlying_type , NoneType ()])
422+
423+ return underlying_type
424+
425+ def _get_type_from_args (self , parameter : Expression , arg_key : str ) -> Optional [Type ]:
426+ """
427+ get type from parameter arguments.
366428
367- In particular, possibly use the signature of __set__.
429+ e.x)
430+ When parameter is `luigi.ChoiceParameter(var_type=int)`, this method should return `int` type.
368431 """
369- default = sym .type
370- if sym .implicit :
371- return default
372- t = get_proper_type (sym .type )
373-
374- # Perform a simple-minded inference from the signature of __set__, if present.
375- # We can't use mypy.checkmember here, since this plugin runs before type checking.
376- # We only support some basic scanerios here, which is hopefully sufficient for
377- # the vast majority of use cases.
378- if not isinstance (t , Instance ):
379- return default
380- setter = t .type .get ('__set__' )
381-
382- if not setter :
383- return default
384-
385- if isinstance (setter .node , FuncDef ):
386- super_info = t .type .get_containing_type_info ('__set__' )
387- assert super_info
388- if setter .type :
389- setter_type = get_proper_type (map_type_from_supertype (setter .type , t .type , super_info ))
390- else :
391- return AnyType (TypeOfAny .unannotated )
392- if isinstance (setter_type , CallableType ) and setter_type .arg_kinds == [
393- ARG_POS ,
394- ARG_POS ,
395- ARG_POS ,
396- ]:
397- return expand_type_by_instance (setter_type .arg_types [2 ], t )
398- else :
399- self ._api .fail (f'Unsupported signature for "__set__" in "{ t .type .name } "' , context )
400- else :
401- self ._api .fail (f'Unsupported "__set__" in "{ t .type .name } "' , context )
402-
403- return default
432+ ok , args = self ._collect_parameter_args (parameter )
433+ if not ok :
434+ return None
435+
436+ if arg_key not in args :
437+ return None
438+
439+ arg = args [arg_key ]
440+ if not isinstance (arg , NameExpr ):
441+ return None
442+ if not isinstance (arg .node , TypeInfo ):
443+ return None
444+ return Instance (arg .node , [])
404445
405446
406447def is_parameter_call (expr : Expression ) -> bool :
407448 """Checks if the expression is a call to luigi.Parameter()"""
408- if not isinstance (expr , CallExpr ):
449+ parameter_name = _extract_parameter_name (expr )
450+ if parameter_name is None :
409451 return False
452+ return PARAMETER_FULLNAME_MATCHER .match (parameter_name ) is not None
453+
454+
455+ def _extract_parameter_name (expr : Expression ) -> Optional [str ]:
456+ """Extract name if the expression is a call to luigi.Parameter()"""
457+ if not isinstance (expr , CallExpr ):
458+ return None
410459
411460 callee = expr .callee
412461 if isinstance (callee , MemberExpr ):
413462 type_info = callee .node
414463 if type_info is None and isinstance (callee .expr , NameExpr ):
415- return PARAMETER_FULLNAME_MATCHER . match ( f'{ callee .expr .name } .{ callee .name } ' ) is not None
464+ return f'{ callee .expr .name } .{ callee .name } '
416465 elif isinstance (callee , NameExpr ):
417466 type_info = callee .node
418467 else :
419- return False
468+ return None
420469
421470 if isinstance (type_info , TypeInfo ):
422- return PARAMETER_FULLNAME_MATCHER . match ( type_info .fullname ) is not None
471+ return type_info .fullname
423472
424473 # Currently, luigi doesn't provide py.typed. it will be released next to 3.5.1.
425474 # https://github.com/spotify/luigi/pull/3297
@@ -429,8 +478,9 @@ def is_parameter_call(expr: Expression) -> bool:
429478 # class MyTask(gokart.TaskOnKart):
430479 # param = Parameter()
431480 if isinstance (type_info , Var ) and luigi .__version__ <= '3.5.1' :
432- return PARAMETER_TMP_MATCHER .match (type_info .name ) is not None
433- return False
481+ return type_info .name
482+
483+ return None
434484
435485
436486def plugin (version : str ) -> type [Plugin ]:
0 commit comments