Skip to content

Commit 3c1e694

Browse files
committed
TYP: Compact Python scalar types
1 parent 50a155a commit 3c1e694

8 files changed

+62
-75
lines changed

array_api_strict/_array_object.py

+40-40
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def __array__(
191191
# NumPy behavior
192192

193193
def _check_allowed_dtypes(
194-
self, other: Array | bool | int | float | complex, dtype_category: str, op: str
194+
self, other: Array | complex, dtype_category: str, op: str
195195
) -> Array:
196196
"""
197197
Helper function for operators to only allow specific input dtypes
@@ -233,7 +233,7 @@ def _check_allowed_dtypes(
233233

234234
return other
235235

236-
def _check_device(self, other: Array | bool | int | float | complex) -> None:
236+
def _check_device(self, other: Array | complex) -> None:
237237
"""Check that other is on a device compatible with the current array"""
238238
if isinstance(other, (bool, int, float, complex)):
239239
return
@@ -244,7 +244,7 @@ def _check_device(self, other: Array | bool | int | float | complex) -> None:
244244
raise TypeError(f"Expected Array | python scalar; got {type(other)}")
245245

246246
# Helper function to match the type promotion rules in the spec
247-
def _promote_scalar(self, scalar: bool | int | float | complex) -> Array:
247+
def _promote_scalar(self, scalar: complex) -> Array:
248248
"""
249249
Returns a promoted version of a Python scalar appropriate for use with
250250
operations on self.
@@ -538,7 +538,7 @@ def __abs__(self) -> Array:
538538
res = self._array.__abs__()
539539
return self.__class__._new(res, device=self.device)
540540

541-
def __add__(self, other: Array | int | float | complex, /) -> Array:
541+
def __add__(self, other: Array | complex, /) -> Array:
542542
"""
543543
Performs the operation __add__.
544544
"""
@@ -550,7 +550,7 @@ def __add__(self, other: Array | int | float | complex, /) -> Array:
550550
res = self._array.__add__(other._array)
551551
return self.__class__._new(res, device=self.device)
552552

553-
def __and__(self, other: Array | bool | int, /) -> Array:
553+
def __and__(self, other: Array | int, /) -> Array:
554554
"""
555555
Performs the operation __and__.
556556
"""
@@ -647,7 +647,7 @@ def __dlpack_device__(self) -> tuple[IntEnum, int]:
647647
# Note: device support is required for this
648648
return self._array.__dlpack_device__()
649649

650-
def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override]
650+
def __eq__(self, other: Array | complex, /) -> Array: # type: ignore[override]
651651
"""
652652
Performs the operation __eq__.
653653
"""
@@ -673,7 +673,7 @@ def __float__(self) -> float:
673673
res = self._array.__float__()
674674
return res
675675

676-
def __floordiv__(self, other: Array | int | float, /) -> Array:
676+
def __floordiv__(self, other: Array | float, /) -> Array:
677677
"""
678678
Performs the operation __floordiv__.
679679
"""
@@ -685,7 +685,7 @@ def __floordiv__(self, other: Array | int | float, /) -> Array:
685685
res = self._array.__floordiv__(other._array)
686686
return self.__class__._new(res, device=self.device)
687687

