Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions mypyc/analysis/capsule_deps.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
from __future__ import annotations

from mypyc.ir.deps import Dependency
from mypyc.ir.func_ir import FuncIR
from mypyc.ir.ops import CallC, PrimitiveOp


def find_implicit_capsule_dependencies(fn: FuncIR) -> set[str] | None:
"""Find implicit dependencies on capsules that need to be imported.
def find_implicit_op_dependencies(fn: FuncIR) -> set[Dependency] | None:
"""Find implicit dependencies that need to be imported.

Using primitives or types defined in librt submodules such as "librt.base64"
requires a capsule import.
requires dependency imports (e.g., capsule imports).

Note that a module can depend on a librt module even if it doesn't explicitly
import it, for example via re-exported names or via return types of functions
defined in other modules.
"""
deps: set[str] | None = None
deps: set[Dependency] | None = None
for block in fn.blocks:
for op in block.ops:
# TODO: Also determine implicit type object dependencies (e.g. cast targets)
if isinstance(op, CallC) and op.capsule is not None:
if deps is None:
deps = set()
deps.add(op.capsule)
if isinstance(op, CallC) and op.dependencies is not None:
for dep in op.dependencies:
if deps is None:
deps = set()
deps.add(dep)
else:
assert not isinstance(op, PrimitiveOp), "Lowered IR is expected"
return deps
38 changes: 26 additions & 12 deletions mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from mypyc.codegen import emitmodule
from mypyc.common import IS_FREE_THREADED, RUNTIME_C_FILES, shared_lib_name
from mypyc.errors import Errors
from mypyc.ir.deps import SourceDep
from mypyc.ir.pprint import format_modules
from mypyc.namegen import exported_name
from mypyc.options import CompilerOptions
Expand Down Expand Up @@ -282,14 +283,14 @@ def generate_c(
groups: emitmodule.Groups,
fscache: FileSystemCache,
compiler_options: CompilerOptions,
) -> tuple[list[list[tuple[str, str]]], str]:
) -> tuple[list[list[tuple[str, str]]], str, list[SourceDep]]:
"""Drive the actual core compilation step.

The groups argument describes how modules are assigned to C
extension modules. See the comments on the Groups type in
mypyc.emitmodule for details.

