We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
_utils.in1d
1 parent 33ccdbb commit e97f18eCopy full SHA for e97f18e
src/array_api_extra/_lib/_utils.py
@@ -33,11 +33,11 @@ def in1d(
33
# This code is run to make the code significantly faster
34
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
35
if invert:
36
- mask = xp.ones(x1.shape[0], dtype=xp.bool, device=x1.device)
+ mask = xp.ones(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
37
for a in x2:
38
mask &= x1 != a
39
else:
40
- mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=x1.device)
+ mask = xp.zeros(x1.shape[0], dtype=xp.bool, device=_compat.device(x1))
41
42
mask |= x1 == a
43
return mask
0 commit comments