@@ -170,78 +170,58 @@ def default_dtypes(self, *, device=None):
170
170
"indexing" : default_integral ,
171
171
}
172
172
173
-
174
173
def _dtypes (self , kind ):
175
- bool = torch .bool
176
- int8 = torch .int8
177
- int16 = torch .int16
178
- int32 = torch .int32
179
- int64 = torch .int64
180
- uint8 = torch .uint8
181
- # uint16, uint32, and uint64 are present in newer versions of pytorch,
182
- # but they aren't generally supported by the array API functions, so
183
- # we omit them from this function.
184
- float32 = torch .float32
185
- float64 = torch .float64
186
- complex64 = torch .complex64
187
- complex128 = torch .complex128
188
-
189
174
if kind is None :
190
- return {
191
- "bool" : bool ,
192
- "int8" : int8 ,
193
- "int16" : int16 ,
194
- "int32" : int32 ,
195
- "int64" : int64 ,
196
- "uint8" : uint8 ,
197
- "float32" : float32 ,
198
- "float64" : float64 ,
199
- "complex64" : complex64 ,
200
- "complex128" : complex128 ,
201
- }
175
+ return self ._dtypes (
176
+ (
177
+ "bool" ,
178
+ "signed integer" ,
179
+ "unsigned integer" ,
180
+ "real floating" ,
181
+ "complex floating" ,
182
+ )
183
+ )
202
184
if kind == "bool" :
203
- return {"bool" : bool }
185
+ return {"bool" : torch . bool }
204
186
if kind == "signed integer" :
205
187
return {
206
- "int8" : int8 ,
207
- "int16" : int16 ,
208
- "int32" : int32 ,
209
- "int64" : int64 ,
188
+ "int8" : torch . int8 ,
189
+ "int16" : torch . int16 ,
190
+ "int32" : torch . int32 ,
191
+ "int64" : torch . int64 ,
210
192
}
211
193
if kind == "unsigned integer" :
212
- return {
213
- "uint8" : uint8 ,
214
- }
194
+ try :
195
+ # torch >=2.3
196
+ return {
197
+ "uint8" : torch .uint8 ,
198
+ "uint16" : torch .uint16 ,
199
+ "uint32" : torch .uint32 ,
200
+ "uint64" : torch .uint32 ,
201
+ }
202
+ except AttributeError :
203
+ return {"uint8" : torch .uint8 }
215
204
if kind == "integral" :
216
- return {
217
- "int8" : int8 ,
218
- "int16" : int16 ,
219
- "int32" : int32 ,
220
- "int64" : int64 ,
221
- "uint8" : uint8 ,
222
- }
205
+ return self ._dtypes (("signed integer" , "unsigned integer" ))
223
206
if kind == "real floating" :
224
207
return {
225
- "float32" : float32 ,
226
- "float64" : float64 ,
208
+ "float32" : torch . float32 ,
209
+ "float64" : torch . float64 ,
227
210
}
228
211
if kind == "complex floating" :
229
212
return {
230
- "complex64" : complex64 ,
231
- "complex128" : complex128 ,
213
+ "complex64" : torch . complex64 ,
214
+ "complex128" : torch . complex128 ,
232
215
}
233
216
if kind == "numeric" :
234
- return {
235
- "int8" : int8 ,
236
- "int16" : int16 ,
237
- "int32" : int32 ,
238
- "int64" : int64 ,
239
- "uint8" : uint8 ,
240
- "float32" : float32 ,
241
- "float64" : float64 ,
242
- "complex64" : complex64 ,
243
- "complex128" : complex128 ,
244
- }
217
+ return self ._dtypes (
218
+ (
219
+ "signed integer" ,
220
+ "unsigned integer" ,
221
+ "real floating" ,
222
+ "complex floating" ,
223
+ )
224
+ )
245
225
if isinstance (kind , tuple ):
246
226
res = {}
247
227
for k in kind :
@@ -261,7 +241,6 @@ def dtypes(self, *, device=None, kind=None):
261
241
----------
262
242
device : Device, optional
263
243
The device to get the data types for.
264
- Unused for PyTorch, as all devices use the same dtypes.
265
244
kind : str or tuple of str, optional
266
245
The kind of data types to return. If ``None``, all data types are
267
246
returned. If a string, only data types of that kind are returned.
0 commit comments