Skip to content

Commit 503537f

Browse files
committed
Introduce array_api_stubs.py
1 parent 792f11e commit 503537f

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed

array_api_stubs.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import sys
2+
from importlib import import_module
3+
from importlib.util import find_spec
4+
from inspect import getmembers, isfunction, signature
5+
from pathlib import Path
6+
from types import FunctionType, ModuleType, SimpleNamespace
7+
from typing import Dict, List, Optional
8+
from unittest import TestCase
9+
10+
__all__ = ["make_stubs_namespace"]
11+
12+
API_VERSIONS = {"2012.12"} # TODO: infer released versions dynamically
13+
14+
15+
def make_stubs_namespace(api_version: Optional[str] = None) -> SimpleNamespace:
16+
"""
17+
Returns a ``SimpleNamespace`` where
18+
19+
* ``functions`` (``dict[str, FunctionType]``) maps names of top-level
20+
functions to their respective stubs.
21+
* ``array_methods`` (``dict[str, FunctionType]``) maps names of array
22+
methods to their respective stubs.
23+
* ``dtype_methods`` (``dict[str, FunctionType]``) maps names of dtype object
24+
methods to their respective stubs.
25+
* ``category_to_functions`` (``dict[str, dict[str, FunctionType]]``) maps
26+
names of categories to their respective function mappings.
27+
* ``extension_to_functions`` (``dict[str, dict[str, FunctionType]]``) maps
28+
names of extensions to their respective function mappings.
29+
30+
Examples
31+
--------
32+
33+
Make a stubs namespace.
34+
35+
>>> from array_api_stubs import make_stubs_namespace
36+
>>> stubs = make_stubs_namespace()
37+
38+
Access the ``array_api.square()`` stub.
39+
40+
>>> stubs.functions["square"]
41+
<function array_api.square(x: ~array, /) -> ~array>
42+
43+
Find names of all set functions.
44+
45+
>>> stubs.category_to_functions["set"].keys()
46+
dict_keys(['unique_all', 'unique_counts', 'unique_inverse', 'unique_values'])
47+
48+
Access the array object's ``__add__`` stub.
49+
50+
>>> stubs.array_methods["__add__"].keys()
51+
<function array_api._Array.__add__(self: 'array', other: 'Union[int, float, array]', /) -> 'array'>
52+
53+
Access the ``array_api.linalg.cross()`` stub.
54+
55+
>>> stubs.extension_to_functions["linalg"]["cross"]
56+
<function array_api.linalg.cross(x1: ~array, x2: ~array, /, *, axis: int = -1) -> ~array>
57+
58+
"""
59+
if api_version is None:
60+
api_version = "draft"
61+
if api_version in API_VERSIONS or api_version == "latest":
62+
raise NotImplementedError("{api_version=} not yet supported")
63+
else:
64+
raise ValueError(
65+
f"{api_version=} not 'draft', 'latest', "
66+
f"or a released version ({API_VERSIONS})"
67+
)
68+
69+
spec_dir = Path(__file__).parent / "spec" / "API_specification"
70+
signatures_dir = spec_dir / "array_api"
71+
assert signatures_dir.exists() # sanity check
72+
spec_abs_path: str = str(spec_dir.resolve())
73+
sys.path.append(spec_abs_path)
74+
assert find_spec("array_api") is not None # sanity check
75+
76+
name_to_mod: Dict[str, ModuleType] = {}
77+
for path in signatures_dir.glob("*.py"):
78+
name = path.name.replace(".py", "")
79+
name_to_mod[name] = import_module(f"array_api.{name}")
80+
81+
array = name_to_mod["array_object"].array
82+
array_methods: Dict[str, FunctionType] = {}
83+
for name, func in getmembers(array, predicate=isfunction):
84+
func.__module__ = "array_api"
85+
assert "Alias" not in func.__doc__ # sanity check
86+
func.__qualname__ = f"_Array.{func.__name__}"
87+
array_methods[name] = func
88+
89+
dtype_eq = name_to_mod["data_types"].__eq__
90+
assert isinstance(dtype_eq, FunctionType) # for mypy
91+
dtype_eq.__module__ = "array_api"
92+
dtype_eq.__qualname__ = "_DataType.__eq__"
93+
dtype_methods: Dict[str, FunctionType] = {"__eq__": dtype_eq}
94+
95+
functions: Dict[str, FunctionType] = {}
96+
category_to_functions: Dict[str, Dict[str, FunctionType]] = {}
97+
for name, mod in name_to_mod.items():
98+
if name.endswith("_functions"):
99+
category = name.replace("_functions", "")
100+
name_to_func = {}
101+
for name in mod.__all__:
102+
func = getattr(mod, name)
103+
assert isinstance(func, FunctionType) # sanity check
104+
func.__module__ = "array_api"
105+
name_to_func[name] = func
106+
functions.update(name_to_func)
107+
category_to_functions[category] = name_to_func
108+
109+
extensions: List[str] = ["linalg"] # TODO: infer on runtime
110+
extension_to_functions: Dict[str, Dict[str, FunctionType]] = {}
111+
for ext in extensions:
112+
mod = name_to_mod[ext]
113+
name_to_func = {name: getattr(mod, name) for name in mod.__all__}
114+
name_to_func = {}
115+
for name in mod.__all__:
116+
func = getattr(mod, name)
117+
assert isinstance(func, FunctionType) # sanity check
118+
assert func.__doc__ is not None # for mypy
119+
if "Alias" in func.__doc__:
120+
func.__doc__ = functions[name].__doc__
121+
func.__module__ = f"array_api.{ext}"
122+
name_to_func[name] = func
123+
extension_to_functions[ext] = name_to_func
124+
125+
return SimpleNamespace(
126+
functions=functions,
127+
array_methods=array_methods,
128+
dtype_methods=dtype_methods,
129+
category_to_functions=category_to_functions,
130+
extension_to_functions=extension_to_functions,
131+
)
132+
133+
134+
class TestMakeStubsNamespace(TestCase):
135+
def setUp(self):
136+
self.stubs = make_stubs_namespace()
137+
138+
def test_attributes(self):
139+
assert isinstance(self.stubs, SimpleNamespace)
140+
for attr in ["functions", "array_methods", "dtype_methods"]:
141+
mapping = getattr(self.stubs, attr)
142+
assert isinstance(mapping, dict)
143+
assert all(isinstance(k, str) for k in mapping.keys())
144+
assert all(isinstance(v, FunctionType) for v in mapping.values())
145+
for attr in ["category_to_functions", "extension_to_functions"]:
146+
mapping = getattr(self.stubs, attr)
147+
assert isinstance(mapping, dict)
148+
assert all(isinstance(k, str) for k in mapping.keys())
149+
for sub_mapping in mapping.values():
150+
assert isinstance(sub_mapping, dict)
151+
assert all(isinstance(k, str) for k in sub_mapping.keys())
152+
assert all(isinstance(v, FunctionType) for v in sub_mapping.values())
153+
154+
def test_function_meta(self):
155+
toplevel_stub = self.stubs.functions["matmul"]
156+
assert toplevel_stub.__module__ == "array_api"
157+
extension_stub = self.stubs.extension_to_functions["linalg"]["matmul"]
158+
assert extension_stub.__module__ == "array_api.linalg"
159+
assert extension_stub.__doc__ == toplevel_stub.__doc__
160+
161+
def test_array_method_meta(self):
162+
stub = self.stubs.array_methods["__add__"]
163+
assert stub.__module__ == "array_api"
164+
assert stub.__qualname__ == "_Array.__add__"
165+
first_arg = next(iter(signature(stub).parameters.values()))
166+
assert first_arg.name == "self"
167+
168+
def test_dtype_method_meta(self):
169+
stub = self.stubs.dtype_methods["__eq__"]
170+
assert stub.__module__ == "array_api"
171+
assert stub.__qualname__ == "_DataType.__eq__"
172+
first_arg = next(iter(signature(stub).parameters.values()))
173+
assert first_arg.name == "self"

0 commit comments

Comments
 (0)