diff --git a/mypyc/analysis/capsule_deps.py b/mypyc/analysis/capsule_deps.py index ada42ee03f28..e2e8563db7fe 100644 --- a/mypyc/analysis/capsule_deps.py +++ b/mypyc/analysis/capsule_deps.py @@ -1,27 +1,29 @@ from __future__ import annotations +from mypyc.ir.deps import Dependency from mypyc.ir.func_ir import FuncIR from mypyc.ir.ops import CallC, PrimitiveOp -def find_implicit_capsule_dependencies(fn: FuncIR) -> set[str] | None: - """Find implicit dependencies on capsules that need to be imported. +def find_implicit_op_dependencies(fn: FuncIR) -> set[Dependency] | None: + """Find implicit dependencies that need to be imported. Using primitives or types defined in librt submodules such as "librt.base64" - requires a capsule import. + requires dependency imports (e.g., capsule imports). Note that a module can depend on a librt module even if it doesn't explicitly import it, for example via re-exported names or via return types of functions defined in other modules. """ - deps: set[str] | None = None + deps: set[Dependency] | None = None for block in fn.blocks: for op in block.ops: # TODO: Also determine implicit type object dependencies (e.g. cast targets) - if isinstance(op, CallC) and op.capsule is not None: - if deps is None: - deps = set() - deps.add(op.capsule) + if isinstance(op, CallC) and op.dependencies is not None: + for dep in op.dependencies: + if deps is None: + deps = set() + deps.add(dep) else: assert not isinstance(op, PrimitiveOp), "Lowered IR is expected" return deps diff --git a/mypyc/build.py b/mypyc/build.py index 9de5ccaabaaa..757f0a49737a 100644 --- a/mypyc/build.py +++ b/mypyc/build.py @@ -39,6 +39,7 @@ from mypyc.codegen import emitmodule from mypyc.common import IS_FREE_THREADED, RUNTIME_C_FILES, shared_lib_name from mypyc.errors import Errors +from mypyc.ir.deps import SourceDep from mypyc.ir.pprint import format_modules from mypyc.namegen import exported_name from mypyc.options import CompilerOptions @@ -282,14 +283,14 @@ def generate_c( groups: emitmodule.Groups, fscache: FileSystemCache, compiler_options: CompilerOptions, -) -> tuple[list[list[tuple[str, str]]], str]: +) -> tuple[list[list[tuple[str, str]]], str, list[SourceDep]]: """Drive the actual core compilation step. The groups argument describes how modules are assigned to C extension modules. See the comments on the Groups type in mypyc.emitmodule for details. - Returns the C source code and (for debugging) the pretty printed IR. + Returns the C source code, (for debugging) the pretty printed IR, and list of SourceDeps. """ t0 = time.time() @@ -325,7 +326,10 @@ def generate_c( if options.mypyc_annotation_file: generate_annotated_html(options.mypyc_annotation_file, result, modules, mapper) - return ctext, "\n".join(format_modules(modules)) + # Collect SourceDep dependencies + source_deps = sorted(emitmodule.collect_source_dependencies(modules), key=lambda d: d.path) + + return ctext, "\n".join(format_modules(modules)), source_deps def build_using_shared_lib( @@ -486,9 +490,9 @@ def mypyc_build( *, separate: bool | list[tuple[list[str], str | None]] = False, only_compile_paths: Iterable[str] | None = None, - skip_cgen_input: Any | None = None, + skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None, always_use_shared_lib: bool = False, -) -> tuple[emitmodule.Groups, list[tuple[list[str], list[str]]]]: +) -> tuple[emitmodule.Groups, list[tuple[list[str], list[str]]], list[SourceDep]]: """Do the front and middle end of mypyc building, producing and writing out C source.""" fscache = FileSystemCache() mypyc_sources, all_sources, options = get_mypy_config( @@ -511,14 +515,16 @@ def mypyc_build( # We let the test harness just pass in the c file contents instead # so that it can do a corner-cutting version without full stubs. + source_deps: list[SourceDep] = [] if not skip_cgen_input: - group_cfiles, ops_text = generate_c( + group_cfiles, ops_text, source_deps = generate_c( all_sources, options, groups, fscache, compiler_options=compiler_options ) # TODO: unique names? write_file(os.path.join(compiler_options.target_dir, "ops.txt"), ops_text) else: - group_cfiles = skip_cgen_input + group_cfiles = skip_cgen_input[0] + source_deps = [SourceDep(d) for d in skip_cgen_input[1]] # Write out the generated C and collect the files for each group # Should this be here?? @@ -535,7 +541,7 @@ def mypyc_build( deps = [os.path.join(compiler_options.target_dir, dep) for dep in get_header_deps(cfiles)] group_cfilenames.append((cfilenames, deps)) - return groups, group_cfilenames + return groups, group_cfilenames, source_deps def mypycify( @@ -548,7 +554,7 @@ def mypycify( strip_asserts: bool = False, multi_file: bool = False, separate: bool | list[tuple[list[str], str | None]] = False, - skip_cgen_input: Any | None = None, + skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None, target_dir: str | None = None, include_runtime_files: bool | None = None, strict_dunder_typing: bool = False, @@ -633,7 +639,7 @@ def mypycify( ) # Generate all the actual important C code - groups, group_cfilenames = mypyc_build( + groups, group_cfilenames, source_deps = mypyc_build( paths, only_compile_paths=only_compile_paths, compiler_options=compiler_options, @@ -708,11 +714,19 @@ def mypycify( # compiler invocations. shared_cfilenames = [] if not compiler_options.include_runtime_files: - for name in RUNTIME_C_FILES: + # Collect all files to copy: runtime files + conditional source files + files_to_copy = list(RUNTIME_C_FILES) + for source_dep in source_deps: + files_to_copy.append(source_dep.path) + files_to_copy.append(source_dep.get_header()) + + # Copy all files + for name in files_to_copy: rt_file = os.path.join(build_dir, name) with open(os.path.join(include_dir(), name), encoding="utf-8") as f: write_file(rt_file, f.read()) - shared_cfilenames.append(rt_file) + if name.endswith(".c"): + shared_cfilenames.append(rt_file) extensions = [] for (group_sources, lib_name), (cfilenames, deps) in zip(groups, group_cfilenames): diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 959ad3a6d73d..a56083796e79 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -27,7 +27,7 @@ from mypy.options import Options from mypy.plugin import Plugin, ReportConfigContext from mypy.util import hash_digest, json_dumps -from mypyc.analysis.capsule_deps import find_implicit_capsule_dependencies +from mypyc.analysis.capsule_deps import find_implicit_op_dependencies from mypyc.codegen.cstring import c_string_initializer from mypyc.codegen.emit import ( Emitter, @@ -56,6 +56,7 @@ short_id_from_name, ) from mypyc.errors import Errors +from mypyc.ir.deps import LIBRT_BASE64, LIBRT_STRINGS, SourceDep from mypyc.ir.func_ir import FuncIR from mypyc.ir.module_ir import ModuleIR, ModuleIRs, deserialize_modules from mypyc.ir.ops import DeserMaps, LoadLiteral @@ -263,9 +264,9 @@ def compile_scc_to_ir( # Switch to lower abstraction level IR. lower_ir(fn, compiler_options) # Calculate implicit module dependencies (needed for librt) - capsules = find_implicit_capsule_dependencies(fn) - if capsules is not None: - module.capsules.update(capsules) + deps = find_implicit_op_dependencies(fn) + if deps is not None: + module.dependencies.update(deps) # Perform optimizations. do_copy_propagation(fn, compiler_options) do_flag_elimination(fn, compiler_options) @@ -427,6 +428,16 @@ def load_scc_from_cache( return modules +def collect_source_dependencies(modules: dict[str, ModuleIR]) -> set[SourceDep]: + """Collect all SourceDep dependencies from all modules.""" + source_deps: set[SourceDep] = set() + for module in modules.values(): + for dep in module.dependencies: + if isinstance(dep, SourceDep): + source_deps.add(dep) + return source_deps + + def compile_modules_to_c( result: BuildResult, compiler_options: CompilerOptions, errors: Errors, groups: Groups ) -> tuple[ModuleIRs, list[FileContents], Mapper]: @@ -560,6 +571,10 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]: if self.compiler_options.include_runtime_files: for name in RUNTIME_C_FILES: base_emitter.emit_line(f'#include "{name}"') + # Include conditional source files + source_deps = collect_source_dependencies(self.modules) + for source_dep in sorted(source_deps, key=lambda d: d.path): + base_emitter.emit_line(f'#include "{source_dep.path}"') base_emitter.emit_line(f'#include "__native{self.short_group_suffix}.h"') base_emitter.emit_line(f'#include "__native_internal{self.short_group_suffix}.h"') emitter = base_emitter @@ -611,10 +626,14 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]: ext_declarations.emit_line("#include ") if self.compiler_options.depends_on_librt_internal: ext_declarations.emit_line("#include ") - if any("librt.base64" in mod.capsules for mod in self.modules.values()): + if any(LIBRT_BASE64 in mod.dependencies for mod in self.modules.values()): ext_declarations.emit_line("#include ") - if any("librt.strings" in mod.capsules for mod in self.modules.values()): + if any(LIBRT_STRINGS in mod.dependencies for mod in self.modules.values()): ext_declarations.emit_line("#include ") + # Include headers for conditional source files + source_deps = collect_source_dependencies(self.modules) + for source_dep in sorted(source_deps, key=lambda d: d.path): + ext_declarations.emit_line(f'#include "{source_dep.get_header()}"') declarations = Emitter(self.context) declarations.emit_line(f"#ifndef MYPYC_LIBRT_INTERNAL{self.group_suffix}_H") @@ -1072,11 +1091,11 @@ def emit_module_exec_func( emitter.emit_line("if (import_librt_internal() < 0) {") emitter.emit_line("return -1;") emitter.emit_line("}") - if "librt.base64" in module.capsules: + if LIBRT_BASE64 in module.dependencies: emitter.emit_line("if (import_librt_base64() < 0) {") emitter.emit_line("return -1;") emitter.emit_line("}") - if "librt.strings" in module.capsules: + if LIBRT_STRINGS in module.dependencies: emitter.emit_line("if (import_librt_strings() < 0) {") emitter.emit_line("return -1;") emitter.emit_line("}") diff --git a/mypyc/common.py b/mypyc/common.py index 28b386f450a6..4ee004c0dd0f 100644 --- a/mypyc/common.py +++ b/mypyc/common.py @@ -65,7 +65,8 @@ BITMAP_TYPE: Final = "uint32_t" BITMAP_BITS: Final = 32 -# Runtime C library files +# Runtime C library files that are always included (some ops may bring +# extra dependencies via mypyc.ir.SourceDep) RUNTIME_C_FILES: Final = [ "init.c", "getargs.c", diff --git a/mypyc/ir/deps.py b/mypyc/ir/deps.py new file mode 100644 index 000000000000..b7747646628b --- /dev/null +++ b/mypyc/ir/deps.py @@ -0,0 +1,52 @@ +from typing import Final + + +class Capsule: + """Defines a C extension capsule that a primitive may require.""" + + def __init__(self, name: str) -> None: + # Module fullname, e.g. 'librt.base64' + self.name: Final = name + + def __repr__(self) -> str: + return f"Capsule(name={self.name!r})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, Capsule) and self.name == other.name + + def __hash__(self) -> int: + return hash(("Capsule", self.name)) + + +class SourceDep: + """Defines a C source file that a primitive may require. + + Each source file must also have a corresponding .h file (replace .c with .h) + that gets implicitly #included if the source is used. + """ + + def __init__(self, path: str) -> None: + # Relative path from mypyc/lib-rt, e.g. 'bytes_extra_ops.c' + self.path: Final = path + + def __repr__(self) -> str: + return f"SourceDep(path={self.path!r})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, SourceDep) and self.path == other.path + + def __hash__(self) -> int: + return hash(("SourceDep", self.path)) + + def get_header(self) -> str: + """Get the header file path by replacing .c with .h""" + return self.path.replace(".c", ".h") + + +Dependency = Capsule | SourceDep + + +LIBRT_STRINGS: Final = Capsule("librt.strings") +LIBRT_BASE64: Final = Capsule("librt.base64") + +BYTES_EXTRA_OPS: Final = SourceDep("bytes_extra_ops.c") diff --git a/mypyc/ir/module_ir.py b/mypyc/ir/module_ir.py index 5aef414490f9..e978bb35e7ab 100644 --- a/mypyc/ir/module_ir.py +++ b/mypyc/ir/module_ir.py @@ -4,6 +4,7 @@ from mypyc.common import JsonDict from mypyc.ir.class_ir import ClassIR +from mypyc.ir.deps import Capsule, Dependency, SourceDep from mypyc.ir.func_ir import FuncDecl, FuncIR from mypyc.ir.ops import DeserMaps from mypyc.ir.rtypes import RType, deserialize_type @@ -30,17 +31,25 @@ def __init__( # These are only visible in the module that defined them, so no need # to serialize. self.type_var_names = type_var_names - # Capsules needed by the module, specified via module names such as "librt.base64" - self.capsules: set[str] = set() + # Dependencies needed by the module (such as capsules or source files) + self.dependencies: set[Dependency] = set() def serialize(self) -> JsonDict: + # Serialize dependencies as a list of dicts with type information + serialized_deps = [] + for dep in sorted(self.dependencies, key=lambda d: (type(d).__name__, str(d))): + if isinstance(dep, Capsule): + serialized_deps.append({"type": "Capsule", "name": dep.name}) + elif isinstance(dep, SourceDep): + serialized_deps.append({"type": "SourceDep", "path": dep.path}) + return { "fullname": self.fullname, "imports": self.imports, "functions": [f.serialize() for f in self.functions], "classes": [c.serialize() for c in self.classes], "final_names": [(k, t.serialize()) for k, t in self.final_names], - "capsules": sorted(self.capsules), + "dependencies": serialized_deps, } @classmethod @@ -53,7 +62,16 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ModuleIR: [(k, deserialize_type(t, ctx)) for k, t in data["final_names"]], [], ) - module.capsules = set(data["capsules"]) + + # Deserialize dependencies + deps: set[Dependency] = set() + for dep_dict in data["dependencies"]: + if dep_dict["type"] == "Capsule": + deps.add(Capsule(dep_dict["name"])) + elif dep_dict["type"] == "SourceDep": + deps.add(SourceDep(dep_dict["path"])) + module.dependencies = deps + return module diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index beb709adfbd0..c09872ca3826 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -31,6 +31,7 @@ class to enable the new behavior. Sometimes adding a new abstract from mypy_extensions import trait +from mypyc.ir.deps import Dependency from mypyc.ir.rtypes import ( RArray, RInstance, @@ -710,7 +711,7 @@ def __init__( priority: int, is_pure: bool, experimental: bool, - capsule: str | None, + dependencies: list[Dependency] | None, ) -> None: # Each primitive much have a distinct name, but otherwise they are arbitrary. self.name: Final = name @@ -736,9 +737,9 @@ def __init__( # Experimental primitives are not used unless mypyc experimental features are # explicitly enabled self.experimental = experimental - # Capsule that needs to imported and configured to call the primitive - # (name of the target module, e.g. "librt.base64"). - self.capsule = capsule + # Dependencies for the primitive, such as a capsule that needs to imported + # and configured to call the primitive. + self.dependencies = dependencies # Native integer types such as u8 can cause ambiguity in primitive # matching, since these are assignable to plain int *and* vice versa. # If this flag is set, the primitive has native integer types and must @@ -1252,7 +1253,7 @@ def __init__( *, is_pure: bool = False, returns_null: bool = False, - capsule: str | None = None, + dependencies: list[Dependency] | None = None, ) -> None: self.error_kind = error_kind super().__init__(line) @@ -1270,9 +1271,9 @@ def __init__( # The function might return a null value that does not indicate # an error. self.returns_null = returns_null - # A capsule from this module must be imported and initialized before calling this - # function (used for C functions exported from librt). Example value: "librt.base64" - self.capsule = capsule + # Dependencies (such as capsules) that must be imported and initialized before + # calling this function (used for C functions exported from librt). + self.dependencies = dependencies if is_pure or returns_null: assert error_kind == ERR_NEVER diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 0a9e2d230ca5..dcb352679cf8 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -2075,7 +2075,7 @@ def call_c( var_arg_idx, is_pure=desc.is_pure, returns_null=desc.returns_null, - capsule=desc.capsule, + dependencies=desc.dependencies, ) ) if desc.is_borrowed: @@ -2160,7 +2160,7 @@ def primitive_op( desc.priority, is_pure=desc.is_pure, returns_null=False, - capsule=desc.capsule, + dependencies=desc.dependencies, ) return self.call_c(c_desc, args, line, result_type=result_type) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 15995db3dfd0..5526508f5aca 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -783,7 +783,6 @@ PyObject *CPyBytes_Concat(PyObject *a, PyObject *b); PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter); CPyTagged CPyBytes_Ord(PyObject *obj); PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count); -PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table); int CPyBytes_Compare(PyObject *left, PyObject *right); diff --git a/mypyc/lib-rt/bytes_extra_ops.c b/mypyc/lib-rt/bytes_extra_ops.c new file mode 100644 index 000000000000..b5d2d9996d52 --- /dev/null +++ b/mypyc/lib-rt/bytes_extra_ops.c @@ -0,0 +1,50 @@ +#include "bytes_extra_ops.h" + +PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table) { + // Fast path: exact bytes object with exact bytes table + if (PyBytes_CheckExact(bytes) && PyBytes_CheckExact(table)) { + Py_ssize_t table_len = PyBytes_GET_SIZE(table); + if (table_len != 256) { + PyErr_SetString(PyExc_ValueError, + "translation table must be 256 characters long"); + return NULL; + } + + Py_ssize_t len = PyBytes_GET_SIZE(bytes); + const char *input = PyBytes_AS_STRING(bytes); + const char *trans_table = PyBytes_AS_STRING(table); + + PyObject *result = PyBytes_FromStringAndSize(NULL, len); + if (result == NULL) { + return NULL; + } + + char *output = PyBytes_AS_STRING(result); + bool changed = false; + + // Without a loop unrolling hint performance can be worse than CPython + CPY_UNROLL_LOOP(4) + for (Py_ssize_t i = len; --i >= 0;) { + char c = *input++; + if ((*output++ = trans_table[(unsigned char)c]) != c) + changed = true; + } + + // If nothing changed, discard result and return the original object + if (!changed) { + Py_DECREF(result); + Py_INCREF(bytes); + return bytes; + } + + return result; + } + + // Fallback to Python method call for non-exact types or non-standard tables + _Py_IDENTIFIER(translate); + PyObject *name = _PyUnicode_FromId(&PyId_translate); + if (name == NULL) { + return NULL; + } + return PyObject_CallMethodOneArg(bytes, name, table); +} diff --git a/mypyc/lib-rt/bytes_extra_ops.h b/mypyc/lib-rt/bytes_extra_ops.h new file mode 100644 index 000000000000..eebb5a345438 --- /dev/null +++ b/mypyc/lib-rt/bytes_extra_ops.h @@ -0,0 +1,10 @@ +#ifndef MYPYC_BYTES_EXTRA_OPS_H +#define MYPYC_BYTES_EXTRA_OPS_H + +#include +#include "CPy.h" + +// Optimized bytes translate operation +PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table); + +#endif diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index 6c138ad90db0..8ecf9337c28b 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -171,52 +171,3 @@ PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) { } return PySequence_Repeat(bytes, temp_count); } - -PyObject *CPyBytes_Translate(PyObject *bytes, PyObject *table) { - // Fast path: exact bytes object with exact bytes table - if (PyBytes_CheckExact(bytes) && PyBytes_CheckExact(table)) { - Py_ssize_t table_len = PyBytes_GET_SIZE(table); - if (table_len != 256) { - PyErr_SetString(PyExc_ValueError, - "translation table must be 256 characters long"); - return NULL; - } - - Py_ssize_t len = PyBytes_GET_SIZE(bytes); - const char *input = PyBytes_AS_STRING(bytes); - const char *trans_table = PyBytes_AS_STRING(table); - - PyObject *result = PyBytes_FromStringAndSize(NULL, len); - if (result == NULL) { - return NULL; - } - - char *output = PyBytes_AS_STRING(result); - bool changed = false; - - // Without a loop unrolling hint performance can be worse than CPython - CPY_UNROLL_LOOP(4) - for (Py_ssize_t i = len; --i >= 0;) { - char c = *input++; - if ((*output++ = trans_table[(unsigned char)c]) != c) - changed = true; - } - - // If nothing changed, discard result and return the original object - if (!changed) { - Py_DECREF(result); - Py_INCREF(bytes); - return bytes; - } - - return result; - } - - // Fallback to Python method call for non-exact types or non-standard tables - _Py_IDENTIFIER(translate); - PyObject *name = _PyUnicode_FromId(&PyId_translate); - if (name == NULL) { - return NULL; - } - return PyObject_CallMethodOneArg(bytes, name, table); -} diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index 982c41dc25b1..0669ddac00df 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -2,6 +2,7 @@ from __future__ import annotations +from mypyc.ir.deps import BYTES_EXTRA_OPS from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( RUnion, @@ -135,6 +136,7 @@ return_type=bytes_rprimitive, c_function_name="CPyBytes_Translate", error_kind=ERR_MAGIC, + dependencies=[BYTES_EXTRA_OPS], ) # Join bytes objects and return a new bytes. diff --git a/mypyc/primitives/librt_strings_ops.py b/mypyc/primitives/librt_strings_ops.py index 23a4033dced4..786784d15f60 100644 --- a/mypyc/primitives/librt_strings_ops.py +++ b/mypyc/primitives/librt_strings_ops.py @@ -1,5 +1,6 @@ from typing import Final +from mypyc.ir.deps import LIBRT_STRINGS from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( KNOWN_NATIVE_TYPES, @@ -20,7 +21,7 @@ c_function_name="LibRTStrings_BytesWriter_internal", error_kind=ERR_MAGIC, experimental=True, - capsule="librt.strings", + dependencies=[LIBRT_STRINGS], ) method_op( @@ -30,7 +31,7 @@ c_function_name="LibRTStrings_BytesWriter_getvalue_internal", error_kind=ERR_MAGIC, experimental=True, - capsule="librt.strings", + dependencies=[LIBRT_STRINGS], ) method_op( @@ -40,7 +41,7 @@ c_function_name="LibRTStrings_BytesWriter_write_internal", error_kind=ERR_MAGIC, experimental=True, - capsule="librt.strings", + dependencies=[LIBRT_STRINGS], ) method_op( @@ -50,7 +51,7 @@ c_function_name="LibRTStrings_BytesWriter_append_internal", error_kind=ERR_MAGIC, experimental=True, - capsule="librt.strings", + dependencies=[LIBRT_STRINGS], ) method_op( @@ -68,5 +69,5 @@ c_function_name="LibRTStrings_BytesWriter_len_internal", error_kind=ERR_NEVER, experimental=True, - capsule="librt.strings", + dependencies=[LIBRT_STRINGS], ) diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index f6cb4fee9759..01853341b8bf 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -2,6 +2,7 @@ from __future__ import annotations +from mypyc.ir.deps import LIBRT_BASE64 from mypyc.ir.ops import ERR_FALSE, ERR_MAGIC, ERR_MAGIC_OVERLAPPING, ERR_NEVER from mypyc.ir.rtypes import ( KNOWN_NATIVE_TYPES, @@ -475,7 +476,7 @@ error_kind=ERR_MAGIC, extra_int_constants=[(0, bool_rprimitive)], experimental=True, - capsule="librt.base64", + dependencies=[LIBRT_BASE64], ) function_op( @@ -486,7 +487,7 @@ error_kind=ERR_MAGIC, extra_int_constants=[(1, bool_rprimitive)], experimental=True, - capsule="librt.base64", + dependencies=[LIBRT_BASE64], ) function_op( @@ -497,7 +498,7 @@ error_kind=ERR_MAGIC, extra_int_constants=[(0, bool_rprimitive)], experimental=True, - capsule="librt.base64", + dependencies=[LIBRT_BASE64], ) function_op( @@ -508,7 +509,7 @@ error_kind=ERR_MAGIC, extra_int_constants=[(1, bool_rprimitive)], experimental=True, - capsule="librt.base64", + dependencies=[LIBRT_BASE64], ) cpyfunction_get_name = function_op( diff --git a/mypyc/primitives/registry.py b/mypyc/primitives/registry.py index c7ab6925b43d..a59599f693c4 100644 --- a/mypyc/primitives/registry.py +++ b/mypyc/primitives/registry.py @@ -39,6 +39,7 @@ from typing import Final, NamedTuple +from mypyc.ir.deps import Dependency from mypyc.ir.ops import PrimitiveDescription, StealsDescription from mypyc.ir.rtypes import RType @@ -62,7 +63,7 @@ class CFunctionDescription(NamedTuple): priority: int is_pure: bool returns_null: bool - capsule: str | None + dependencies: list[Dependency] | None # A description for C load operations including LoadGlobal and LoadAddress @@ -102,7 +103,7 @@ def method_op( priority: int = 1, is_pure: bool = False, experimental: bool = False, - capsule: str | None = None, + dependencies: list[Dependency] | None = None, ) -> PrimitiveDescription: """Define a c function call op that replaces a method call. @@ -148,7 +149,7 @@ def method_op( priority, is_pure=is_pure, experimental=experimental, - capsule=capsule, + dependencies=dependencies, ) ops.append(desc) return desc @@ -168,7 +169,7 @@ def function_op( is_borrowed: bool = False, priority: int = 1, experimental: bool = False, - capsule: str | None = None, + dependencies: list[Dependency] | None = None, ) -> PrimitiveDescription: """Define a C function call op that replaces a function call. @@ -198,7 +199,7 @@ def function_op( priority=priority, is_pure=False, experimental=experimental, - capsule=capsule, + dependencies=dependencies, ) ops.append(desc) return desc @@ -218,7 +219,7 @@ def binary_op( steals: StealsDescription = False, is_borrowed: bool = False, priority: int = 1, - capsule: str | None = None, + dependencies: list[Dependency] | None = None, ) -> PrimitiveDescription: """Define a c function call op for a binary operation. @@ -247,7 +248,7 @@ def binary_op( priority=priority, is_pure=False, experimental=False, - capsule=capsule, + dependencies=dependencies, ) ops.append(desc) return desc @@ -289,7 +290,7 @@ def custom_op( 0, is_pure=is_pure, returns_null=returns_null, - capsule=None, + dependencies=None, ) @@ -307,7 +308,7 @@ def custom_primitive_op( is_borrowed: bool = False, is_pure: bool = False, experimental: bool = False, - capsule: str | None = None, + dependencies: list[Dependency] | None = None, ) -> PrimitiveDescription: """Define a primitive op that can't be automatically generated based on the AST. @@ -330,7 +331,7 @@ def custom_primitive_op( priority=0, is_pure=is_pure, experimental=experimental, - capsule=capsule, + dependencies=dependencies, ) @@ -347,7 +348,7 @@ def unary_op( is_borrowed: bool = False, priority: int = 1, is_pure: bool = False, - capsule: str | None = None, + dependencies: list[Dependency] | None = None, ) -> PrimitiveDescription: """Define a primitive op for an unary operation. @@ -374,7 +375,7 @@ def unary_op( priority=priority, is_pure=is_pure, experimental=False, - capsule=capsule, + dependencies=dependencies, ) ops.append(desc) return desc diff --git a/mypyc/test-data/run-multimodule.test b/mypyc/test-data/run-multimodule.test index 9323612cb4fb..216aed25a5e5 100644 --- a/mypyc/test-data/run-multimodule.test +++ b/mypyc/test-data/run-multimodule.test @@ -950,3 +950,15 @@ import native 15 NT(x=4) {'x': 5} + +[case testExtraLibRtSourceFileDep] +table_list = list(range(256)) +table_list[ord('A')] = ord('B') +table = bytes(table_list) + +def translate(b: bytes) -> bytes: + # The primimitive for bytes.translate requires an optional C file from lib-rt + return b.translate(table) +[file driver.py] +import native +assert native.translate(b'ABCD') == b'BBCD' diff --git a/mypyc/test/test_cheader.py b/mypyc/test/test_cheader.py index ec9e2c4cf450..d21eefdb9bc5 100644 --- a/mypyc/test/test_cheader.py +++ b/mypyc/test/test_cheader.py @@ -7,6 +7,7 @@ import re import unittest +from mypyc.ir.deps import SourceDep from mypyc.primitives import registry @@ -24,6 +25,7 @@ def check_name(name: str) -> None: rf"\b{name}\b", header ), f'"{name}" is used in mypyc.primitives but not declared in CPy.h' + all_ops = [] for values in [ registry.method_call_ops.values(), registry.binary_ops.values(), @@ -31,9 +33,21 @@ def check_name(name: str) -> None: registry.function_ops.values(), ]: for ops in values: - for op in ops: - if op.c_function_name is not None: - check_name(op.c_function_name) + all_ops.extend(ops) + + # Find additional headers via extra C source file dependencies. + for op in all_ops: + if op.dependencies: + for dep in op.dependencies: + if isinstance(dep, SourceDep): + header_fnam = os.path.join(base_dir, dep.get_header()) + if os.path.isfile(header_fnam): + with open(os.path.join(base_dir, header_fnam)) as f: + header += f.read() + + for op in all_ops: + if op.c_function_name is not None: + check_name(op.c_function_name) primitives_path = os.path.join(os.path.dirname(__file__), "..", "primitives") for fnam in glob.glob(f"{primitives_path}/*.py"): diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index d332a7466af4..681e15b58844 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -22,6 +22,7 @@ from mypy.test.helpers import assert_module_equivalence, perform_file_operations from mypyc.build import construct_groups from mypyc.codegen import emitmodule +from mypyc.codegen.emitmodule import collect_source_dependencies from mypyc.errors import Errors from mypyc.options import CompilerOptions from mypyc.test.config import test_data_prefix @@ -266,6 +267,7 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> ir, cfiles, _ = emitmodule.compile_modules_to_c( result, compiler_options=compiler_options, errors=errors, groups=groups ) + deps = sorted(dep.path for dep in collect_source_dependencies(ir)) if errors.num_errors: errors.flush_errors() assert False, "Compile error" @@ -288,7 +290,7 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> setup_format.format( module_paths, separate, - cfiles, + (cfiles, deps), self.multi_file, opt_level, librt, diff --git a/mypyc/transform/exceptions.py b/mypyc/transform/exceptions.py index 33dfeb693cf7..28bbd80c52cc 100644 --- a/mypyc/transform/exceptions.py +++ b/mypyc/transform/exceptions.py @@ -146,6 +146,7 @@ def primitive_call(desc: CFunctionDescription, args: list[Value], line: int) -> desc.is_borrowed, desc.error_kind, line, + dependencies=desc.dependencies, )