forked from MaartenGr/BERTopic
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_utils.py
More file actions
70 lines (54 loc) · 2.46 KB
/
_utils.py
File metadata and controls
70 lines (54 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import hdbscan
import numpy as np
def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
""" Function used to select the HDBSCAN-like model for generating
predictions and probabilities.
Arguments:
model: The cluster model.
func: The function to use. Options:
- "approximate_predict"
- "all_points_membership_vectors"
- "membership_vector"
embeddings: Input embeddings for "approximate_predict"
and "membership_vector"
"""
# Approximate predict
if func == "approximate_predict":
if isinstance(model, hdbscan.HDBSCAN):
predictions, probabilities = hdbscan.approximate_predict(model, embeddings)
return predictions, probabilities
str_type_model = str(type(model)).lower()
if "cuml" in str_type_model and "hdbscan" in str_type_model:
from cuml.cluster import hdbscan as cuml_hdbscan
predictions, probabilities = cuml_hdbscan.approximate_predict(model, embeddings)
return predictions, probabilities
predictions = model.predict(embeddings)
return predictions, None
# All points membership
if func == "all_points_membership_vectors":
if isinstance(model, hdbscan.HDBSCAN):
return hdbscan.all_points_membership_vectors(model)
str_type_model = str(type(model)).lower()
if "cuml" in str_type_model and "hdbscan" in str_type_model:
from cuml.cluster import hdbscan as cuml_hdbscan
return cuml_hdbscan.all_points_membership_vectors(model)
return None
# membership_vector
if func == "membership_vector":
if isinstance(model, hdbscan.HDBSCAN):
probabilities = hdbscan.membership_vector(model, embeddings)
return probabilities
str_type_model = str(type(model)).lower()
if "cuml" in str_type_model and "hdbscan" in str_type_model:
from cuml.cluster import hdbscan as cuml_hdbscan
probabilities = cuml_hdbscan.membership_vector(model, embeddings)
return probabilities
return None
def is_supported_hdbscan(model):
""" Check whether the input model is a supported HDBSCAN-like model """
if isinstance(model, hdbscan.HDBSCAN):
return True
str_type_model = str(type(model)).lower()
if "cuml" in str_type_model and "hdbscan" in str_type_model:
return True
return False