From 63b569f5cf8e2c4eebb69182b727416365f0f270 Mon Sep 17 00:00:00 2001 From: Olivier Bougeant <36014848+Bougeant@users.noreply.github.com> Date: Tue, 19 Mar 2024 23:43:54 +0200 Subject: [PATCH 1/3] Fix CUML HDBSCAN predictions by using correct method. --- bertopic/cluster/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 355a53f6..9815b96d 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -51,8 +51,8 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: - from cuml.cluster.hdbscan.prediction import approximate_predict - probabilities = approximate_predict(model, embeddings) + from cuml.cluster.hdbscan.prediction import membership_vector + probabilities = membership_vector(model, embeddings) return probabilities return None From bc0f3b3000326a35979e1fbf0dd88f14184445af Mon Sep 17 00:00:00 2001 From: Olivier Bougeant <36014848+Bougeant@users.noreply.github.com> Date: Wed, 20 Mar 2024 08:26:16 +0200 Subject: [PATCH 2/3] Trigger tests --- bertopic/cluster/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 9815b96d..c2733d18 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -51,6 +51,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: + from cuml.cluster.hdbscan.prediction import membership_vector probabilities = membership_vector(model, embeddings) return probabilities From 608867daf282800a3e7434535720a11874f2c92c Mon Sep 17 00:00:00 2001 From: Olivier Bougeant <36014848+Bougeant@users.noreply.github.com> Date: Wed, 20 Mar 2024 08:26:45 +0200 Subject: [PATCH 3/3] Restore --- bertopic/cluster/_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index c2733d18..9815b96d 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -51,7 +51,6 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: - from cuml.cluster.hdbscan.prediction import membership_vector probabilities = membership_vector(model, embeddings) return probabilities