Skip to content
This repository was archived by the owner on Sep 12, 2024. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions jaclang/compiler/absyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(self, kid: Sequence[AstNode]) -> None:
self.meta: dict[str, str] = {}
self.loc: CodeLocInfo = CodeLocInfo(*self.resolve_tok_range())

# NOTE: This is only applicable for Expr, However adding it there needs to call the constructor in all the
# subclasses, Adding it here, this needs a review.
self.expr_type: str = ""

@property
def sym_tab(self) -> SymbolTable:
"""Get symbol table."""
Expand Down
124 changes: 74 additions & 50 deletions jaclang/compiler/passes/main/fuse_typeinfo_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

from __future__ import annotations

from typing import Callable, TypeVar
from types import MethodType
from typing import Callable, Optional, TypeVar

import jaclang.compiler.absyntree as ast
from jaclang.compiler.passes import Pass
from jaclang.compiler.passes.transform import Transform
from jaclang.settings import settings
from jaclang.utils.helpers import pascal_to_snake
from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack


import mypy.nodes as MypyNodes # noqa N812
import mypy.types as MypyTypes # noqa N812
from mypy.checkexpr import Type as MyType
Expand All @@ -23,11 +24,82 @@
T = TypeVar("T", bound=ast.AstSymbolNode)


# List of expression nodes which we'll be extracting the type info from.
JAC_EXPR_NODES = (
ast.AwaitExpr,
ast.BinaryExpr,
ast.CompareExpr,
ast.BoolExpr,
ast.LambdaExpr,
ast.UnaryExpr,
ast.IfElseExpr,
ast.AtomTrailer,
ast.AtomUnit,
ast.YieldExpr,
ast.YieldExpr,
ast.FuncCall,
ast.EdgeRefTrailer,
ast.ListVal,
ast.SetVal,
ast.TupleVal,
ast.DictVal,
ast.ListCompr,
ast.DictCompr,
)


class FuseTypeInfoPass(Pass):
"""Python and bytecode file self.__debug_printing pass."""

node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {}

@staticmethod
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
"""
Enter an expression node.

This function is dynamically bound as a method on insntace of this class, since the
group of functions to handle expressions has a the exact same logic.
"""
if len(node.gen.mypy_ast) == 0:
return

# If the corrosponding mypy ast node type has stored here, get the values.
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
mytype: MyType = self.node_type_hash[mypy_node]
node.expr_type = str(mytype)

# TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
# expression. Time and memory wasted here.
collection_types_map = {
ast.ListVal: "builtins.list",
ast.SetVal: "builtins.set",
ast.TupleVal: "builtins.tuple",
ast.DictVal: "builtins.dict",
ast.ListCompr: None,
ast.DictCompr: None,
}

# Set they symbol type for collection expression.
if type(node) in tuple(collection_types_map.keys()):
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(mytype)
collection_type = collection_types_map[type(node)]
if collection_type is not None:
node.name_spec.sym_type = collection_type

def __init__(self, input_ir: T, prior: Optional[Transform]) -> None:
"""Initialize the FuseTpeInfoPass instance."""
for expr_node in JAC_EXPR_NODES:
method_name = "enter_" + pascal_to_snake(expr_node.__name__)
method = MethodType(
FuseTypeInfoPass.__handle_node(FuseTypeInfoPass.enter_expr), self
)
setattr(self, method_name, method)
super().__init__(input_ir, prior)

def __debug_print(self, *argv: object) -> None:
if settings.fuse_type_info_debug:
self.log_info("FuseTypeInfo::", *argv)
Expand Down Expand Up @@ -310,54 +382,6 @@ def enter_f_string(self, node: ast.FString) -> None:
"""Pass handler for FString nodes."""
self.__debug_print("Getting type not supported in", type(node))

@__handle_node
def enter_list_val(self, node: ast.ListVal) -> None:
"""Pass handler for ListVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.list"

@__handle_node
def enter_set_val(self, node: ast.SetVal) -> None:
"""Pass handler for SetVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.set"

@__handle_node
def enter_tuple_val(self, node: ast.TupleVal) -> None:
"""Pass handler for TupleVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.tuple"

@__handle_node
def enter_dict_val(self, node: ast.DictVal) -> None:
"""Pass handler for DictVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.dict"

@__handle_node
def enter_list_compr(self, node: ast.ListCompr) -> None:
"""Pass handler for ListCompr nodes."""
mypy_node = node.gen.mypy_ast[0]
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])

@__handle_node
def enter_dict_compr(self, node: ast.DictCompr) -> None:
"""Pass handler for DictCompr nodes."""
mypy_node = node.gen.mypy_ast[0]
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These deleted methods are dynamically added above in init

@__handle_node
def enter_index_slice(self, node: ast.IndexSlice) -> None:
"""Pass handler for IndexSlice nodes."""
Expand Down
148 changes: 91 additions & 57 deletions jaclang/compiler/passes/utils/mypy_ast_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@

import ast
import os
from types import MethodType

from jaclang.compiler.absyntree import AstNode
from jaclang.compiler.passes import Pass
from jaclang.compiler.passes.main.fuse_typeinfo_pass import (
FuseTypeInfoPass,
)
from jaclang.utils.helpers import pascal_to_snake

import mypy.build as myb
import mypy.checkexpr as mycke
import mypy.errors as mye
import mypy.fastparse as myfp
import mypy.nodes as mypy_nodes
from mypy.build import BuildSource
from mypy.build import BuildSourceSet
from mypy.build import FileSystemCache
Expand All @@ -29,6 +32,55 @@
from mypy.semanal_main import semantic_analysis_for_scc


