Skip to content

Commit d9a6b73

Browse files
committed
Add support for scalar arguments to xp.where
1 parent d086c61 commit d9a6b73

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

array_api_strict/_searching_functions.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ._array_object import Array
44
from ._dtypes import _result_type, _real_numeric_dtypes
5-
from ._flags import requires_data_dependent_shapes, requires_api_version
5+
from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
@@ -72,12 +72,19 @@ def searchsorted(
7272
# x1 must be 1-D, but NumPy already requires this.
7373
return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device)
7474

75-
def where(condition: Array, x1: Array, x2: Array, /) -> Array:
75+
def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | float | Array, /) -> Array:
7676
"""
7777
Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
7878
7979
See its docstring for more information.
8080
"""
81+
if get_array_api_strict_flags()['api_version'] > '2023.12':
82+
if isinstance(x1, (bool, float, int)):
83+
x1 = Array._new(np.asarray(x1), device=condition.device)
84+
85+
if isinstance(x2, (bool, float, int)):
86+
x2 = Array._new(np.asarray(x2), device=condition.device)
87+
8188
# Call result type here just to raise on disallowed type combinations
8289
_result_type(x1.dtype, x2.dtype)
8390

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
3+
import array_api_strict as xp
4+
5+
from array_api_strict import ArrayAPIStrictFlags
6+
from array_api_strict._flags import next_supported_version
7+
8+
9+
def test_where_with_scalars():
10+
x = xp.asarray([1, 2, 3, 1])
11+
12+
# Versions up to and including 2023.12 don't support scalar arguments
13+
with pytest.raises(AttributeError, match="object has no attribute 'dtype'"):
14+
xp.where(x == 1, 42, 44)
15+
16+
# Versions after 2023.12 support scalar arguments
17+
with ArrayAPIStrictFlags(api_version=next_supported_version):
18+
x_where = xp.where(x == 1, 42, 44)
19+
20+
expected = xp.asarray([42, 44, 44, 42])
21+
assert xp.all(x_where == expected)

0 commit comments

Comments
 (0)