Skip to content

[WIP] Add paddle support #206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/array-api-tests-paddle.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: Array API Tests (Paddle Latest)

on: [push, pull_request]

jobs:
array-api-tests-paddle:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: paddle
extra-env-vars: |
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

This is a small wrapper around common array libraries that is compatible with
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
NumPy, CuPy, PyTorch, Dask, JAX, ndonnx, `sparse` and Paddle are supported. If you want
support for other array libraries, or if you encounter any issues, please [open
an issue](https://github.com/data-apis/array-api-compat/issues).

See the documentation for more details https://data-apis.org/array-api-compat/
See the documentation for more details <https://data-apis.org/array-api-compat/>
78 changes: 78 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,32 @@ def is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def is_paddle_array(x):
"""
Return True if `x` is a Paddle tensor.

This function does not import Paddle if it has not already been imported
and is therefore cheap to use.

See Also
--------

array_namespace
is_array_api_obj
is_numpy_array
is_cupy_array
is_dask_array
is_jax_array
is_pydata_sparse_array
"""
# Avoid importing paddle if it isn't already
if 'paddle' not in sys.modules:
return False

import paddle

return paddle.is_tensor(x)

def is_ndonnx_array(x):
"""
Return True if `x` is a ndonnx Array.
Expand Down Expand Up @@ -252,6 +278,7 @@ def is_array_api_obj(x):
or is_dask_array(x) \
or is_jax_array(x) \
or is_pydata_sparse_array(x) \
or is_paddle_array(x) \
or hasattr(x, '__array_namespace__')

def _compat_module_name():
Expand Down Expand Up @@ -319,6 +346,27 @@ def is_torch_namespace(xp) -> bool:
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_paddle_namespace(xp) -> bool:
"""
Returns True if `xp` is a Paddle namespace.

This includes both Paddle itself and the version wrapped by array-api-compat.

See Also
--------

array_namespace
is_numpy_namespace
is_cupy_namespace
is_ndonnx_namespace
is_dask_namespace
is_jax_namespace
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {'paddle', _compat_module_name() + '.paddle'}


def is_ndonnx_namespace(xp):
"""
Returns True if `xp` is an NDONNX namespace.
Expand Down Expand Up @@ -543,6 +591,14 @@ def your_function(x, y):
else:
import jax.experimental.array_api as jnp
namespaces.add(jnp)
elif is_paddle_array(x):
if _use_compat:
_check_api_version(api_version)
from .. import paddle as paddle_namespace
namespaces.add(paddle_namespace)
else:
import paddle
namespaces.add(paddle)
elif is_pydata_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
Expand Down Expand Up @@ -660,6 +716,16 @@ def device(x: Array, /) -> Device:
return "cpu"
# Return the device of the constituent array
return device(inner)
elif is_paddle_array(x):
raw_place_str = str(x.place)
if "gpu_pinned" in raw_place_str:
return "cpu"
elif "cpu" in raw_place_str:
return "cpu"
elif "gpu" in raw_place_str:
return "gpu"
raise ValueError(f"Unsupported Paddle device: {x.place}")

return x.device

# Prevent shadowing, used below
Expand Down Expand Up @@ -709,6 +775,14 @@ def _torch_to_device(x, device, /, stream=None):
raise NotImplementedError
return x.to(device)

def _paddle_to_device(x, device, /, stream=None):
if stream is not None:
raise NotImplementedError(
"paddle.Tensor.to() do not support stream argument yet"
)
return x.to(device)


def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
Expand Down Expand Up @@ -781,6 +855,8 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
# In JAX v0.4.31 and older, this import adds to_device method to x.
import jax.experimental.array_api # noqa: F401
return x.to_device(device, stream=stream)
elif is_paddle_array(x):
return _paddle_to_device(x, device, stream=stream)
elif is_pydata_sparse_array(x) and device == _device(x):
# Perform trivial check to return the same array if
# device is same instead of err-ing.
Expand Down Expand Up @@ -819,6 +895,8 @@ def size(x):
"is_torch_namespace",
"is_ndonnx_array",
"is_ndonnx_namespace",
"is_paddle_array",
"is_paddle_namespace",
"is_pydata_sparse_array",
"is_pydata_sparse_namespace",
"size",
Expand Down
22 changes: 22 additions & 0 deletions array_api_compat/paddle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from paddle import * # noqa: F403

# Several names are not included in the above import *
import paddle

for n in dir(paddle):
if n.startswith("_") or n.endswith("_") or "gpu" in n or "cpu" in n or "backward" in n:
continue
exec(f"{n} = paddle.{n}")


# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403

# See the comment in the numpy __init__.py
__import__(__package__ + ".linalg")

__import__(__package__ + ".fft")

from ..common._helpers import * # noqa: F403

__array_api_version__ = "2023.12"
Loading