Skip to content

Add an Array Protocol & improve static typing support #589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -48,3 +48,29 @@ repos:
rev: 23.7.0
hooks:
- id: black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.0.0"
hooks:
- id: mypy
additional_dependencies: [typing_extensions>=4.4.0]
args:
- --ignore-missing-imports
- --config=pyproject.toml
files: ".*(_draft.*)$"
exclude: |
(?x)^(
.*creation_functions.py|
.*data_type_functions.py|
.*elementwise_functions.py|
.*fft.py|
.*indexing_functions.py|
.*linalg.py|
.*linear_algebra_functions.py|
.*manipulation_functions.py|
.*searching_functions.py|
.*set_functions.py|
.*sorting_functions.py|
.*statistical_functions.py|
.*utility_functions.py|
)$
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -30,5 +30,15 @@ doc = [
requires = ["setuptools"]
build-backend = "setuptools.build_meta"


[tool.black]
line-length = 88


[tool.mypy]
python_version = "3.9"
mypy_path = "$MYPY_CONFIG_FILE_DIR/src/array_api_stubs/_draft/"
files = [
"src/array_api_stubs/_draft/**/*.py"
]
follow_imports = "silent"
128 changes: 64 additions & 64 deletions spec/draft/API_specification/array_object.rst
Original file line number Diff line number Diff line change
@@ -30,47 +30,47 @@ Arithmetic Operators

A conforming implementation of the array API standard must provide and support an array object supporting the following Python arithmetic operators.

- ``+x``: :meth:`.array.__pos__`
- ``+x``: :meth:`.Array.__pos__`

- `operator.pos(x) <https://docs.python.org/3/library/operator.html#operator.pos>`_
- `operator.__pos__(x) <https://docs.python.org/3/library/operator.html#operator.__pos__>`_

- `-x`: :meth:`.array.__neg__`
- `-x`: :meth:`.Array.__neg__`

- `operator.neg(x) <https://docs.python.org/3/library/operator.html#operator.neg>`_
- `operator.__neg__(x) <https://docs.python.org/3/library/operator.html#operator.__neg__>`_

- `x1 + x2`: :meth:`.array.__add__`
- `x1 + x2`: :meth:`.Array.__add__`

- `operator.add(x1, x2) <https://docs.python.org/3/library/operator.html#operator.add>`_
- `operator.__add__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__add__>`_

- `x1 - x2`: :meth:`.array.__sub__`
- `x1 - x2`: :meth:`.Array.__sub__`

- `operator.sub(x1, x2) <https://docs.python.org/3/library/operator.html#operator.sub>`_
- `operator.__sub__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__sub__>`_

- `x1 * x2`: :meth:`.array.__mul__`
- `x1 * x2`: :meth:`.Array.__mul__`

- `operator.mul(x1, x2) <https://docs.python.org/3/library/operator.html#operator.mul>`_
- `operator.__mul__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__mul__>`_

- `x1 / x2`: :meth:`.array.__truediv__`
- `x1 / x2`: :meth:`.Array.__truediv__`

- `operator.truediv(x1,x2) <https://docs.python.org/3/library/operator.html#operator.truediv>`_
- `operator.__truediv__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__truediv__>`_

- `x1 // x2`: :meth:`.array.__floordiv__`
- `x1 // x2`: :meth:`.Array.__floordiv__`

- `operator.floordiv(x1, x2) <https://docs.python.org/3/library/operator.html#operator.floordiv>`_
- `operator.__floordiv__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__floordiv__>`_

- `x1 % x2`: :meth:`.array.__mod__`
- `x1 % x2`: :meth:`.Array.__mod__`

- `operator.mod(x1, x2) <https://docs.python.org/3/library/operator.html#operator.mod>`_
- `operator.__mod__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__mod__>`_

- `x1 ** x2`: :meth:`.array.__pow__`
- `x1 ** x2`: :meth:`.Array.__pow__`

- `operator.pow(x1, x2) <https://docs.python.org/3/library/operator.html#operator.pow>`_
- `operator.__pow__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__pow__>`_
@@ -82,7 +82,7 @@ Array Operators

A conforming implementation of the array API standard must provide and support an array object supporting the following Python array operators.

- `x1 @ x2`: :meth:`.array.__matmul__`
- `x1 @ x2`: :meth:`.Array.__matmul__`

- `operator.matmul(x1, x2) <https://docs.python.org/3/library/operator.html#operator.matmul>`_
- `operator.__matmul__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__matmul__>`_
@@ -94,34 +94,34 @@ Bitwise Operators

A conforming implementation of the array API standard must provide and support an array object supporting the following Python bitwise operators.

- `~x`: :meth:`.array.__invert__`
- `~x`: :meth:`.Array.__invert__`

- `operator.inv(x) <https://docs.python.org/3/library/operator.html#operator.inv>`_
- `operator.invert(x) <https://docs.python.org/3/library/operator.html#operator.invert>`_
- `operator.__inv__(x) <https://docs.python.org/3/library/operator.html#operator.__inv__>`_
- `operator.__invert__(x) <https://docs.python.org/3/library/operator.html#operator.__invert__>`_

- `x1 & x2`: :meth:`.array.__and__`
- `x1 & x2`: :meth:`.Array.__and__`

- `operator.and(x1, x2) <https://docs.python.org/3/library/operator.html#operator.and>`_
- `operator.__and__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__and__>`_

- `x1 | x2`: :meth:`.array.__or__`
- `x1 | x2`: :meth:`.Array.__or__`

- `operator.or(x1, x2) <https://docs.python.org/3/library/operator.html#operator.or>`_
- `operator.__or__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__or__>`_

- `x1 ^ x2`: :meth:`.array.__xor__`
- `x1 ^ x2`: :meth:`.Array.__xor__`

- `operator.xor(x1, x2) <https://docs.python.org/3/library/operator.html#operator.xor>`_
- `operator.__xor__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__xor__>`_

- `x1 << x2`: :meth:`.array.__lshift__`
- `x1 << x2`: :meth:`.Array.__lshift__`

- `operator.lshift(x1, x2) <https://docs.python.org/3/library/operator.html#operator.lshift>`_
- `operator.__lshift__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__lshift__>`_

- `x1 >> x2`: :meth:`.array.__rshift__`
- `x1 >> x2`: :meth:`.Array.__rshift__`

- `operator.rshift(x1, x2) <https://docs.python.org/3/library/operator.html#operator.rshift>`_
- `operator.__rshift__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__rshift__>`_
@@ -133,37 +133,37 @@ Comparison Operators

A conforming implementation of the array API standard must provide and support an array object supporting the following Python comparison operators.

- `x1 < x2`: :meth:`.array.__lt__`
- `x1 < x2`: :meth:`.Array.__lt__`

- `operator.lt(x1, x2) <https://docs.python.org/3/library/operator.html#operator.lt>`_
- `operator.__lt__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__lt__>`_

- `x1 <= x2`: :meth:`.array.__le__`
- `x1 <= x2`: :meth:`.Array.__le__`

- `operator.le(x1, x2) <https://docs.python.org/3/library/operator.html#operator.le>`_
- `operator.__le__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__le__>`_

- `x1 > x2`: :meth:`.array.__gt__`
- `x1 > x2`: :meth:`.Array.__gt__`

- `operator.gt(x1, x2) <https://docs.python.org/3/library/operator.html#operator.gt>`_
- `operator.__gt__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__gt__>`_

- `x1 >= x2`: :meth:`.array.__ge__`
- `x1 >= x2`: :meth:`.Array.__ge__`

- `operator.ge(x1, x2) <https://docs.python.org/3/library/operator.html#operator.ge>`_
- `operator.__ge__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__ge__>`_

- `x1 == x2`: :meth:`.array.__eq__`
- `x1 == x2`: :meth:`.Array.__eq__`

- `operator.eq(x1, x2) <https://docs.python.org/3/library/operator.html#operator.eq>`_
- `operator.__eq__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__eq__>`_

- `x1 != x2`: :meth:`.array.__ne__`
- `x1 != x2`: :meth:`.Array.__ne__`

- `operator.ne(x1, x2) <https://docs.python.org/3/library/operator.html#operator.ne>`_
- `operator.__ne__(x1, x2) <https://docs.python.org/3/library/operator.html#operator.__ne__>`_

:meth:`.array.__lt__`, :meth:`.array.__le__`, :meth:`.array.__gt__`, :meth:`.array.__ge__` are only defined for arrays having real-valued data types. Other comparison operators should be defined for arrays having any data type.
:meth:`.Array.__lt__`, :meth:`.Array.__le__`, :meth:`.Array.__gt__`, :meth:`.Array.__ge__` are only defined for arrays having real-valued data types. Other comparison operators should be defined for arrays having any data type.
For backward compatibility, conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`).

In-place Operators
@@ -252,13 +252,13 @@ Attributes
:toctree: generated
:template: property.rst

array.dtype
array.device
array.mT
array.ndim
array.shape
array.size
array.T
Array.dtype
Array.device
Array.mT
Array.ndim
Array.shape
Array.size
Array.T

-------------------------------------------------

@@ -272,37 +272,37 @@ Methods
:toctree: generated
:template: property.rst

array.__abs__
array.__add__
array.__and__
array.__array_namespace__
array.__bool__
array.__complex__
array.__dlpack__
array.__dlpack_device__
array.__eq__
array.__float__
array.__floordiv__
array.__ge__
array.__getitem__
array.__gt__
array.__index__
array.__int__
array.__invert__
array.__le__
array.__lshift__
array.__lt__
array.__matmul__
array.__mod__
array.__mul__
array.__ne__
array.__neg__
array.__or__
array.__pos__
array.__pow__
array.__rshift__
array.__setitem__
array.__sub__
array.__truediv__
array.__xor__
array.to_device
Array.__abs__
Array.__add__
Array.__and__
Array.__array_namespace__
Array.__bool__
Array.__complex__
Array.__dlpack__
Array.__dlpack_device__
Array.__eq__
Array.__float__
Array.__floordiv__
Array.__ge__
Array.__getitem__
Array.__gt__
Array.__index__
Array.__int__
Array.__invert__
Array.__le__
Array.__lshift__
Array.__lt__
Array.__matmul__
Array.__mod__
Array.__mul__
Array.__ne__
Array.__neg__
Array.__or__
Array.__pos__
Array.__pow__
Array.__rshift__
Array.__setitem__
Array.__sub__
Array.__truediv__
Array.__xor__
Array.to_device
2 changes: 1 addition & 1 deletion spec/draft/purpose_and_scope.md
Original file line number Diff line number Diff line change
@@ -317,7 +317,7 @@ namespace (e.g. `import package_name.array_api`). This has two issues though:

To address both issues, a uniform way must be provided by a conforming
implementation to access the API namespace, namely a [method on the array
object](array.__array_namespace__):
object](Array.__array_namespace__):

```
xp = x.__array_namespace__()
5 changes: 5 additions & 0 deletions src/_array_api_conf.py
Original file line number Diff line number Diff line change
@@ -62,12 +62,16 @@
("py:obj", "typing.Union[int, float, typing.Literal[inf, - inf]]"),
("py:class", "int | float | ~typing.Literal[inf, -inf]"),
("py:class", "enum.Enum"),
("py:class", "Enum"),
("py:class", "ellipsis"),
]
nitpick_ignore_regex = [
("py:class", ".*array"),
("py:class", ".*Array"),
("py:class", ".*device"),
("py:class", ".*Device"),
("py:class", ".*dtype"),
("py:class", ".*Self"),
("py:class", ".*NestedSequence"),
("py:class", ".*SupportsBufferProtocol"),
("py:class", ".*PyCapsule"),
@@ -84,6 +88,7 @@
"array": "array",
"Device": "device",
"Dtype": "dtype",
"DType": "dtype",
}

