@@ -10,21 +10,16 @@ def test_where_with_scalars():
10
10
x = xp .asarray ([1 , 2 , 3 , 1 ])
11
11
12
12
# Versions up to and including 2023.12 don't support scalar arguments
13
- with pytest .raises (AttributeError , match = "object has no attribute 'dtype'" ):
14
- xp .where (x == 1 , 42 , 44 )
13
+ with ArrayAPIStrictFlags (api_version = '2023.12' ):
14
+ with pytest .raises (AttributeError , match = "object has no attribute 'dtype'" ):
15
+ xp .where (x == 1 , 42 , 44 )
15
16
16
17
# Versions after 2023.12 support scalar arguments
17
- with (pytest .warns (
18
- UserWarning ,
19
- match = "The 2024.12 version of the array API specification is in draft status"
20
- ),
21
- ArrayAPIStrictFlags (api_version = draft_version ),
22
- ):
23
- x_where = xp .where (x == 1 , xp .asarray (42 ), 44 )
24
-
25
- expected = xp .asarray ([42 , 44 , 44 , 42 ])
26
- assert xp .all (x_where == expected )
27
-
28
- # The spec does not allow both x1 and x2 to be scalars
29
- with pytest .raises (ValueError , match = "One of" ):
30
- xp .where (x == 1 , 42 , 44 )
18
+ x_where = xp .where (x == 1 , xp .asarray (42 ), 44 )
19
+
20
+ expected = xp .asarray ([42 , 44 , 44 , 42 ])
21
+ assert xp .all (x_where == expected )
22
+
23
+ # The spec does not allow both x1 and x2 to be scalars
24
+ with pytest .raises (ValueError , match = "One of" ):
25
+ xp .where (x == 1 , 42 , 44 )
0 commit comments