|
2 | 2 |
|
3 | 3 | from ._array_object import Array
|
4 | 4 | from ._dtypes import _result_type, _real_numeric_dtypes
|
5 |
| -from ._flags import requires_data_dependent_shapes, requires_api_version |
| 5 | +from ._flags import requires_data_dependent_shapes, requires_api_version, get_array_api_strict_flags |
6 | 6 |
|
7 | 7 | from typing import TYPE_CHECKING
|
8 | 8 | if TYPE_CHECKING:
|
@@ -72,12 +72,19 @@ def searchsorted(
|
72 | 72 | # x1 must be 1-D, but NumPy already requires this.
|
73 | 73 | return Array._new(np.searchsorted(x1._array, x2._array, side=side, sorter=sorter), device=x1.device)
|
74 | 74 |
|
75 |
| -def where(condition: Array, x1: Array, x2: Array, /) -> Array: |
| 75 | +def where(condition: Array, x1: bool | int | float | Array, x2: bool | int | float | Array, /) -> Array: |
76 | 76 | """
|
77 | 77 | Array API compatible wrapper for :py:func:`np.where <numpy.where>`.
|
78 | 78 |
|
79 | 79 | See its docstring for more information.
|
80 | 80 | """
|
| 81 | + if get_array_api_strict_flags()['api_version'] > '2023.12': |
| 82 | + if isinstance(x1, (bool, float, int)): |
| 83 | + x1 = Array._new(np.asarray(x1), device=condition.device) |
| 84 | + |
| 85 | + if isinstance(x2, (bool, float, int)): |
| 86 | + x2 = Array._new(np.asarray(x2), device=condition.device) |
| 87 | + |
81 | 88 | # Call result type here just to raise on disallowed type combinations
|
82 | 89 | _result_type(x1.dtype, x2.dtype)
|
83 | 90 |
|
|
0 commit comments