# Make autosummary show the signatures of functions in the tables using actual
32 changes: 23 additions & 9 deletions src/array_api_stubs/_draft/_types.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@

from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
@@ -46,9 +47,22 @@
)
from enum import Enum

array = TypeVar("array")
device = TypeVar("device")
dtype = TypeVar("dtype")

if TYPE_CHECKING:
from .array_object import Array
from .data_types import DType


class Device(Protocol):
"""Protocol for device objects."""

def __eq__(self, value: Any) -> bool:
...


array = TypeVar("array", bound="Array")
device = TypeVar("device", bound=Device)
dtype = TypeVar("dtype", bound="DType")
SupportsDLPack = TypeVar("SupportsDLPack")
SupportsBufferProtocol = TypeVar("SupportsBufferProtocol")
PyCapsule = TypeVar("PyCapsule")
@@ -66,7 +80,7 @@ class finfo_object:
max: float
min: float
smallest_normal: float
dtype: dtype
dtype: DType


@dataclass
@@ -76,7 +90,7 @@ class iinfo_object:
bits: int
max: int
min: int
dtype: dtype
dtype: DType


_T_co = TypeVar("_T_co", covariant=True)
@@ -96,17 +110,17 @@ class Info(Protocol):
def capabilities(self) -> Capabilities:
...

def default_device(self) -> device:
def default_device(self) -> Device:
...

def default_dtypes(self, *, device: Optional[device]) -> DefaultDataTypes:
def default_dtypes(self, *, device: Optional[Device]) -> DefaultDataTypes:
...

def devices(self) -> List[device]:
def devices(self) -> List[Device]:
...

def dtypes(
self, *, device: Optional[device], kind: Optional[Union[str, Tuple[str, ...]]]
self, *, device: Optional[Device], kind: Optional[Union[str, Tuple[str, ...]]]
) -> DataTypes:
...

Loading
Loading