Skip to content

Commit ff6ea18

Browse files
committed
Torch info for large uints
1 parent cfbd75b commit ff6ea18

File tree

2 files changed

+37
-60
lines changed

2 files changed

+37
-60
lines changed

.github/workflows/array-api-tests-torch.yml

-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,4 @@ jobs:
88
with:
99
package-name: torch
1010
extra-requires: '--index-url https://download.pytorch.org/whl/cpu'
11-
extra-env-vars: |
12-
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64
1311
python-versions: '[''3.10'', ''3.13'']'

array_api_compat/torch/_info.py

+37-58
Original file line numberDiff line numberDiff line change
@@ -170,78 +170,58 @@ def default_dtypes(self, *, device=None):
170170
"indexing": default_integral,
171171
}
172172

173-
174173
def _dtypes(self, kind):
175-
bool = torch.bool
176-
int8 = torch.int8
177-
int16 = torch.int16
178-
int32 = torch.int32
179-
int64 = torch.int64
180-
uint8 = torch.uint8
181-
# uint16, uint32, and uint64 are present in newer versions of pytorch,
182-
# but they aren't generally supported by the array API functions, so
183-
# we omit them from this function.
184-
float32 = torch.float32
185-
float64 = torch.float64
186-
complex64 = torch.complex64
187-
complex128 = torch.complex128
188-
189174
if kind is None:
190-
return {
191-
"bool": bool,
192-
"int8": int8,
193-
"int16": int16,
194-
"int32": int32,
195-
"int64": int64,
196-
"uint8": uint8,
197-
"float32": float32,
198-
"float64": float64,
199-
"complex64": complex64,
200-
"complex128": complex128,
201-
}
175+
return self._dtypes(
176+
(
177+
"bool",
178+
"signed integer",
179+
"unsigned integer",
180+
"real floating",
181+
"complex floating",
182+
)
183+
)
202184
if kind == "bool":
203-
return {"bool": bool}
185+
return {"bool": torch.bool}
204186
if kind == "signed integer":
205187
return {
206-
"int8": int8,
207-
"int16": int16,
208-
"int32": int32,
209-
"int64": int64,
188+
"int8": torch.int8,
189+
"int16": torch.int16,
190+
"int32": torch.int32,
191+
"int64": torch.int64,
210192
}
211193
if kind == "unsigned integer":
212-
return {
213-
"uint8": uint8,
214-
}
194+
try:
195+
# torch >=2.3
196+
return {
197+
"uint8": torch.uint8,
198+
"uint16": torch.uint16,
199+
"uint32": torch.uint32,
200+
"uint64": torch.uint32,
201+
}
202+
except AttributeError:
203+
return {"uint8": torch.uint8}
215204
if kind == "integral":
216-
return {
217-
"int8": int8,
218-
"int16": int16,
219-
"int32": int32,
220-
"int64": int64,
221-
"uint8": uint8,
222-
}
205+
return self._dtypes(("signed integer", "unsigned integer"))
223206
if kind == "real floating":
224207
return {
225-
"float32": float32,
226-
"float64": float64,
208+
"float32": torch.float32,
209+
"float64": torch.float64,
227210
}
228211
if kind == "complex floating":
229212
return {
230-
"complex64": complex64,
231-
"complex128": complex128,
213+
"complex64": torch.complex64,
214+
"complex128": torch.complex128,
232215
}
233216
if kind == "numeric":
234-
return {
235-
"int8": int8,
236-
"int16": int16,
237-
"int32": int32,
238-
"int64": int64,
239-
"uint8": uint8,
240-
"float32": float32,
241-
"float64": float64,
242-
"complex64": complex64,
243-
"complex128": complex128,
244-
}
217+
return self._dtypes(
218+
(
219+
"signed integer",
220+
"unsigned integer",
221+
"real floating",
222+
"complex floating",
223+
)
224+
)
245225
if isinstance(kind, tuple):
246226
res = {}
247227
for k in kind:
@@ -261,7 +241,6 @@ def dtypes(self, *, device=None, kind=None):
261241
----------
262242
device : Device, optional
263243
The device to get the data types for.
264-
Unused for PyTorch, as all devices use the same dtypes.
265244
kind : str or tuple of str, optional
266245
The kind of data types to return. If ``None``, all data types are
267246
returned. If a string, only data types of that kind are returned.

0 commit comments

Comments
 (0)