Skip to content

Commit e97f18e

Browse files
committed
BUG: _utils.in1d: fix device
1 parent 33ccdbb commit e97f18e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/array_api_extra/_lib/_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def in1d(
3333
# This code is run to make the code significantly faster
3434
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
3535
if invert:
36-
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=x1.device)
36+
mask = xp.ones(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
3737
for a in x2:
3838
mask &= x1 != a
3939
else:
40-
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=x1.device)
40+
mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
4141
for a in x2:
4242
mask |= x1 == a
4343
return mask

0 commit comments

Comments
 (0)