|
| 1 | +import sys |
| 2 | +from importlib import import_module |
| 3 | +from importlib.util import find_spec |
| 4 | +from inspect import getmembers, isfunction, signature |
| 5 | +from pathlib import Path |
| 6 | +from types import FunctionType, ModuleType, SimpleNamespace |
| 7 | +from typing import Dict, List, Optional |
| 8 | +from unittest import TestCase |
| 9 | + |
| 10 | +__all__ = ["make_stubs_namespace"] |
| 11 | + |
| 12 | +API_VERSIONS = {"2012.12"} # TODO: infer released versions dynamically |
| 13 | + |
| 14 | + |
| 15 | +def make_stubs_namespace(api_version: Optional[str] = None) -> SimpleNamespace: |
| 16 | + """ |
| 17 | + Returns a ``SimpleNamespace`` where |
| 18 | +
|
| 19 | + * ``functions`` (``dict[str, FunctionType]``) maps names of top-level |
| 20 | + functions to their respective stubs. |
| 21 | + * ``array_methods`` (``dict[str, FunctionType]``) maps names of array |
| 22 | + methods to their respective stubs. |
| 23 | + * ``dtype_methods`` (``dict[str, FunctionType]``) maps names of dtype object |
| 24 | + methods to their respective stubs. |
| 25 | + * ``category_to_functions`` (``dict[str, dict[str, FunctionType]]``) maps |
| 26 | + names of categories to their respective function mappings. |
| 27 | + * ``extension_to_functions`` (``dict[str, dict[str, FunctionType]]``) maps |
| 28 | + names of extensions to their respective function mappings. |
| 29 | +
|
| 30 | + Examples |
| 31 | + -------- |
| 32 | +
|
| 33 | + Make a stubs namespace. |
| 34 | +
|
| 35 | + >>> from array_api_stubs import make_stubs_namespace |
| 36 | + >>> stubs = make_stubs_namespace() |
| 37 | +
|
| 38 | + Access the ``array_api.square()`` stub. |
| 39 | +
|
| 40 | + >>> stubs.functions["square"] |
| 41 | + <function array_api.square(x: ~array, /) -> ~array> |
| 42 | +
|
| 43 | + Find names of all set functions. |
| 44 | +
|
| 45 | + >>> stubs.category_to_functions["set"].keys() |
| 46 | + dict_keys(['unique_all', 'unique_counts', 'unique_inverse', 'unique_values']) |
| 47 | +
|
| 48 | + Access the array object's ``__add__`` stub. |
| 49 | +
|
| 50 | + >>> stubs.array_methods["__add__"].keys() |
| 51 | + <function array_api._Array.__add__(self: 'array', other: 'Union[int, float, array]', /) -> 'array'> |
| 52 | +
|
| 53 | + Access the ``array_api.linalg.cross()`` stub. |
| 54 | +
|
| 55 | + >>> stubs.extension_to_functions["linalg"]["cross"] |
| 56 | + <function array_api.linalg.cross(x1: ~array, x2: ~array, /, *, axis: int = -1) -> ~array> |
| 57 | +
|
| 58 | + """ |
| 59 | + if api_version is None: |
| 60 | + api_version = "draft" |
| 61 | + if api_version in API_VERSIONS or api_version == "latest": |
| 62 | + raise NotImplementedError("{api_version=} not yet supported") |
| 63 | + else: |
| 64 | + raise ValueError( |
| 65 | + f"{api_version=} not 'draft', 'latest', " |
| 66 | + f"or a released version ({API_VERSIONS})" |
| 67 | + ) |
| 68 | + |
| 69 | + spec_dir = Path(__file__).parent / "spec" / "API_specification" |
| 70 | + signatures_dir = spec_dir / "array_api" |
| 71 | + assert signatures_dir.exists() # sanity check |
| 72 | + spec_abs_path: str = str(spec_dir.resolve()) |
| 73 | + sys.path.append(spec_abs_path) |
| 74 | + assert find_spec("array_api") is not None # sanity check |
| 75 | + |
| 76 | + name_to_mod: Dict[str, ModuleType] = {} |
| 77 | + for path in signatures_dir.glob("*.py"): |
| 78 | + name = path.name.replace(".py", "") |
| 79 | + name_to_mod[name] = import_module(f"array_api.{name}") |
| 80 | + |
| 81 | + array = name_to_mod["array_object"].array |
| 82 | + array_methods: Dict[str, FunctionType] = {} |
| 83 | + for name, func in getmembers(array, predicate=isfunction): |
| 84 | + func.__module__ = "array_api" |
| 85 | + assert "Alias" not in func.__doc__ # sanity check |
| 86 | + func.__qualname__ = f"_Array.{func.__name__}" |
| 87 | + array_methods[name] = func |
| 88 | + |
| 89 | + dtype_eq = name_to_mod["data_types"].__eq__ |
| 90 | + assert isinstance(dtype_eq, FunctionType) # for mypy |
| 91 | + dtype_eq.__module__ = "array_api" |
| 92 | + dtype_eq.__qualname__ = "_DataType.__eq__" |
| 93 | + dtype_methods: Dict[str, FunctionType] = {"__eq__": dtype_eq} |
| 94 | + |
| 95 | + functions: Dict[str, FunctionType] = {} |
| 96 | + category_to_functions: Dict[str, Dict[str, FunctionType]] = {} |
| 97 | + for name, mod in name_to_mod.items(): |
| 98 | + if name.endswith("_functions"): |
| 99 | + category = name.replace("_functions", "") |
| 100 | + name_to_func = {} |
| 101 | + for name in mod.__all__: |
| 102 | + func = getattr(mod, name) |
| 103 | + assert isinstance(func, FunctionType) # sanity check |
| 104 | + func.__module__ = "array_api" |
| 105 | + name_to_func[name] = func |
| 106 | + functions.update(name_to_func) |
| 107 | + category_to_functions[category] = name_to_func |
| 108 | + |
| 109 | + extensions: List[str] = ["linalg"] # TODO: infer on runtime |
| 110 | + extension_to_functions: Dict[str, Dict[str, FunctionType]] = {} |
| 111 | + for ext in extensions: |
| 112 | + mod = name_to_mod[ext] |
| 113 | + name_to_func = {name: getattr(mod, name) for name in mod.__all__} |
| 114 | + name_to_func = {} |
| 115 | + for name in mod.__all__: |
| 116 | + func = getattr(mod, name) |
| 117 | + assert isinstance(func, FunctionType) # sanity check |
| 118 | + assert func.__doc__ is not None # for mypy |
| 119 | + if "Alias" in func.__doc__: |
| 120 | + func.__doc__ = functions[name].__doc__ |
| 121 | + func.__module__ = f"array_api.{ext}" |
| 122 | + name_to_func[name] = func |
| 123 | + extension_to_functions[ext] = name_to_func |
| 124 | + |
| 125 | + return SimpleNamespace( |
| 126 | + functions=functions, |
| 127 | + array_methods=array_methods, |
| 128 | + dtype_methods=dtype_methods, |
| 129 | + category_to_functions=category_to_functions, |
| 130 | + extension_to_functions=extension_to_functions, |
| 131 | + ) |
| 132 | + |
| 133 | + |
| 134 | +class TestMakeStubsNamespace(TestCase): |
| 135 | + def setUp(self): |
| 136 | + self.stubs = make_stubs_namespace() |
| 137 | + |
| 138 | + def test_attributes(self): |
| 139 | + assert isinstance(self.stubs, SimpleNamespace) |
| 140 | + for attr in ["functions", "array_methods", "dtype_methods"]: |
| 141 | + mapping = getattr(self.stubs, attr) |
| 142 | + assert isinstance(mapping, dict) |
| 143 | + assert all(isinstance(k, str) for k in mapping.keys()) |
| 144 | + assert all(isinstance(v, FunctionType) for v in mapping.values()) |
| 145 | + for attr in ["category_to_functions", "extension_to_functions"]: |
| 146 | + mapping = getattr(self.stubs, attr) |
| 147 | + assert isinstance(mapping, dict) |
| 148 | + assert all(isinstance(k, str) for k in mapping.keys()) |
| 149 | + for sub_mapping in mapping.values(): |
| 150 | + assert isinstance(sub_mapping, dict) |
| 151 | + assert all(isinstance(k, str) for k in sub_mapping.keys()) |
| 152 | + assert all(isinstance(v, FunctionType) for v in sub_mapping.values()) |
| 153 | + |
| 154 | + def test_function_meta(self): |
| 155 | + toplevel_stub = self.stubs.functions["matmul"] |
| 156 | + assert toplevel_stub.__module__ == "array_api" |
| 157 | + extension_stub = self.stubs.extension_to_functions["linalg"]["matmul"] |
| 158 | + assert extension_stub.__module__ == "array_api.linalg" |
| 159 | + assert extension_stub.__doc__ == toplevel_stub.__doc__ |
| 160 | + |
| 161 | + def test_array_method_meta(self): |
| 162 | + stub = self.stubs.array_methods["__add__"] |
| 163 | + assert stub.__module__ == "array_api" |
| 164 | + assert stub.__qualname__ == "_Array.__add__" |
| 165 | + first_arg = next(iter(signature(stub).parameters.values())) |
| 166 | + assert first_arg.name == "self" |
| 167 | + |
| 168 | + def test_dtype_method_meta(self): |
| 169 | + stub = self.stubs.dtype_methods["__eq__"] |
| 170 | + assert stub.__module__ == "array_api" |
| 171 | + assert stub.__qualname__ == "_DataType.__eq__" |
| 172 | + first_arg = next(iter(signature(stub).parameters.values())) |
| 173 | + assert first_arg.name == "self" |
0 commit comments