diff --git a/jaclang/compiler/absyntree.py b/jaclang/compiler/absyntree.py index b7524db5a..ea38eb14d 100644 --- a/jaclang/compiler/absyntree.py +++ b/jaclang/compiler/absyntree.py @@ -50,6 +50,10 @@ def __init__(self, kid: Sequence[AstNode]) -> None: self.meta: dict[str, str] = {} self.loc: CodeLocInfo = CodeLocInfo(*self.resolve_tok_range()) + # NOTE: This is only applicable for Expr, However adding it there needs to call the constructor in all the + # subclasses, Adding it here, this needs a review. + self.expr_type: str = "" + @property def sym_tab(self) -> SymbolTable: """Get symbol table.""" diff --git a/jaclang/compiler/passes/main/fuse_typeinfo_pass.py b/jaclang/compiler/passes/main/fuse_typeinfo_pass.py index 8e44c5313..e8a2c7992 100644 --- a/jaclang/compiler/passes/main/fuse_typeinfo_pass.py +++ b/jaclang/compiler/passes/main/fuse_typeinfo_pass.py @@ -6,15 +6,16 @@ from __future__ import annotations -from typing import Callable, TypeVar +from types import MethodType +from typing import Callable, Optional, TypeVar import jaclang.compiler.absyntree as ast from jaclang.compiler.passes import Pass +from jaclang.compiler.passes.transform import Transform from jaclang.settings import settings from jaclang.utils.helpers import pascal_to_snake from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack - import mypy.nodes as MypyNodes # noqa N812 import mypy.types as MypyTypes # noqa N812 from mypy.checkexpr import Type as MyType @@ -23,11 +24,82 @@ T = TypeVar("T", bound=ast.AstSymbolNode) +# List of expression nodes which we'll be extracting the type info from. +JAC_EXPR_NODES = ( + ast.AwaitExpr, + ast.BinaryExpr, + ast.CompareExpr, + ast.BoolExpr, + ast.LambdaExpr, + ast.UnaryExpr, + ast.IfElseExpr, + ast.AtomTrailer, + ast.AtomUnit, + ast.YieldExpr, + ast.YieldExpr, + ast.FuncCall, + ast.EdgeRefTrailer, + ast.ListVal, + ast.SetVal, + ast.TupleVal, + ast.DictVal, + ast.ListCompr, + ast.DictCompr, +) + + class FuseTypeInfoPass(Pass): """Python and bytecode file self.__debug_printing pass.""" node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {} + @staticmethod + def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None: + """ + Enter an expression node. + + This function is dynamically bound as a method on insntace of this class, since the + group of functions to handle expressions has a the exact same logic. + """ + if len(node.gen.mypy_ast) == 0: + return + + # If the corrosponding mypy ast node type has stored here, get the values. + mypy_node = node.gen.mypy_ast[0] + if mypy_node in self.node_type_hash: + mytype: MyType = self.node_type_hash[mypy_node] + node.expr_type = str(mytype) + + # TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an + # expression. Time and memory wasted here. + collection_types_map = { + ast.ListVal: "builtins.list", + ast.SetVal: "builtins.set", + ast.TupleVal: "builtins.tuple", + ast.DictVal: "builtins.dict", + ast.ListCompr: None, + ast.DictCompr: None, + } + + # Set they symbol type for collection expression. + if type(node) in tuple(collection_types_map.keys()): + assert isinstance(node, ast.AtomExpr) # To make mypy happy. + if mypy_node in self.node_type_hash: + node.name_spec.sym_type = str(mytype) + collection_type = collection_types_map[type(node)] + if collection_type is not None: + node.name_spec.sym_type = collection_type + + def __init__(self, input_ir: T, prior: Optional[Transform]) -> None: + """Initialize the FuseTpeInfoPass instance.""" + for expr_node in JAC_EXPR_NODES: + method_name = "enter_" + pascal_to_snake(expr_node.__name__) + method = MethodType( + FuseTypeInfoPass.__handle_node(FuseTypeInfoPass.enter_expr), self + ) + setattr(self, method_name, method) + super().__init__(input_ir, prior) + def __debug_print(self, *argv: object) -> None: if settings.fuse_type_info_debug: self.log_info("FuseTypeInfo::", *argv) @@ -310,54 +382,6 @@ def enter_f_string(self, node: ast.FString) -> None: """Pass handler for FString nodes.""" self.__debug_print("Getting type not supported in", type(node)) - @__handle_node - def enter_list_val(self, node: ast.ListVal) -> None: - """Pass handler for ListVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.list" - - @__handle_node - def enter_set_val(self, node: ast.SetVal) -> None: - """Pass handler for SetVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.set" - - @__handle_node - def enter_tuple_val(self, node: ast.TupleVal) -> None: - """Pass handler for TupleVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.tuple" - - @__handle_node - def enter_dict_val(self, node: ast.DictVal) -> None: - """Pass handler for DictVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.dict" - - @__handle_node - def enter_list_compr(self, node: ast.ListCompr) -> None: - """Pass handler for ListCompr nodes.""" - mypy_node = node.gen.mypy_ast[0] - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - - @__handle_node - def enter_dict_compr(self, node: ast.DictCompr) -> None: - """Pass handler for DictCompr nodes.""" - mypy_node = node.gen.mypy_ast[0] - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - @__handle_node def enter_index_slice(self, node: ast.IndexSlice) -> None: """Pass handler for IndexSlice nodes.""" diff --git a/jaclang/compiler/passes/utils/mypy_ast_build.py b/jaclang/compiler/passes/utils/mypy_ast_build.py index 89619c36b..904a0e0a4 100644 --- a/jaclang/compiler/passes/utils/mypy_ast_build.py +++ b/jaclang/compiler/passes/utils/mypy_ast_build.py @@ -4,17 +4,20 @@ import ast import os +from types import MethodType from jaclang.compiler.absyntree import AstNode from jaclang.compiler.passes import Pass from jaclang.compiler.passes.main.fuse_typeinfo_pass import ( FuseTypeInfoPass, ) +from jaclang.utils.helpers import pascal_to_snake import mypy.build as myb import mypy.checkexpr as mycke import mypy.errors as mye import mypy.fastparse as myfp +import mypy.nodes as mypy_nodes from mypy.build import BuildSource from mypy.build import BuildSourceSet from mypy.build import FileSystemCache @@ -29,6 +32,55 @@ from mypy.semanal_main import semantic_analysis_for_scc +# All the expression nodes of mypy. +EXPRESSION_NODES = ( + mypy_nodes.AssertTypeExpr, + mypy_nodes.AssignmentExpr, + mypy_nodes.AwaitExpr, + mypy_nodes.BytesExpr, + mypy_nodes.CallExpr, + mypy_nodes.CastExpr, + mypy_nodes.ComparisonExpr, + mypy_nodes.ComplexExpr, + mypy_nodes.ConditionalExpr, + mypy_nodes.DictionaryComprehension, + mypy_nodes.DictExpr, + mypy_nodes.EllipsisExpr, + mypy_nodes.EnumCallExpr, + mypy_nodes.Expression, + mypy_nodes.FloatExpr, + mypy_nodes.GeneratorExpr, + mypy_nodes.IndexExpr, + mypy_nodes.IntExpr, + mypy_nodes.LambdaExpr, + mypy_nodes.ListComprehension, + mypy_nodes.ListExpr, + mypy_nodes.MemberExpr, + mypy_nodes.NamedTupleExpr, + mypy_nodes.NameExpr, + mypy_nodes.NewTypeExpr, + mypy_nodes.OpExpr, + mypy_nodes.ParamSpecExpr, + mypy_nodes.PromoteExpr, + mypy_nodes.RefExpr, + mypy_nodes.RevealExpr, + mypy_nodes.SetComprehension, + mypy_nodes.SetExpr, + mypy_nodes.SliceExpr, + mypy_nodes.StarExpr, + mypy_nodes.StrExpr, + mypy_nodes.SuperExpr, + mypy_nodes.TupleExpr, + mypy_nodes.TypeAliasExpr, + mypy_nodes.TypedDictExpr, + mypy_nodes.TypeVarExpr, + mypy_nodes.TypeVarTupleExpr, + mypy_nodes.UnaryExpr, + mypy_nodes.YieldExpr, + mypy_nodes.YieldFromExpr, +) + + mypy_to_jac_node_map: dict[ tuple[int, int | None, int | None, int | None], list[AstNode] ] = {} @@ -87,63 +139,45 @@ def __init__( """Override to mypy expression checker for direct AST pass through.""" super().__init__(tc, msg, plugin, per_line_checking_time_ns) - def visit_list_expr(self, e: mycke.ListExpr) -> mycke.Type: - """Type check a list expression [...].""" - out = super().visit_list_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_set_expr(self, e: mycke.SetExpr) -> mycke.Type: - """Type check a set expression {...}.""" - out = super().visit_set_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_tuple_expr(self, e: myfp.TupleExpr) -> myb.Type: - """Type check a tuple expression (...).""" - out = super().visit_tuple_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_dict_expr(self, e: myfp.DictExpr) -> myb.Type: - """Type check a dictionary expression {...}.""" - out = super().visit_dict_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_list_comprehension(self, e: myfp.ListComprehension) -> myb.Type: - """Type check a list comprehension.""" - out = super().visit_list_comprehension(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_set_comprehension(self, e: myfp.SetComprehension) -> myb.Type: - """Type check a set comprehension.""" - out = super().visit_set_comprehension(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_generator_expr(self, e: myfp.GeneratorExpr) -> myb.Type: - """Type check a generator expression.""" - out = super().visit_generator_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_dictionary_comprehension( - self, e: myfp.DictionaryComprehension - ) -> myb.Type: - """Type check a dict comprehension.""" - out = super().visit_dictionary_comprehension(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_member_expr( - self, e: myfp.MemberExpr, is_lvalue: bool = False - ) -> myb.Type: - """Type check a member expr.""" - out = super().visit_member_expr(e, is_lvalue) - FuseTypeInfoPass.node_type_hash[e] = out - return out + # For every expression there, create attach a method on this instance (self) named "enter_expr()" + for expr_node in EXPRESSION_NODES: + method_name = "visit_" + pascal_to_snake(expr_node.__name__) + + # We call the super() version of the method so ensure the parent class has the method or else continue. + if not hasattr(mycke.ExpressionChecker, method_name): + continue + + # If the method already overriden then don't override it again here. Continue. Note that the method exists + # on the parent class and if it's also exists on this class and it's a different object that means it's + # overrident method. + if getattr(mycke.ExpressionChecker, method_name) != getattr( + ExpressionChecker, method_name + ): + continue + + # Since the "closure" function bellow captures the method name inside it, we cannot use it directly as the + # "method_name" variable is used inside a loop and by the time the closure close the "method_name" value, + # it'll be changed by the loop, so we need another method ("make_closure") to persist the value. + def make_closure(method_name: str): # noqa: ANN201 + def closure( + self: ExpressionChecker, + e: mycke.Expression, + *args, # noqa: ANN002 + **kwargs, # noqa: ANN003 + ) -> mycke.Type: + # Ignore B023 here since we bind loop variable properly but flake8 raise a false alarm + # (in some version of it), a bug in flake8 (https://github.com/PyCQA/flake8-bugbear/issues/269). + out = getattr(mycke.ExpressionChecker, method_name)( # noqa: B023 + self, e, *args, **kwargs + ) + FuseTypeInfoPass.node_type_hash[e] = out + return out + + return closure + + # Attach the new "visit_expr()" method to this instance. + method = make_closure(method_name) + setattr(self, method_name, MethodType(method, self)) class State(myb.State):