# All the expression nodes of mypy.
EXPRESSION_NODES = (
mypy_nodes.AssertTypeExpr,
mypy_nodes.AssignmentExpr,
mypy_nodes.AwaitExpr,
mypy_nodes.BytesExpr,
mypy_nodes.CallExpr,
mypy_nodes.CastExpr,
mypy_nodes.ComparisonExpr,
mypy_nodes.ComplexExpr,
mypy_nodes.ConditionalExpr,
mypy_nodes.DictionaryComprehension,
mypy_nodes.DictExpr,
mypy_nodes.EllipsisExpr,
mypy_nodes.EnumCallExpr,
mypy_nodes.Expression,
mypy_nodes.FloatExpr,
mypy_nodes.GeneratorExpr,
mypy_nodes.IndexExpr,
mypy_nodes.IntExpr,
mypy_nodes.LambdaExpr,
mypy_nodes.ListComprehension,
mypy_nodes.ListExpr,
mypy_nodes.MemberExpr,
mypy_nodes.NamedTupleExpr,
mypy_nodes.NameExpr,
mypy_nodes.NewTypeExpr,
mypy_nodes.OpExpr,
mypy_nodes.ParamSpecExpr,
mypy_nodes.PromoteExpr,
mypy_nodes.RefExpr,
mypy_nodes.RevealExpr,
mypy_nodes.SetComprehension,
mypy_nodes.SetExpr,
mypy_nodes.SliceExpr,
mypy_nodes.StarExpr,
mypy_nodes.StrExpr,
mypy_nodes.SuperExpr,
mypy_nodes.TupleExpr,
mypy_nodes.TypeAliasExpr,
mypy_nodes.TypedDictExpr,
mypy_nodes.TypeVarExpr,
mypy_nodes.TypeVarTupleExpr,
mypy_nodes.UnaryExpr,
mypy_nodes.YieldExpr,
mypy_nodes.YieldFromExpr,
)


mypy_to_jac_node_map: dict[
tuple[int, int | None, int | None, int | None], list[AstNode]
] = {}
Expand Down Expand Up @@ -87,63 +139,45 @@ def __init__(
"""Override to mypy expression checker for direct AST pass through."""
super().__init__(tc, msg, plugin, per_line_checking_time_ns)

def visit_list_expr(self, e: mycke.ListExpr) -> mycke.Type:
"""Type check a list expression [...]."""
out = super().visit_list_expr(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_set_expr(self, e: mycke.SetExpr) -> mycke.Type:
"""Type check a set expression {...}."""
out = super().visit_set_expr(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_tuple_expr(self, e: myfp.TupleExpr) -> myb.Type:
"""Type check a tuple expression (...)."""
out = super().visit_tuple_expr(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_dict_expr(self, e: myfp.DictExpr) -> myb.Type:
"""Type check a dictionary expression {...}."""
out = super().visit_dict_expr(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_list_comprehension(self, e: myfp.ListComprehension) -> myb.Type:
"""Type check a list comprehension."""
out = super().visit_list_comprehension(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_set_comprehension(self, e: myfp.SetComprehension) -> myb.Type:
"""Type check a set comprehension."""
out = super().visit_set_comprehension(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_generator_expr(self, e: myfp.GeneratorExpr) -> myb.Type:
"""Type check a generator expression."""
out = super().visit_generator_expr(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_dictionary_comprehension(
self, e: myfp.DictionaryComprehension
) -> myb.Type:
"""Type check a dict comprehension."""
out = super().visit_dictionary_comprehension(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_member_expr(
self, e: myfp.MemberExpr, is_lvalue: bool = False
) -> myb.Type:
"""Type check a member expr."""
out = super().visit_member_expr(e, is_lvalue)
FuseTypeInfoPass.node_type_hash[e] = out
return out
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the visit_expr method are deleted as the logic are the same, and added as method bound when the instance is initialized at __init__ with also extrace the type info.

# For every expression there, create attach a method on this instance (self) named "enter_expr()"
for expr_node in EXPRESSION_NODES:
method_name = "visit_" + pascal_to_snake(expr_node.__name__)

# We call the super() version of the method so ensure the parent class has the method or else continue.
if not hasattr(mycke.ExpressionChecker, method_name):
continue

# If the method already overriden then don't override it again here. Continue. Note that the method exists
# on the parent class and if it's also exists on this class and it's a different object that means it's
# overrident method.
if getattr(mycke.ExpressionChecker, method_name) != getattr(
ExpressionChecker, method_name
):
continue

# Since the "closure" function bellow captures the method name inside it, we cannot use it directly as the
# "method_name" variable is used inside a loop and by the time the closure close the "method_name" value,
# it'll be changed by the loop, so we need another method ("make_closure") to persist the value.
def make_closure(method_name: str): # noqa: ANN201
def closure(
self: ExpressionChecker,
e: mycke.Expression,
*args, # noqa: ANN002
**kwargs, # noqa: ANN003
) -> mycke.Type:
# Ignore B023 here since we bind loop variable properly but flake8 raise a false alarm
# (in some version of it), a bug in flake8 (https://github.com/PyCQA/flake8-bugbear/issues/269).
out = getattr(mycke.ExpressionChecker, method_name)( # noqa: B023
self, e, *args, **kwargs
)
FuseTypeInfoPass.node_type_hash[e] = out
return out

return closure

# Attach the new "visit_expr()" method to this instance.
method = make_closure(method_name)
setattr(self, method_name, MethodType(method, self))


class State(myb.State):
Expand Down