Skip to content

Commit 85dc3ba

Browse files
update promotion table and can_cast table
1 parent 7118894 commit 85dc3ba

File tree

5 files changed

+64
-77
lines changed

5 files changed

+64
-77
lines changed

array_api_compat/common/_helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def is_paddle_array(x):
144144

145145
import paddle
146146

147-
# TODO: Should we reject ndarray subclasses?
148147
return paddle.is_tensor(x)
149148

150149
def is_ndonnx_array(x):
@@ -725,7 +724,7 @@ def device(x: Array, /) -> Device:
725724
return "cpu"
726725
elif "gpu" in raw_place_str:
727726
return "gpu"
728-
raise NotImplementedError(f"Unsupported device {raw_place_str}")
727+
raise ValueError(f"Unsupported Paddle device: {x.place}")
729728

730729
return x.device
731730

array_api_compat/paddle/_aliases.py

Lines changed: 51 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -42,37 +42,18 @@
4242
paddle.complex128,
4343
}
4444

45+
# NOTE: Implicit promotion rules of Paddle is a bit strict than other frameworks,
46+
# see details: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/advanced/auto_type_promotion_cn.html
4547
_promotion_table = {
4648
# bool
4749
(paddle.bool, paddle.bool): paddle.bool,
4850
# ints
4951
(paddle.int8, paddle.int8): paddle.int8,
50-
(paddle.int8, paddle.int16): paddle.int16,
51-
(paddle.int8, paddle.int32): paddle.int32,
52-
(paddle.int8, paddle.int64): paddle.int64,
53-
(paddle.int16, paddle.int8): paddle.int16,
5452
(paddle.int16, paddle.int16): paddle.int16,
55-
(paddle.int16, paddle.int32): paddle.int32,
56-
(paddle.int16, paddle.int64): paddle.int64,
57-
(paddle.int32, paddle.int8): paddle.int32,
58-
(paddle.int32, paddle.int16): paddle.int32,
5953
(paddle.int32, paddle.int32): paddle.int32,
60-
(paddle.int32, paddle.int64): paddle.int64,
61-
(paddle.int64, paddle.int8): paddle.int64,
62-
(paddle.int64, paddle.int16): paddle.int64,
63-
(paddle.int64, paddle.int32): paddle.int64,
6454
(paddle.int64, paddle.int64): paddle.int64,
6555
# uints
6656
(paddle.uint8, paddle.uint8): paddle.uint8,
67-
# ints and uints (mixed sign)
68-
(paddle.int8, paddle.uint8): paddle.int16,
69-
(paddle.int16, paddle.uint8): paddle.int16,
70-
(paddle.int32, paddle.uint8): paddle.int32,
71-
(paddle.int64, paddle.uint8): paddle.int64,
72-
(paddle.uint8, paddle.int8): paddle.int16,
73-
(paddle.uint8, paddle.int16): paddle.int16,
74-
(paddle.uint8, paddle.int32): paddle.int32,
75-
(paddle.uint8, paddle.int64): paddle.int64,
7657
# floats
7758
(paddle.float32, paddle.float32): paddle.float32,
7859
(paddle.float32, paddle.float64): paddle.float64,
@@ -158,12 +139,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
158139
paddle.float64: True,
159140
paddle.complex64: True,
160141
paddle.complex128: True,
161-
paddle.uint8: False,
162-
paddle.int8: False,
163-
paddle.int16: False,
164-
paddle.int32: False,
165-
paddle.int64: False,
166-
paddle.bool: False,
142+
paddle.uint8: True,
143+
paddle.int8: True,
144+
paddle.int16: True,
145+
paddle.int32: True,
146+
paddle.int64: True,
147+
paddle.bool: True,
167148
},
168149
paddle.float16: {
169150
paddle.bfloat16: True,
@@ -172,12 +153,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
172153
paddle.float64: True,
173154
paddle.complex64: True,
174155
paddle.complex128: True,
175-
paddle.uint8: False,
176-
paddle.int8: False,
177-
paddle.int16: False,
178-
paddle.int32: False,
179-
paddle.int64: False,
180-
paddle.bool: False,
156+
paddle.uint8: True,
157+
paddle.int8: True,
158+
paddle.int16: True,
159+
paddle.int32: True,
160+
paddle.int64: True,
161+
paddle.bool: True,
181162
},
182163
paddle.float32: {
183164
paddle.bfloat16: True,
@@ -186,12 +167,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
186167
paddle.float64: True,
187168
paddle.complex64: True,
188169
paddle.complex128: True,
189-
paddle.uint8: False,
190-
paddle.int8: False,
191-
paddle.int16: False,
192-
paddle.int32: False,
193-
paddle.int64: False,
194-
paddle.bool: False,
170+
paddle.uint8: True,
171+
paddle.int8: True,
172+
paddle.int16: True,
173+
paddle.int32: True,
174+
paddle.int64: True,
175+
paddle.bool: True,
195176
},
196177
paddle.float64: {
197178
paddle.bfloat16: True,
@@ -200,40 +181,40 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
200181
paddle.float64: True,
201182
paddle.complex64: True,
202183
paddle.complex128: True,
203-
paddle.uint8: False,
204-
paddle.int8: False,
205-
paddle.int16: False,
206-
paddle.int32: False,
207-
paddle.int64: False,
208-
paddle.bool: False,
184+
paddle.uint8: True,
185+
paddle.int8: True,
186+
paddle.int16: True,
187+
paddle.int32: True,
188+
paddle.int64: True,
189+
paddle.bool: True,
209190
},
210191
paddle.complex64: {
211-
paddle.bfloat16: False,
212-
paddle.float16: False,
213-
paddle.float32: False,
214-
paddle.float64: False,
192+
paddle.bfloat16: True,
193+
paddle.float16: True,
194+
paddle.float32: True,
195+
paddle.float64: True,
215196
paddle.complex64: True,
216197
paddle.complex128: True,
217-
paddle.uint8: False,
218-
paddle.int8: False,
219-
paddle.int16: False,
220-
paddle.int32: False,
221-
paddle.int64: False,
222-
paddle.bool: False,
198+
paddle.uint8: True,
199+
paddle.int8: True,
200+
paddle.int16: True,
201+
paddle.int32: True,
202+
paddle.int64: True,
203+
paddle.bool: True,
223204
},
224205
paddle.complex128: {
225-
paddle.bfloat16: False,
226-
paddle.float16: False,
227-
paddle.float32: False,
228-
paddle.float64: False,
206+
paddle.bfloat16: True,
207+
paddle.float16: True,
208+
paddle.float32: True,
209+
paddle.float64: True,
229210
paddle.complex64: True,
230211
paddle.complex128: True,
231-
paddle.uint8: False,
232-
paddle.int8: False,
233-
paddle.int16: False,
234-
paddle.int32: False,
235-
paddle.int64: False,
236-
paddle.bool: False,
212+
paddle.uint8: True,
213+
paddle.int8: True,
214+
paddle.int16: True,
215+
paddle.int32: True,
216+
paddle.int64: True,
217+
paddle.bool: True,
237218
},
238219
paddle.uint8: {
239220
paddle.bfloat16: True,
@@ -247,7 +228,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
247228
paddle.int16: True,
248229
paddle.int32: True,
249230
paddle.int64: True,
250-
paddle.bool: False,
231+
paddle.bool: True,
251232
},
252233
paddle.int8: {
253234
paddle.bfloat16: True,
@@ -261,7 +242,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
261242
paddle.int16: True,
262243
paddle.int32: True,
263244
paddle.int64: True,
264-
paddle.bool: False,
245+
paddle.bool: True,
265246
},
266247
paddle.int16: {
267248
paddle.bfloat16: True,
@@ -275,7 +256,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
275256
paddle.int16: True,
276257
paddle.int32: True,
277258
paddle.int64: True,
278-
paddle.bool: False,
259+
paddle.bool: True,
279260
},
280261
paddle.int32: {
281262
paddle.bfloat16: True,
@@ -289,7 +270,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
289270
paddle.int16: True,
290271
paddle.int32: True,
291272
paddle.int64: True,
292-
paddle.bool: False,
273+
paddle.bool: True,
293274
},
294275
paddle.int64: {
295276
paddle.bfloat16: True,
@@ -303,7 +284,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
303284
paddle.int16: True,
304285
paddle.int32: True,
305286
paddle.int64: True,
306-
paddle.bool: False,
287+
paddle.bool: True,
307288
},
308289
paddle.bool: {
309290
paddle.bfloat16: True,

tests/_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
wrapped_libraries = ["numpy", "paddle"]
6+
wrapped_libraries = ["numpy", "paddle", "torch"]
77
all_libraries = wrapped_libraries + []
88

99
# `sparse` added array API support as of Python 3.10.

tests/test_all.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,5 @@ def test_all(library):
4040
all_names = module.__all__
4141

4242
if set(dir_names) != set(all_names):
43-
assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
44-
assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
43+
assert set(dir_names) - set(all_names) == set(), f"Failed in library '{library}', some dir() names not included in __all__ for {mod_name}"
44+
assert set(all_names) - set(dir_names) == set(), f"Failed in library '{library}', some __all__ names not in dir() for {mod_name}"

tests/test_common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
is_array_functions = {
1818
'numpy': 'is_numpy_array',
1919
# 'cupy': 'is_cupy_array',
20-
# 'torch': 'is_torch_array',
20+
'torch': 'is_torch_array',
2121
# 'dask.array': 'is_dask_array',
2222
# 'jax.numpy': 'is_jax_array',
2323
# 'sparse': 'is_pydata_sparse_array',
@@ -27,7 +27,7 @@
2727
is_namespace_functions = {
2828
'numpy': 'is_numpy_namespace',
2929
# 'cupy': 'is_cupy_namespace',
30-
# 'torch': 'is_torch_namespace',
30+
'torch': 'is_torch_namespace',
3131
# 'dask.array': 'is_dask_namespace',
3232
# 'jax.numpy': 'is_jax_namespace',
3333
# 'sparse': 'is_pydata_sparse_namespace',
@@ -103,6 +103,13 @@ def test_asarray_cross_library(source_library, target_library, request):
103103
if source_library == "cupy" and target_library != "cupy":
104104
# cupy explicitly disallows implicit conversions to CPU
105105
pytest.skip(reason="cupy does not support implicit conversion to CPU")
106+
if source_library == "paddle" or target_library == "paddle":
107+
pytest.skip(
108+
reason=(
109+
"paddle does not support implicit conversion from/to other framework "
110+
"via 'asarray', dlpack is recommend now."
111+
)
112+
)
106113
elif source_library == "sparse" and target_library != "sparse":
107114
pytest.skip(reason="`sparse` does not allow implicit densification")
108115
src_lib = import_(source_library, wrapper=True)

0 commit comments

Comments
 (0)