@@ -630,6 +630,29 @@ def _make_array(x, dim_str, dtype):
630630 return out
631631
632632
633+ def _normalize_array_type (array_type ):
634+ if IS_NUMPY_INSTALLED and array_type is npt .ArrayLike :
635+ # Work around https://github.com/numpy/numpy/commit/1041f940f91660c91770679c60f6e63539581c72
636+ # which removes `bool`/`int`/`float` from the union.
637+ return Union [(* get_args (array_type ), bool , int , float , complex )]
638+ elif isinstance (array_type , TypeVar ):
639+ bound = array_type .__bound__
640+ if bound is None :
641+ constraints = array_type .__constraints__
642+ if constraints == ():
643+ return Any
644+ else :
645+ return _normalize_array_type (Union [constraints ])
646+ else :
647+ return _normalize_array_type (bound )
648+ elif isinstance (array_type , TypeAliasType ):
649+ return _normalize_array_type (array_type .__value__ )
650+ elif get_origin (array_type ) in _union_types :
651+ return Union [tuple (_normalize_array_type (x ) for x in get_args (array_type ))]
652+ else :
653+ return array_type
654+
655+
633656class _MetaAbstractDtype (type ):
634657 def __instancecheck__ (cls , obj : Any ) -> NoReturn :
635658 raise AnnotationError (
@@ -646,23 +669,8 @@ def __getitem__(cls, item: tuple[Any, str]):
646669 "Ellipsis can be used to accept any shape: `Float[jax.Array, '...']`."
647670 )
648671 array_type , dim_str = item
672+ array_type = _normalize_array_type (array_type )
649673 dim_str = dim_str .strip ()
650- if isinstance (array_type , TypeVar ):
651- bound = array_type .__bound__
652- if bound is None :
653- constraints = array_type .__constraints__
654- if constraints == ():
655- array_type = Any
656- else :
657- array_type = Union [constraints ]
658- else :
659- array_type = bound
660- if isinstance (array_type , TypeAliasType ):
661- array_type = array_type .__value__
662- if IS_NUMPY_INSTALLED and item [0 ] is npt .ArrayLike :
663- # Work around https://github.com/numpy/numpy/commit/1041f940f91660c91770679c60f6e63539581c72
664- # which removes `bool`/`int`/`float` from the union.
665- array_type = Union [(* get_args (array_type ), bool , int , float , complex )]
666674 del item
667675 if get_origin (array_type ) in _union_types :
668676 out = [_make_array (x , dim_str , cls ) for x in get_args (array_type )]
0 commit comments