Skip to content

TYP: Type annotations, part 4 #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions array_api_compat/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
specification for more details.

"""
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # pyright: ignore[reportReturnType]
wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType]

return inner

Expand Down
14 changes: 8 additions & 6 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
from collections.abc import Sequence
from types import NoneType
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from ._helpers import _check_device, array_namespace
from ._helpers import device as _get_device
from ._helpers import is_cupy_namespace as _is_cupy_namespace
from ._helpers import is_cupy_namespace
from ._typing import Array, Device, DType, Namespace

if TYPE_CHECKING:
Expand Down Expand Up @@ -381,8 +383,8 @@ def clip(
# TODO: np.clip has other ufunc kwargs
out: Array | None = None,
) -> Array:
def _isscalar(a: object) -> TypeIs[int | float | None]:
return isinstance(a, (int, float, type(None)))
def _isscalar(a: object) -> TypeIs[float | None]:
return isinstance(a, int | float | NoneType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With

Suggested change
return isinstance(a, int | float | NoneType)
return a is None or isinstance(a, int | float)

we avoid the types import while simultaneously accentuating the violent dissonance between the Python runtime and its type-system, given that the sole purpose of a type-system is to accurately describe the runtime behavior...

Copy link
Contributor Author

@crusaderky crusaderky Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we avoid the types import? This is a runtime check.


min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
Expand Down Expand Up @@ -450,7 +452,7 @@ def reshape(
shape: tuple[int, ...],
xp: Namespace,
*,
copy: Optional[bool] = None,
copy: bool | None = None,
**kwargs: object,
) -> Array:
if copy is True:
Expand Down Expand Up @@ -657,7 +659,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
out[xp.isnan(x)] = xp.nan
return out[()]

Expand Down
70 changes: 33 additions & 37 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,50 +12,46 @@
import math
import sys
import warnings
from collections.abc import Collection, Hashable
from collections.abc import Hashable
from functools import lru_cache
from types import NoneType
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
SupportsIndex,
TypeAlias,
TypeGuard,
TypeVar,
cast,
overload,
)

from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace

if TYPE_CHECKING:

import cupy as cp
import dask.array as da
import jax
import ndonnx as ndx
import numpy as np
import numpy.typing as npt
import sparse # pyright: ignore[reportMissingTypeStubs]
import sparse
import torch

# TODO: import from typing (requires Python >=3.13)
from typing_extensions import TypeIs, TypeVar

_SizeT = TypeVar("_SizeT", bound = int | None)
from typing_extensions import TypeIs

_ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
_CupyArray: TypeAlias = Any # cupy has no py.typed

_ArrayApiObj: TypeAlias = (
npt.NDArray[Any]
| cp.ndarray
| da.Array
| jax.Array
| ndx.Array
| sparse.SparseArray
| torch.Tensor
| SupportsArrayNamespace[Any]
| _CupyArray
| SupportsArrayNamespace
)

_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
Expand Down Expand Up @@ -95,7 +91,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
return dtype == jax.float0


def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
"""
Return True if `x` is a NumPy array.

Expand Down Expand Up @@ -266,7 +262,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
return _issubclass_fast(cls, "sparse", "SparseArray")


def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType]
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the current definition of _ArrayApiObj, TypeIs would cause downstream failures for all unknown array api compliant libraries.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't SupportsArrayNamespace cover all downstream array types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SupportsArrayNamespace just states that the object has a __array_namespace__ method, nothing more.
It's missing all the other methods and properties of an Array.
I would much rather NOT write the full Array protocol here (I did it in array-api-extra and I regret it), as this is squarely in scope for array-api-types.

