-
Notifications
You must be signed in to change notification settings - Fork 10
ENH: Wrapped namespaces support #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
|
I thought the from marray import numpy as mxp
x = mxp.asarray([1, 2, 3], mask=[False, True, False])
x.__array_namespace__()
# <module 'marray.numpy'> So if the base library underlying the As for library-specific calls, I've thought about having separately-installable packages (or maybe just separate repos that get vendored) that could be maintained by the developers of the underlying libraries. This would allow Dask developers to add additional features to the |
I don't think that's realistic, particularly for lesser-known wrappers. Maybe a more sustainable approach could be for all wrapper namespaces to offer a coherent API that lets you unwrap and rewrap the inner objects: def apply(callback, *args, xp=None, **kwargs):
inner_args = tuple(x.data for x in args) # Backend-specific
outer_ns = array_namespace(*args)
inner_ns = array_namespace(*inner_args)
if xp in (None, inner_ns):
out = callback(*inner_args, **kwargs)
else:
out = inner_ns.apply(callback, *inner_args, **kwargs, xp=xp)
# Backend-specific: re-add units=, mask=, etc. from args here.
return outer_ns.asarray(out) so this (mdhaber/marray#91 (comment)): mxp = marray.wrap_namespace(np)
dxp = da.wrap_namespace(mxp)
qxp = quantity_array.wrap_namespace(dxp)
a = qxp.asarray(qxp.arange(10), units="pint")
a = qxp.asarray(a.data.map_blocks(lambda x: mxp.asarray(x, mask=x > 5)), units=a.units) would become: a = qxp.asarray(qxp.arange(10), units="pint")
a = qxp.apply(lambda x: mxp.asarray(x, mask=x > 5), a, xp=mxp) |
Functions like def special_case_func(x, xp):
if is_wrapper_namespace(xp):
return xp.apply(special_case_func, x) |
I'm starting to see multiple Array API compliant libraries that are just thin wrappers around another arbitrary Array API compliant library:
These libraries currently fail to work with
array-api-compat
,array-api-extra
, or anything that relies on them (like scipy) except in the most basic use cases, e.g. wrapping around a writeable numpy or cupy array.These are the reasons I could find so far:
array_api_compat.is_writeable_array
blindly returns True for unknown namespaces. This is problematic also when wrapping around read-only numpy arrays.array_api_compat.is_lazy_array
incorrectly returns False for wrapped Dask. This will lead to major performance issues.xpx.at
,xpx.apply_where
, andxpx.lazy_apply
have special casing for Dask which won't be triggered when Dask is wrapped. You can't just changearray_api_compat.is_dask_array
to return True for wrapped dask objects, because there are a wealth of dask-specific calls (da.map_blocks
, passing the arrays to a@delayed
function) that won't work out of the box.nunique
, have special casing for JAX. Again, there are jax-specific function calls e.g.jax.pure_callback
that won't accept wrapped arrays.xpx.pad
) should work thanks to their generic code path, but they are untested; so for examplepad
on PyTorch works due to special casing, but there are no tests that verify that the generic implementation ofpad
applied to a wrapper around PyTorch will work.The text was updated successfully, but these errors were encountered: