Skip to content

Commit dd0406a

Browse files
committed
Fix GPU clustering
A variable which did not support GPU operations was errantly on GPU when cuda=True. Move to CPU.
1 parent e4c58bf commit dd0406a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vamb/cluster.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def _smaller_indices(tensor, kept_mask, threshold, cuda):
349349

350350
# If it's on GPU, we remove the already clustered points at this step.
351351
if cuda:
352-
return _torch.nonzero((tensor <= threshold) & kept_mask).flatten()
352+
return _torch.nonzero((tensor <= threshold) & kept_mask).flatten().cpu()
353353
else:
354354
arr = tensor.numpy()
355355
indices = (arr <= threshold).nonzero()[0]

0 commit comments

Comments
 (0)