"""
Return True if `x` is an array API compatible array object.

Expand Down Expand Up @@ -581,7 +577,7 @@ def your_function(x, y):

namespaces.add(cupy_namespace)
else:
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
import cupy as cp

namespaces.add(cp)
elif is_torch_array(x):
Expand Down Expand Up @@ -618,14 +614,14 @@ def your_function(x, y):
if hasattr(jax.numpy, "__array_api_version__"):
jnp = jax.numpy
else:
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports]
import jax.experimental.array_api as jnp # type: ignore[no-redef]
namespaces.add(jnp)
elif is_pydata_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
raise ValueError("`sparse` does not have an array-api-compat wrapper")
else:
import sparse # pyright: ignore[reportMissingTypeStubs]
import sparse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no need for the else: clause

Copy link
Contributor Author

@crusaderky crusaderky Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not enlarge the scope of this PR.
On a side note, this is something I disagree with the ruff policy. The else clause, while not being functionally useful, makes the code flow more readable IMHO.

# `sparse` is already an array namespace. We do not have a wrapper
# submodule for it.
namespaces.add(sparse)
Expand All @@ -634,9 +630,9 @@ def your_function(x, y):
raise ValueError(
"The given array does not have an array-api-compat wrapper"
)
x = cast("SupportsArrayNamespace[Any]", x)
x = cast(SupportsArrayNamespace, x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly disagree with TC006

namespaces.add(x.__array_namespace__(api_version=api_version))
elif isinstance(x, (bool, int, float, complex, type(None))):
elif isinstance(x, int | float | complex | NoneType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif isinstance(x, int | float | complex | NoneType):
elif x is None or isinstance(x, int | float | complex):

(I'll spare you the pseudo-philosophical rant this time)

continue
else:
# TODO: Support Python scalars?
Expand Down Expand Up @@ -732,7 +728,7 @@ def device(x: _ArrayApiObj, /) -> Device:
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the Dask array to determine type
if is_numpy_array(x._meta): # pyright: ignore
if is_numpy_array(x._meta):
# Must be on CPU since backed by numpy
return "cpu"
return _DASK_DEVICE
Expand Down Expand Up @@ -761,7 +757,7 @@ def device(x: _ArrayApiObj, /) -> Device:
return "cpu"
# Return the device of the constituent array
return device(inner) # pyright: ignore
return x.device # pyright: ignore
return x.device # type: ignore # pyright: ignore


# Prevent shadowing, used below
Expand All @@ -770,12 +766,12 @@ def device(x: _ArrayApiObj, /) -> Device:

# Based on cupy.array_api.Array.to_device
def _cupy_to_device(
x: _CupyArray,
x: cp.ndarray,
device: Device,
/,
stream: int | Any | None = None,
) -> _CupyArray:
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
) -> cp.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has this been fixed in cupy since last time or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I'm aware of. But we have to skip all missing imports anyway or people that can't pip-install cupy won't be able to run the linter.

import cupy as cp
from cupy.cuda import Device as _Device # pyright: ignore
from cupy.cuda import stream as stream_module # pyright: ignore
from cupy_backends.cuda.api import runtime # pyright: ignore
Expand All @@ -791,10 +787,10 @@ def _cupy_to_device(
raise ValueError(f"Unsupported device {device!r}")
else:
# see cupy/cupy#5985 for the reason how we handle device/stream here
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
prev_device: Device = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
prev_stream = None
if stream is not None:
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore
prev_stream = stream_module.get_current_stream() # pyright: ignore
# stream can be an int as specified in __dlpack__, or a CuPy stream
if isinstance(stream, int):
stream = cp.cuda.ExternalStream(stream) # pyright: ignore
Expand All @@ -808,7 +804,7 @@ def _cupy_to_device(
arr = x.copy()
finally:
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType]
if stream is not None:
if prev_stream is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casual bugfix 🤔 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually logically identical to before. But it was convoluted and rightfully the type checker was complaining.

prev_stream.use()
return arr

Expand All @@ -817,7 +813,7 @@ def _torch_to_device(
x: torch.Tensor,
device: torch.device | str | int,
/,
stream: None = None,
stream: int | Any | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we should create a stream type-alias (unless there already is one)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there isn't one, and IMHO it's not referenced enough to warrant one

) -> torch.Tensor:
if stream is not None:
raise NotImplementedError
Expand Down Expand Up @@ -883,7 +879,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
# cupy does not yet have to_device
return _cupy_to_device(x, device, stream=stream)
elif is_torch_array(x):
return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType]
return _torch_to_device(x, device, stream=stream)
elif is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
Expand All @@ -908,12 +904,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -


@overload
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
def size(x: HasShape[int]) -> int: ...
@overload
def size(x: HasShape[Collection[None]]) -> None: ...
def size(x: HasShape[int | None]) -> int | None: ...
@overload
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
def size(x: HasShape[float]) -> int | None: ... # Dask special case
def size(x: HasShape[float | None]) -> int | None:
Comment on lines +907 to +912
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you remove the overload that returns -> None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because no Array API backend in existence, present or future, will ever have all its arrays with no shape information whatsoever.

"""
Return the total number of elements of x.

Expand All @@ -928,9 +924,9 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
if None in x.shape:
return None
out = math.prod(cast("Collection[SupportsIndex]", x.shape))
out = math.prod(cast(tuple[float, ...], x.shape))
# dask.array.Array.shape can contain NaN
return None if math.isnan(out) else out
return None if math.isnan(out) else cast(int, out)


@lru_cache(100)
Expand All @@ -946,7 +942,7 @@ def _is_writeable_cls(cls: type) -> bool | None:
return None


def is_writeable_array(x: object) -> bool:
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
Return False if `x` is not an array API compatible object.
Expand Down Expand Up @@ -984,7 +980,7 @@ def _is_lazy_cls(cls: type) -> bool | None:
return None


def is_lazy_array(x: object) -> bool:
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything with a shape: tuple that contains a None would return True here, so this isn't correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an issue with the implementation, not the type annotation.
As per the docstring, this function returns True if and only if x is an array API object and it is lazy.

"""Return True if x is potentially a future or it may be otherwise impossible or
expensive to eagerly read its contents, regardless of their size, e.g. by
calling ``bool(x)`` or ``float(x)``.
Expand Down Expand Up @@ -1021,7 +1017,7 @@ def is_lazy_array(x: object) -> bool:
# on __bool__ (dask is one such example, which however is special-cased above).

# Select a single point of the array
s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
s = size(cast(HasShape, x))
if s is None:
return True
xp = array_namespace(x)
Expand Down
2 changes: 1 addition & 1 deletion array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if np.__version__[0] == "2":
from numpy.lib.array_utils import normalize_axis_tuple
else:
from numpy.core.numeric import normalize_axis_tuple
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]

from .._internal import get_xp
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
Expand Down
Loading
Loading