688-
def __ge__(self, other: Array | int | float, /) -> Array:
688+
def __ge__(self, other: Array | float, /) -> Array:
689689
"""
690690
Performs the operation __ge__.
691691
"""
@@ -737,7 +737,7 @@ def __getitem__(
737737
res = self._array.__getitem__(np_key)
738738
return self._new(res, device=self.device)
739739

740-
def __gt__(self, other: Array | int | float, /) -> Array:
740+
def __gt__(self, other: Array | float, /) -> Array:
741741
"""
742742
Performs the operation __gt__.
743743
"""
@@ -792,7 +792,7 @@ def __iter__(self) -> Iterator[Array]:
792792
# implemented, which implies iteration on 1-D arrays.
793793
return (Array._new(i, device=self.device) for i in self._array)
794794

795-
def __le__(self, other: Array | int | float, /) -> Array:
795+
def __le__(self, other: Array | float, /) -> Array:
796796
"""
797797
Performs the operation __le__.
798798
"""
@@ -816,7 +816,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
816816
res = self._array.__lshift__(other._array)
817817
return self.__class__._new(res, device=self.device)
818818

819-
def __lt__(self, other: Array | int | float, /) -> Array:
819+
def __lt__(self, other: Array | float, /) -> Array:
820820
"""
821821
Performs the operation __lt__.
822822
"""
@@ -841,7 +841,7 @@ def __matmul__(self, other: Array, /) -> Array:
841841
res = self._array.__matmul__(other._array)
842842
return self.__class__._new(res, device=self.device)
843843

844-
def __mod__(self, other: Array | int | float, /) -> Array:
844+
def __mod__(self, other: Array | float, /) -> Array:
845845
"""
846846
Performs the operation __mod__.
847847
"""
@@ -853,7 +853,7 @@ def __mod__(self, other: Array | int | float, /) -> Array:
853853
res = self._array.__mod__(other._array)
854854
return self.__class__._new(res, device=self.device)
855855

856-
def __mul__(self, other: Array | int | float | complex, /) -> Array:
856+
def __mul__(self, other: Array | complex, /) -> Array:
857857
"""
858858
Performs the operation __mul__.
859859
"""
@@ -865,7 +865,7 @@ def __mul__(self, other: Array | int | float | complex, /) -> Array:
865865
res = self._array.__mul__(other._array)
866866
return self.__class__._new(res, device=self.device)
867867

868-
def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override]
868+
def __ne__(self, other: Array | complex, /) -> Array: # type: ignore[override]
869869
"""
870870
Performs the operation __ne__.
871871
"""
@@ -886,7 +886,7 @@ def __neg__(self) -> Array:
886886
res = self._array.__neg__()
887887
return self.__class__._new(res, device=self.device)
888888

889-
def __or__(self, other: Array | bool | int, /) -> Array:
889+
def __or__(self, other: Array | int, /) -> Array:
890890
"""
891891
Performs the operation __or__.
892892
"""
@@ -907,7 +907,7 @@ def __pos__(self) -> Array:
907907
res = self._array.__pos__()
908908
return self.__class__._new(res, device=self.device)
909909

