Skip to content

ENH: unary functions overhaul; better input validation #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Apr 23, 2025

xref #145

  • Rewrite all unary functions with a generator
  • Disallow numpy generics in binary functions, clip, and where
  • Improve error message when the first argument of where is not an Array
  • Test for device mismatches in the inputs of binary functions, clip, and where
  • Test input-output device propagation in where

@@ -168,9 +231,6 @@ def _array_vals():
for d in _floating_dtypes:
yield asarray(1.0, dtype=d)

# Use the latest version of the standard so all functions are included
set_array_api_strict_flags(api_version="2024.12")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant with auto-applied fixture

Copy link
Member

@ev-br ev-br left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Let keep this PR open for a while though, in case somebody has opinions on generating unary functions from a decorator. I personally think this is a good change, but there were concerns in #100

res = xp.where(cond, 1, x2)
assert res.device == device
res = xp.where(cond, x1, 2)
assert res.device == device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the following tests are great. I imaging we'll want to parrot them in array-api-tests at some point.

@crusaderky
Copy link
Contributor Author

Note: this does not close #102 , as Python sneakily callls numpy.ndarray.__radd__. When array_api_strict.Array.__add__ fails. The opposite (LHS is numpy, RHS is array-api-strict) is also impossible to fix without disallowing __array__ and __buffer__.

@ev-br
Copy link
Member

ev-br commented Apr 24, 2025

Note: this does not close #102 , as Python sneakily callls numpy.ndarray.radd. When array_api_strict.Array.add fails. The opposite (LHS is numpy, RHS is array-api-strict) is also impossible to fix without disallowing array and buffer.

Exactly.
I suggest we ignore this for the time being.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants