Skip to content

Commit b3ad46f

Browse files
Numpy 2.4.0 compat: fix union of arraylike.
1 parent 103e004 commit b3ad46f

3 files changed

Lines changed: 31 additions & 17 deletions

File tree

jaxtyping/_array_types.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,29 @@ def _make_array(x, dim_str, dtype):
630630
return out
631631

632632

633+
def _normalize_array_type(array_type):
634+
if IS_NUMPY_INSTALLED and array_type is npt.ArrayLike:
635+
# Work around https://github.com/numpy/numpy/commit/1041f940f91660c91770679c60f6e63539581c72
636+
# which removes `bool`/`int`/`float` from the union.
637+
return Union[(*get_args(array_type), bool, int, float, complex)]
638+
elif isinstance(array_type, TypeVar):
639+
bound = array_type.__bound__
640+
if bound is None:
641+
constraints = array_type.__constraints__
642+
if constraints == ():
643+
return Any
644+
else:
645+
return _normalize_array_type(Union[constraints])
646+
else:
647+
return _normalize_array_type(bound)
648+
elif isinstance(array_type, TypeAliasType):
649+
return _normalize_array_type(array_type.__value__)
650+
elif get_origin(array_type) in _union_types:
651+
return Union[tuple(_normalize_array_type(x) for x in get_args(array_type))]
652+
else:
653+
return array_type
654+
655+
633656
class _MetaAbstractDtype(type):
634657
def __instancecheck__(cls, obj: Any) -> NoReturn:
635658
raise AnnotationError(
@@ -646,23 +669,8 @@ def __getitem__(cls, item: tuple[Any, str]):
646669
"Ellipsis can be used to accept any shape: `Float[jax.Array, '...']`."
647670
)
648671
array_type, dim_str = item
672+
array_type = _normalize_array_type(array_type)
649673
dim_str = dim_str.strip()
650-
if isinstance(array_type, TypeVar):
651-
bound = array_type.__bound__
652-
if bound is None:
653-
constraints = array_type.__constraints__
654-
if constraints == ():
655-
array_type = Any
656-
else:
657-
array_type = Union[constraints]
658-
else:
659-
array_type = bound
660-
if isinstance(array_type, TypeAliasType):
661-
array_type = array_type.__value__
662-
if IS_NUMPY_INSTALLED and item[0] is npt.ArrayLike:
663-
# Work around https://github.com/numpy/numpy/commit/1041f940f91660c91770679c60f6e63539581c72
664-
# which removes `bool`/`int`/`float` from the union.
665-
array_type = Union[(*get_args(array_type), bool, int, float, complex)]
666674
del item
667675
if get_origin(array_type) in _union_types:
668676
out = [_make_array(x, dim_str, cls) for x in get_args(array_type)]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ name = "jaxtyping"
2828
readme = "README.md"
2929
requires-python = ">=3.10"
3030
urls = {repository = "https://github.com/patrick-kidger/jaxtyping"}
31-
version = "0.3.6"
31+
version = "0.3.7"
3232

3333
[project.optional-dependencies]
3434
dev = [

test/test_array.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,3 +929,9 @@ def test_typealiastype():
929929

930930
x = Float[TypeAliasType("Foo", bool | np.ndarray), "3"]
931931
assert _to_set([x]) == _to_set([Float[np.ndarray, "3"]])
932+
933+
934+
# https://github.com/patrick-kidger/jaxtyping/issues/374
935+
def test_union_of_arraylike():
936+
Arrayish = Union[jax.Array, np.typing.ArrayLike]
937+
Float[Arrayish, ""]

0 commit comments

Comments
 (0)