-
Notifications
You must be signed in to change notification settings - Fork 33
TYP: Type annotations, part 4 #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
362c48a
ad375dc
49f9ba7
4371506
c724a52
14f70af
0a571bc
0172300
8711041
014e20f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,50 +12,46 @@ | |||||
import math | ||||||
import sys | ||||||
import warnings | ||||||
from collections.abc import Collection, Hashable | ||||||
from collections.abc import Hashable | ||||||
from functools import lru_cache | ||||||
from types import NoneType | ||||||
from typing import ( | ||||||
TYPE_CHECKING, | ||||||
Any, | ||||||
Final, | ||||||
Literal, | ||||||
SupportsIndex, | ||||||
TypeAlias, | ||||||
TypeGuard, | ||||||
TypeVar, | ||||||
cast, | ||||||
overload, | ||||||
) | ||||||
|
||||||
from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
|
||||||
import cupy as cp | ||||||
import dask.array as da | ||||||
import jax | ||||||
import ndonnx as ndx | ||||||
import numpy as np | ||||||
import numpy.typing as npt | ||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
import torch | ||||||
|
||||||
# TODO: import from typing (requires Python >=3.13) | ||||||
from typing_extensions import TypeIs, TypeVar | ||||||
|
||||||
_SizeT = TypeVar("_SizeT", bound = int | None) | ||||||
from typing_extensions import TypeIs | ||||||
|
||||||
_ZeroGradientArray: TypeAlias = npt.NDArray[np.void] | ||||||
_CupyArray: TypeAlias = Any # cupy has no py.typed | ||||||
|
||||||
_ArrayApiObj: TypeAlias = ( | ||||||
npt.NDArray[Any] | ||||||
| cp.ndarray | ||||||
| da.Array | ||||||
| jax.Array | ||||||
| ndx.Array | ||||||
| sparse.SparseArray | ||||||
| torch.Tensor | ||||||
| SupportsArrayNamespace[Any] | ||||||
| _CupyArray | ||||||
| SupportsArrayNamespace | ||||||
) | ||||||
|
||||||
_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) | ||||||
|
@@ -95,7 +91,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: | |||||
return dtype == jax.float0 | ||||||
|
||||||
|
||||||
def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: | ||||||
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: | ||||||
""" | ||||||
Return True if `x` is a NumPy array. | ||||||
|
||||||
|
@@ -266,7 +262,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: | |||||
return _issubclass_fast(cls, "sparse", "SparseArray") | ||||||
|
||||||
|
||||||
def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] | ||||||
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the current definition of _ArrayApiObj, TypeIs would cause downstream failures for all unknown array api compliant libraries. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
""" | ||||||
Return True if `x` is an array API compatible array object. | ||||||
|
||||||
|
@@ -581,7 +577,7 @@ def your_function(x, y): | |||||
|
||||||
namespaces.add(cupy_namespace) | ||||||
else: | ||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
import cupy as cp | ||||||
|
||||||
namespaces.add(cp) | ||||||
elif is_torch_array(x): | ||||||
|
@@ -618,14 +614,14 @@ def your_function(x, y): | |||||
if hasattr(jax.numpy, "__array_api_version__"): | ||||||
jnp = jax.numpy | ||||||
else: | ||||||
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] | ||||||
import jax.experimental.array_api as jnp # type: ignore[no-redef] | ||||||
namespaces.add(jnp) | ||||||
elif is_pydata_sparse_array(x): | ||||||
if use_compat is True: | ||||||
_check_api_version(api_version) | ||||||
raise ValueError("`sparse` does not have an array-api-compat wrapper") | ||||||
else: | ||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's no need for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather not enlarge the scope of this PR. |
||||||
# `sparse` is already an array namespace. We do not have a wrapper | ||||||
# submodule for it. | ||||||
namespaces.add(sparse) | ||||||
|
@@ -634,9 +630,9 @@ def your_function(x, y): | |||||
raise ValueError( | ||||||
"The given array does not have an array-api-compat wrapper" | ||||||
) | ||||||
x = cast("SupportsArrayNamespace[Any]", x) | ||||||
x = cast(SupportsArrayNamespace, x) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is why I quoted it: https://docs.astral.sh/ruff/rules/runtime-cast-value/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I strongly disagree with TC006 |
||||||
namespaces.add(x.__array_namespace__(api_version=api_version)) | ||||||
elif isinstance(x, (bool, int, float, complex, type(None))): | ||||||
elif isinstance(x, int | float | complex | NoneType): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(I'll spare you the pseudo-philosophical rant this time) |
||||||
continue | ||||||
else: | ||||||
# TODO: Support Python scalars? | ||||||
|
@@ -732,7 +728,7 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
return "cpu" | ||||||
elif is_dask_array(x): | ||||||
# Peek at the metadata of the Dask array to determine type | ||||||
if is_numpy_array(x._meta): # pyright: ignore | ||||||
if is_numpy_array(x._meta): | ||||||
# Must be on CPU since backed by numpy | ||||||
return "cpu" | ||||||
return _DASK_DEVICE | ||||||
|
@@ -761,7 +757,7 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
return "cpu" | ||||||
# Return the device of the constituent array | ||||||
return device(inner) # pyright: ignore | ||||||
return x.device # pyright: ignore | ||||||
return x.device # type: ignore # pyright: ignore | ||||||
|
||||||
|
||||||
# Prevent shadowing, used below | ||||||
|
@@ -770,12 +766,12 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
|
||||||
# Based on cupy.array_api.Array.to_device | ||||||
def _cupy_to_device( | ||||||
x: _CupyArray, | ||||||
x: cp.ndarray, | ||||||
device: Device, | ||||||
/, | ||||||
stream: int | Any | None = None, | ||||||
) -> _CupyArray: | ||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
) -> cp.ndarray: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. has this been fixed in cupy since last time or something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not that I'm aware of. But we have to skip all missing imports anyway or people that can't pip-install cupy won't be able to run the linter. |
||||||
import cupy as cp | ||||||
from cupy.cuda import Device as _Device # pyright: ignore | ||||||
from cupy.cuda import stream as stream_module # pyright: ignore | ||||||
from cupy_backends.cuda.api import runtime # pyright: ignore | ||||||
|
@@ -791,10 +787,10 @@ def _cupy_to_device( | |||||
raise ValueError(f"Unsupported device {device!r}") | ||||||
else: | ||||||
# see cupy/cupy#5985 for the reason how we handle device/stream here | ||||||
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] | ||||||
prev_device: Device = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] | ||||||
prev_stream = None | ||||||
if stream is not None: | ||||||
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore | ||||||
prev_stream = stream_module.get_current_stream() # pyright: ignore | ||||||
# stream can be an int as specified in __dlpack__, or a CuPy stream | ||||||
if isinstance(stream, int): | ||||||
stream = cp.cuda.ExternalStream(stream) # pyright: ignore | ||||||
|
@@ -808,7 +804,7 @@ def _cupy_to_device( | |||||
arr = x.copy() | ||||||
finally: | ||||||
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] | ||||||
if stream is not None: | ||||||
if prev_stream is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. casual bugfix 🤔 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's actually logically identical to before. But it was convoluted and rightfully the type checker was complaining. |
||||||
prev_stream.use() | ||||||
return arr | ||||||
|
||||||
|
@@ -817,7 +813,7 @@ def _torch_to_device( | |||||
x: torch.Tensor, | ||||||
device: torch.device | str | int, | ||||||
/, | ||||||
stream: None = None, | ||||||
stream: int | Any | None = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps we should create a stream type-alias (unless there already is one) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there isn't one, and IMHO it's not referenced enough to warrant one |
||||||
) -> torch.Tensor: | ||||||
if stream is not None: | ||||||
raise NotImplementedError | ||||||
|
@@ -883,7 +879,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - | |||||
# cupy does not yet have to_device | ||||||
return _cupy_to_device(x, device, stream=stream) | ||||||
elif is_torch_array(x): | ||||||
return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] | ||||||
return _torch_to_device(x, device, stream=stream) | ||||||
elif is_dask_array(x): | ||||||
if stream is not None: | ||||||
raise ValueError("The stream argument to to_device() is not supported") | ||||||
|
@@ -908,12 +904,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - | |||||
|
||||||
|
||||||
@overload | ||||||
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... | ||||||
def size(x: HasShape[int]) -> int: ... | ||||||
@overload | ||||||
def size(x: HasShape[Collection[None]]) -> None: ... | ||||||
def size(x: HasShape[int | None]) -> int | None: ... | ||||||
@overload | ||||||
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... | ||||||
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: | ||||||
def size(x: HasShape[float]) -> int | None: ... # Dask special case | ||||||
def size(x: HasShape[float | None]) -> int | None: | ||||||
Comment on lines
+907
to
+912
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why did you remove the overload that returns There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because no Array API backend in existence, present or future, will ever have all its arrays with no shape information whatsoever. |
||||||
""" | ||||||
Return the total number of elements of x. | ||||||
|
||||||
|
@@ -928,9 +924,9 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: | |||||
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape | ||||||
if None in x.shape: | ||||||
return None | ||||||
out = math.prod(cast("Collection[SupportsIndex]", x.shape)) | ||||||
out = math.prod(cast(tuple[float, ...], x.shape)) | ||||||
# dask.array.Array.shape can contain NaN | ||||||
return None if math.isnan(out) else out | ||||||
return None if math.isnan(out) else cast(int, out) | ||||||
|
||||||
|
||||||
@lru_cache(100) | ||||||
|
@@ -946,7 +942,7 @@ def _is_writeable_cls(cls: type) -> bool | None: | |||||
return None | ||||||
|
||||||
|
||||||
def is_writeable_array(x: object) -> bool: | ||||||
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
""" | ||||||
Return False if ``x.__setitem__`` is expected to raise; True otherwise. | ||||||
Return False if `x` is not an array API compatible object. | ||||||
|
@@ -984,7 +980,7 @@ def _is_lazy_cls(cls: type) -> bool | None: | |||||
return None | ||||||
|
||||||
|
||||||
def is_lazy_array(x: object) -> bool: | ||||||
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Anything with a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's an issue with the implementation, not the type annotation. |
||||||
"""Return True if x is potentially a future or it may be otherwise impossible or | ||||||
expensive to eagerly read its contents, regardless of their size, e.g. by | ||||||
calling ``bool(x)`` or ``float(x)``. | ||||||
|
@@ -1021,7 +1017,7 @@ def is_lazy_array(x: object) -> bool: | |||||
# on __bool__ (dask is one such example, which however is special-cased above). | ||||||
|
||||||
# Select a single point of the array | ||||||
s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) | ||||||
s = size(cast(HasShape, x)) | ||||||
if s is None: | ||||||
return True | ||||||
xp = array_namespace(x) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With
we avoid the
types
import while simultaneously accentuating the violent dissonance between the Python runtime and its type-system, given that the sole purpose of a type-system is to accurately describe the runtime behavior...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we avoid the
types
import? This is a runtime check.