Returns the C source code and (for debugging) the pretty printed IR.
Returns the C source code, (for debugging) the pretty printed IR, and list of SourceDeps.
"""
t0 = time.time()

Expand Down Expand Up @@ -325,7 +326,10 @@ def generate_c(
if options.mypyc_annotation_file:
generate_annotated_html(options.mypyc_annotation_file, result, modules, mapper)

return ctext, "\n".join(format_modules(modules))
# Collect SourceDep dependencies
source_deps = sorted(emitmodule.collect_source_dependencies(modules), key=lambda d: d.path)

return ctext, "\n".join(format_modules(modules)), source_deps


def build_using_shared_lib(
Expand Down Expand Up @@ -486,9 +490,9 @@ def mypyc_build(
*,
separate: bool | list[tuple[list[str], str | None]] = False,
only_compile_paths: Iterable[str] | None = None,
skip_cgen_input: Any | None = None,
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
always_use_shared_lib: bool = False,
) -> tuple[emitmodule.Groups, list[tuple[list[str], list[str]]]]:
) -> tuple[emitmodule.Groups, list[tuple[list[str], list[str]]], list[SourceDep]]:
"""Do the front and middle end of mypyc building, producing and writing out C source."""
fscache = FileSystemCache()
mypyc_sources, all_sources, options = get_mypy_config(
Expand All @@ -511,14 +515,16 @@ def mypyc_build(

# We let the test harness just pass in the c file contents instead
# so that it can do a corner-cutting version without full stubs.
source_deps: list[SourceDep] = []
if not skip_cgen_input:
group_cfiles, ops_text = generate_c(
group_cfiles, ops_text, source_deps = generate_c(
all_sources, options, groups, fscache, compiler_options=compiler_options
)
# TODO: unique names?
write_file(os.path.join(compiler_options.target_dir, "ops.txt"), ops_text)
else:
group_cfiles = skip_cgen_input
group_cfiles = skip_cgen_input[0]
source_deps = [SourceDep(d) for d in skip_cgen_input[1]]

# Write out the generated C and collect the files for each group
# Should this be here??
Expand All @@ -535,7 +541,7 @@ def mypyc_build(
deps = [os.path.join(compiler_options.target_dir, dep) for dep in get_header_deps(cfiles)]
group_cfilenames.append((cfilenames, deps))

return groups, group_cfilenames
return groups, group_cfilenames, source_deps


def mypycify(
Expand All @@ -548,7 +554,7 @@ def mypycify(
strip_asserts: bool = False,
multi_file: bool = False,
separate: bool | list[tuple[list[str], str | None]] = False,
skip_cgen_input: Any | None = None,
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
target_dir: str | None = None,
include_runtime_files: bool | None = None,
strict_dunder_typing: bool = False,
Expand Down Expand Up @@ -633,7 +639,7 @@ def mypycify(
)

# Generate all the actual important C code
groups, group_cfilenames = mypyc_build(
groups, group_cfilenames, source_deps = mypyc_build(
paths,
only_compile_paths=only_compile_paths,
compiler_options=compiler_options,
Expand Down Expand Up @@ -708,11 +714,19 @@ def mypycify(
# compiler invocations.
shared_cfilenames = []
if not compiler_options.include_runtime_files:
for name in RUNTIME_C_FILES:
# Collect all files to copy: runtime files + conditional source files
files_to_copy = list(RUNTIME_C_FILES)
for source_dep in source_deps:
files_to_copy.append(source_dep.path)
files_to_copy.append(source_dep.get_header())

# Copy all files
for name in files_to_copy:
rt_file = os.path.join(build_dir, name)
with open(os.path.join(include_dir(), name), encoding="utf-8") as f:
write_file(rt_file, f.read())
shared_cfilenames.append(rt_file)
if name.endswith(".c"):
shared_cfilenames.append(rt_file)

extensions = []
for (group_sources, lib_name), (cfilenames, deps) in zip(groups, group_cfilenames):
Expand Down
35 changes: 27 additions & 8 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from mypy.options import Options
from mypy.plugin import Plugin, ReportConfigContext
from mypy.util import hash_digest, json_dumps
from mypyc.analysis.capsule_deps import find_implicit_capsule_dependencies
from mypyc.analysis.capsule_deps import find_implicit_op_dependencies
from mypyc.codegen.cstring import c_string_initializer
from mypyc.codegen.emit import (
Emitter,
Expand Down Expand Up @@ -56,6 +56,7 @@
short_id_from_name,
)
from mypyc.errors import Errors
from mypyc.ir.deps import LIBRT_BASE64, LIBRT_STRINGS, SourceDep
from mypyc.ir.func_ir import FuncIR
from mypyc.ir.module_ir import ModuleIR, ModuleIRs, deserialize_modules
from mypyc.ir.ops import DeserMaps, LoadLiteral
Expand Down Expand Up @@ -263,9 +264,9 @@ def compile_scc_to_ir(
# Switch to lower abstraction level IR.
lower_ir(fn, compiler_options)
# Calculate implicit module dependencies (needed for librt)
capsules = find_implicit_capsule_dependencies(fn)
if capsules is not None:
module.capsules.update(capsules)
deps = find_implicit_op_dependencies(fn)
if deps is not None:
module.dependencies.update(deps)
# Perform optimizations.
do_copy_propagation(fn, compiler_options)
do_flag_elimination(fn, compiler_options)
Expand Down Expand Up @@ -427,6 +428,16 @@ def load_scc_from_cache(
return modules


def collect_source_dependencies(modules: dict[str, ModuleIR]) -> set[SourceDep]:
"""Collect all SourceDep dependencies from all modules."""
source_deps: set[SourceDep] = set()
for module in modules.values():
for dep in module.dependencies:
if isinstance(dep, SourceDep):
source_deps.add(dep)
return source_deps


def compile_modules_to_c(
result: BuildResult, compiler_options: CompilerOptions, errors: Errors, groups: Groups
) -> tuple[ModuleIRs, list[FileContents], Mapper]:
Expand Down Expand Up @@ -560,6 +571,10 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
if self.compiler_options.include_runtime_files:
for name in RUNTIME_C_FILES:
base_emitter.emit_line(f'#include "{name}"')
# Include conditional source files
source_deps = collect_source_dependencies(self.modules)
for source_dep in sorted(source_deps, key=lambda d: d.path):
base_emitter.emit_line(f'#include "{source_dep.path}"')
base_emitter.emit_line(f'#include "__native{self.short_group_suffix}.h"')
base_emitter.emit_line(f'#include "__native_internal{self.short_group_suffix}.h"')
emitter = base_emitter
Expand Down Expand Up @@ -611,10 +626,14 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
ext_declarations.emit_line("#include <CPy.h>")
if self.compiler_options.depends_on_librt_internal:
ext_declarations.emit_line("#include <librt_internal.h>")
if any("librt.base64" in mod.capsules for mod in self.modules.values()):
if any(LIBRT_BASE64 in mod.dependencies for mod in self.modules.values()):
ext_declarations.emit_line("#include <librt_base64.h>")
if any("librt.strings" in mod.capsules for mod in self.modules.values()):
if any(LIBRT_STRINGS in mod.dependencies for mod in self.modules.values()):
ext_declarations.emit_line("#include <librt_strings.h>")
# Include headers for conditional source files
source_deps = collect_source_dependencies(self.modules)
for source_dep in sorted(source_deps, key=lambda d: d.path):
ext_declarations.emit_line(f'#include "{source_dep.get_header()}"')

declarations = Emitter(self.context)
declarations.emit_line(f"#ifndef MYPYC_LIBRT_INTERNAL{self.group_suffix}_H")
Expand Down Expand Up @@ -1072,11 +1091,11 @@ def emit_module_exec_func(
emitter.emit_line("if (import_librt_internal() < 0) {")
emitter.emit_line("return -1;")
emitter.emit_line("}")
if "librt.base64" in module.capsules:
if LIBRT_BASE64 in module.dependencies:
emitter.emit_line("if (import_librt_base64() < 0) {")
emitter.emit_line("return -1;")
emitter.emit_line("}")
if "librt.strings" in module.capsules:
if LIBRT_STRINGS in module.dependencies:
emitter.emit_line("if (import_librt_strings() < 0) {")
emitter.emit_line("return -1;")
emitter.emit_line("}")
Expand Down
3 changes: 2 additions & 1 deletion mypyc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@
BITMAP_TYPE: Final = "uint32_t"
BITMAP_BITS: Final = 32

# Runtime C library files
# Runtime C library files that are always included (some ops may bring
# extra dependencies via mypyc.ir.SourceDep)
RUNTIME_C_FILES: Final = [
"init.c",
"getargs.c",
Expand Down
52 changes: 52 additions & 0 deletions mypyc/ir/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Final


class Capsule:
"""Defines a C extension capsule that a primitive may require."""

def __init__(self, name: str) -> None:
# Module fullname, e.g. 'librt.base64'
self.name: Final = name

def __repr__(self) -> str:
return f"Capsule(name={self.name!r})"

def __eq__(self, other: object) -> bool:
return isinstance(other, Capsule) and self.name == other.name

def __hash__(self) -> int:
return hash(("Capsule", self.name))


class SourceDep:
"""Defines a C source file that a primitive may require.

Each source file must also have a corresponding .h file (replace .c with .h)
that gets implicitly #included if the source is used.
"""

def __init__(self, path: str) -> None:
# Relative path from mypyc/lib-rt, e.g. 'bytes_extra_ops.c'
self.path: Final = path

def __repr__(self) -> str:
return f"SourceDep(path={self.path!r})"

def __eq__(self, other: object) -> bool:
return isinstance(other, SourceDep) and self.path == other.path

def __hash__(self) -> int:
return hash(("SourceDep", self.path))

def get_header(self) -> str:
"""Get the header file path by replacing .c with .h"""
return self.path.replace(".c", ".h")


Dependency = Capsule | SourceDep


LIBRT_STRINGS: Final = Capsule("librt.strings")
LIBRT_BASE64: Final = Capsule("librt.base64")

BYTES_EXTRA_OPS: Final = SourceDep("bytes_extra_ops.c")
26 changes: 22 additions & 4 deletions mypyc/ir/module_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mypyc.common import JsonDict
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.deps import Capsule, Dependency, SourceDep
from mypyc.ir.func_ir import FuncDecl, FuncIR
from mypyc.ir.ops import DeserMaps
from mypyc.ir.rtypes import RType, deserialize_type
Expand All @@ -30,17 +31,25 @@ def __init__(
# These are only visible in the module that defined them, so no need
# to serialize.
self.type_var_names = type_var_names
# Capsules needed by the module, specified via module names such as "librt.base64"
self.capsules: set[str] = set()
# Dependencies needed by the module (such as capsules or source files)
self.dependencies: set[Dependency] = set()

def serialize(self) -> JsonDict:
# Serialize dependencies as a list of dicts with type information
serialized_deps = []
for dep in sorted(self.dependencies, key=lambda d: (type(d).__name__, str(d))):
if isinstance(dep, Capsule):
serialized_deps.append({"type": "Capsule", "name": dep.name})
elif isinstance(dep, SourceDep):
serialized_deps.append({"type": "SourceDep", "path": dep.path})

return {
"fullname": self.fullname,
"imports": self.imports,
"functions": [f.serialize() for f in self.functions],
"classes": [c.serialize() for c in self.classes],
"final_names": [(k, t.serialize()) for k, t in self.final_names],
"capsules": sorted(self.capsules),
"dependencies": serialized_deps,
}

@classmethod
Expand All @@ -53,7 +62,16 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ModuleIR:
[(k, deserialize_type(t, ctx)) for k, t in data["final_names"]],
[],
)
module.capsules = set(data["capsules"])

# Deserialize dependencies
deps: set[Dependency] = set()
for dep_dict in data["dependencies"]:
if dep_dict["type"] == "Capsule":
deps.add(Capsule(dep_dict["name"]))
elif dep_dict["type"] == "SourceDep":
deps.add(SourceDep(dep_dict["path"]))
module.dependencies = deps

return module


Expand Down
Loading