910-
def __pow__(self, other: Array | int | float | complex, /) -> Array:
910+
def __pow__(self, other: Array | complex, /) -> Array:
911911
"""
912912
Performs the operation __pow__.
913913
"""
@@ -944,7 +944,7 @@ def __setitem__(
944944
| Array
945945
| tuple[int | slice | EllipsisType, ...]
946946
),
947-
value: Array | bool | int | float | complex,
947+
value: Array | complex,
948948
/,
949949
) -> None:
950950
"""
@@ -957,7 +957,7 @@ def __setitem__(
957957
np_key = key._array if isinstance(key, Array) else key
958958
self._array.__setitem__(np_key, asarray(value)._array)
959959

960-
def __sub__(self, other: Array | int | float | complex, /) -> Array:
960+
def __sub__(self, other: Array | complex, /) -> Array:
961961
"""
962962
Performs the operation __sub__.
963963
"""
@@ -971,7 +971,7 @@ def __sub__(self, other: Array | int | float | complex, /) -> Array:
971971

972972
# PEP 484 requires int to be a subtype of float, but __truediv__ should
973973
# not accept int.
974-
def __truediv__(self, other: Array | int | float | complex, /) -> Array:
974+
def __truediv__(self, other: Array | complex, /) -> Array:
975975
"""
976976
Performs the operation __truediv__.
977977
"""
@@ -983,7 +983,7 @@ def __truediv__(self, other: Array | int | float | complex, /) -> Array:
983983
res = self._array.__truediv__(other._array)
984984
return self.__class__._new(res, device=self.device)
985985

986-
def __xor__(self, other: Array | bool | int, /) -> Array:
986+
def __xor__(self, other: Array | int, /) -> Array:
987987
"""
988988
Performs the operation __xor__.
989989
"""
@@ -995,7 +995,7 @@ def __xor__(self, other: Array | bool | int, /) -> Array:
995995
res = self._array.__xor__(other._array)
996996
return self.__class__._new(res, device=self.device)
997997

998-
def __iadd__(self, other: Array | int | float | complex, /) -> Array:
998+
def __iadd__(self, other: Array | complex, /) -> Array:
999999
"""
10001000
Performs the operation __iadd__.
10011001
"""
@@ -1006,7 +1006,7 @@ def __iadd__(self, other: Array | int | float | complex, /) -> Array:
10061006
self._array.__iadd__(other._array)
10071007
return self
10081008

1009-
def __radd__(self, other: Array | int | float | complex, /) -> Array:
1009+
def __radd__(self, other: Array | complex, /) -> Array:
10101010
"""
10111011
Performs the operation __radd__.
10121012
"""
@@ -1018,7 +1018,7 @@ def __radd__(self, other: Array | int | float | complex, /) -> Array:
10181018
res = self._array.__radd__(other._array)
10191019
return self.__class__._new(res, device=self.device)
10201020

1021-
def __iand__(self, other: Array | bool | int, /) -> Array:
1021+
def __iand__(self, other: Array | int, /) -> Array:
10221022
"""
10231023
Performs the operation __iand__.
10241024
"""
@@ -1029,7 +1029,7 @@ def __iand__(self, other: Array | bool | int, /) -> Array:
10291029
self._array.__iand__(other._array)
10301030
return self
10311031

1032-
def __rand__(self, other: Array | bool | int, /) -> Array:
1032+
def __rand__(self, other: Array | int, /) -> Array:
10331033
"""
10341034
Performs the operation __rand__.
10351035
"""
@@ -1041,7 +1041,7 @@ def __rand__(self, other: Array | bool | int, /) -> Array:
10411041
res = self._array.__rand__(other._array)
10421042
return self.__class__._new(res, device=self.device)
10431043

1044-
def __ifloordiv__(self, other: Array | int | float, /) -> Array:
1044+
def __ifloordiv__(self, other: Array | float, /) -> Array:
10451045
"""
10461046
Performs the operation __ifloordiv__.
10471047
"""
@@ -1052,7 +1052,7 @@ def __ifloordiv__(self, other: Array | int | float, /) -> Array:
10521052
self._array.__ifloordiv__(other._array)
10531053
return self
10541054

1055-
def __rfloordiv__(self, other: Array | int | float, /) -> Array:
1055+
def __rfloordiv__(self, other: Array | float, /) -> Array:
10561056
"""
10571057
Performs the operation __rfloordiv__.
10581058
"""
@@ -1113,7 +1113,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
11131113
res = self._array.__rmatmul__(other._array)
11141114
return self.__class__._new(res, device=self.device)
11151115

1116-
def __imod__(self, other: Array | int | float, /) -> Array:
1116+
def __imod__(self, other: Array | float, /) -> Array:
11171117
"""
11181118
Performs the operation __imod__.
11191119
"""
@@ -1123,7 +1123,7 @@ def __imod__(self, other: Array | int | float, /) -> Array:
11231123
self._array.__imod__(other._array)
11241124
return self
11251125

1126-
def __rmod__(self, other: Array | int | float, /) -> Array:
1126+
def __rmod__(self, other: Array | float, /) -> Array:
11271127
"""
11281128
Performs the operation __rmod__.
11291129
"""
@@ -1135,7 +1135,7 @@ def __rmod__(self, other: Array | int | float, /) -> Array:
11351135
res = self._array.__rmod__(other._array)
11361136
return self.__class__._new(res, device=self.device)
11371137

1138-
def __imul__(self, other: Array | int | float | complex, /) -> Array:
1138+
def __imul__(self, other: Array | complex, /) -> Array:
11391139
"""
11401140
Performs the operation __imul__.
11411141
"""
@@ -1145,7 +1145,7 @@ def __imul__(self, other: Array | int | float | complex, /) -> Array:
11451145
self._array.__imul__(other._array)
11461146
return self
11471147

1148-
def __rmul__(self, other: Array | int | float | complex, /) -> Array:
1148+
def __rmul__(self, other: Array | complex, /) -> Array:
11491149
"""
11501150
Performs the operation __rmul__.
11511151
"""
@@ -1157,7 +1157,7 @@ def __rmul__(self, other: Array | int | float | complex, /) -> Array:
11571157
res = self._array.__rmul__(other._array)
11581158
return self.__class__._new(res, device=self.device)
11591159

1160-
def __ior__(self, other: Array | bool | int, /) -> Array:
1160+
def __ior__(self, other: Array | int, /) -> Array:
11611161
"""
11621162
Performs the operation __ior__.
11631163
"""
@@ -1167,7 +1167,7 @@ def __ior__(self, other: Array | bool | int, /) -> Array:
11671167
self._array.__ior__(other._array)
11681168
return self
11691169

1170-
def __ror__(self, other: Array | bool | int, /) -> Array:
1170+
def __ror__(self, other: Array | int, /) -> Array:
11711171
"""
11721172
Performs the operation __ror__.
11731173
"""
@@ -1179,7 +1179,7 @@ def __ror__(self, other: Array | bool | int, /) -> Array:
11791179
res = self._array.__ror__(other._array)
11801180
return self.__class__._new(res, device=self.device)
11811181

1182-
def __ipow__(self, other: Array | int | float | complex, /) -> Array:
1182+
def __ipow__(self, other: Array | complex, /) -> Array:
11831183
"""
11841184
Performs the operation __ipow__.
11851185
"""
@@ -1189,7 +1189,7 @@ def __ipow__(self, other: Array | int | float | complex, /) -> Array:
11891189
self._array.__ipow__(other._array)
11901190
return self
11911191

1192-
def __rpow__(self, other: Array | int | float | complex, /) -> Array:
1192+
def __rpow__(self, other: Array | complex, /) -> Array:
11931193
"""
11941194
Performs the operation __rpow__.
11951195
"""
@@ -1224,7 +1224,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
12241224
res = self._array.__rrshift__(other._array)
12251225
return self.__class__._new(res, device=self.device)
12261226

1227-
def __isub__(self, other: Array | int | float | complex, /) -> Array:
1227+
def __isub__(self, other: Array | complex, /) -> Array:
12281228
"""
12291229
Performs the operation __isub__.
12301230
"""
@@ -1234,7 +1234,7 @@ def __isub__(self, other: Array | int | float | complex, /) -> Array:
12341234
self._array.__isub__(other._array)
12351235
return self
12361236

1237-
def __rsub__(self, other: Array | int | float | complex, /) -> Array:
1237+
def __rsub__(self, other: Array | complex, /) -> Array:
12381238
"""
12391239
Performs the operation __rsub__.
12401240
"""
@@ -1246,7 +1246,7 @@ def __rsub__(self, other: Array | int | float | complex, /) -> Array:
12461246
res = self._array.__rsub__(other._array)
12471247
return self.__class__._new(res, device=self.device)
12481248

1249-
def __itruediv__(self, other: Array | int | float | complex, /) -> Array:
1249+
def __itruediv__(self, other: Array | complex, /) -> Array:
12501250
"""
12511251
Performs the operation __itruediv__.
12521252
"""
@@ -1256,7 +1256,7 @@ def __itruediv__(self, other: Array | int | float | complex, /) -> Array:
12561256
self._array.__itruediv__(other._array)
12571257
return self
12581258

1259-
def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
1259+
def __rtruediv__(self, other: Array | complex, /) -> Array:
12601260
"""
12611261
Performs the operation __rtruediv__.
12621262
"""
@@ -1268,7 +1268,7 @@ def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
12681268
res = self._array.__rtruediv__(other._array)
12691269
return self.__class__._new(res, device=self.device)
12701270

1271-
def __ixor__(self, other: Array | bool | int, /) -> Array:
1271+
def __ixor__(self, other: Array | int, /) -> Array:
12721272
"""
12731273
Performs the operation __ixor__.
12741274
"""
@@ -1278,7 +1278,7 @@ def __ixor__(self, other: Array | bool | int, /) -> Array:
12781278
self._array.__ixor__(other._array)
12791279
return self
12801280

1281-
def __rxor__(self, other: Array | bool | int, /) -> Array:
1281+
def __rxor__(self, other: Array | int, /) -> Array:
12821282
"""
12831283
Performs the operation __rxor__.
12841284
"""

0 commit comments

Comments
 (0)