Skip to content

ENH: Wrapped namespaces support #144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
crusaderky opened this issue Feb 19, 2025 · 4 comments
Open

ENH: Wrapped namespaces support #144

crusaderky opened this issue Feb 19, 2025 · 4 comments
Labels
enhancement New feature or request

Comments

@crusaderky
Copy link
Contributor

I'm starting to see multiple Array API compliant libraries that are just thin wrappers around another arbitrary Array API compliant library:

  • @mdhaber's marray
  • @lucascolley's quantity-array
  • I'm deliberately keeping Dask out of scope for this issue, as there is a lot of nuance involved in it.
  • xarray's wrapping design also predates the Array API and xarray is not (today) Array API compatible, but I can see it joining the fray soon.

These libraries currently fail to work with array-api-compat, array-api-extra, or anything that relies on them (like scipy) except in the most basic use cases, e.g. wrapping around a writeable numpy or cupy array.

These are the reasons I could find so far:

  • array_api_compat.is_writeable_array blindly returns True for unknown namespaces. This is problematic also when wrapping around read-only numpy arrays.
  • array_api_compat.is_lazy_array incorrectly returns False for wrapped Dask. This will lead to major performance issues.
  • xpx.at, xpx.apply_where, and xpx.lazy_apply have special casing for Dask which won't be triggered when Dask is wrapped. You can't just change array_api_compat.is_dask_array to return True for wrapped dask objects, because there are a wealth of dask-specific calls (da.map_blocks, passing the arrays to a @delayed function) that won't work out of the box.
  • The same functions, plus nunique, have special casing for JAX. Again, there are jax-specific function calls e.g. jax.pure_callback that won't accept wrapped arrays.
  • Delegated functions (xpx.pad) should work thanks to their generic code path, but they are untested; so for example pad on PyTorch works due to special casing, but there are no tests that verify that the generic implementation of pad applied to a wrapper around PyTorch will work.
@crusaderky
Copy link
Contributor Author

jax.jit will fail with these libraries too. The reason is that jitted functions require inputs that are jax arrays. This could be fixed by an ad-hoc wrapper around it, that unwraps the arrays just before entering the jitted function and then re-wraps it immediately afterwards.

@mdhaber
Copy link
Contributor

mdhaber commented Feb 20, 2025

I thought the MArray __array_namespace__ info might be relevant here:

from marray import numpy as mxp
x = mxp.asarray([1, 2, 3], mask=[False, True, False])
x.__array_namespace__()
# <module 'marray.numpy'>

So if the base library underlying the MArray is important for special-casing, it's not hard to get at that information. Perhaps there can be an informal (or formal?) standard for this sort of thing, and standard recipe for parsing that information?

As for library-specific calls, I've thought about having separately-installable packages (or maybe just separate repos that get vendored) that could be maintained by the developers of the underlying libraries. This would allow Dask developers to add additional features to the marray.dask namespace or to the array objects themselves.

@crusaderky
Copy link
Contributor Author

so if Dask developers wanted to add additional features to the marray.dask namespace or to the array objects themselves, they could do so.

I don't think that's realistic, particularly for lesser-known wrappers.

Maybe a more sustainable approach could be for all wrapper namespaces to offer a coherent API that lets you unwrap and rewrap the inner objects:

def apply(callback, *args, xp=None, **kwargs):
    inner_args = tuple(x.data for x in args)  # Backend-specific
    outer_ns = array_namespace(*args)
    inner_ns = array_namespace(*inner_args)
    if xp in (None, inner_ns):
        out = callback(*inner_args, **kwargs)
    else:
        out = inner_ns.apply(callback, *inner_args, **kwargs, xp=xp)

    # Backend-specific: re-add units=, mask=, etc. from args here.
    return outer_ns.asarray(out)

so this (mdhaber/marray#91 (comment)):

mxp = marray.wrap_namespace(np)
dxp = da.wrap_namespace(mxp)
qxp = quantity_array.wrap_namespace(dxp)

a = qxp.asarray(qxp.arange(10), units="pint")
a = qxp.asarray(a.data.map_blocks(lambda x: mxp.asarray(x, mask=x > 5)), units=a.units)

would become:

a = qxp.asarray(qxp.arange(10), units="pint")
a = qxp.apply(lambda x: mxp.asarray(x, mask=x > 5), a, xp=mxp)

@crusaderky
Copy link
Contributor Author

Functions like at that deal in special cases would need to add two lines:

def special_case_func(x, xp):
    if is_wrapper_namespace(xp):
        return xp.apply(special_case_func, x)

@lucascolley lucascolley added the enhancement New feature or request label Feb 26, 2025
@lucascolley lucascolley changed the title Wrapped namespaces support ENH: Wrapped namespaces support Feb 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants