Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4849,19 +4849,18 @@ def _create_model_from_files(
images: The images per topic
warn_no_backend: Whether to warn the user if no backend is given
"""
from sentence_transformers import SentenceTransformer

params["n_gram_range"] = tuple(params["n_gram_range"])

if ctfidf_config is not None:
ngram_range = ctfidf_config["vectorizer_model"]["params"]["ngram_range"]
ctfidf_config["vectorizer_model"]["params"]["ngram_range"] = tuple(ngram_range)

params["n_gram_range"] = tuple(params["n_gram_range"])
ctfidf_config

# Select HF model through SentenceTransformers
try:
from sentence_transformers import SentenceTransformer

embedding_model = select_backend(SentenceTransformer(params["embedding_model"]))
except: # noqa: E722
embedding_model = BaseEmbedder()
Expand All @@ -4887,7 +4886,7 @@ def _create_model_from_files(
hdbscan_model=empty_cluster_model,
**params,
)
topic_model.topic_embeddings_ = tensors["topic_embeddings"].numpy()
topic_model.topic_embeddings_ = tensors["topic_embeddings"]
topic_model.topic_representations_ = {int(key): val for key, val in topics["topic_representations"].items()}
topic_model.topics_ = topics["topics"]
topic_model.topic_sizes_ = {int(key): val for key, val in topics["topic_sizes"].items()}
Expand Down Expand Up @@ -4924,7 +4923,7 @@ def _create_model_from_files(
# ClassTfidfTransformer
topic_model.ctfidf_model.reduce_frequent_words = ctfidf_config["ctfidf_model"]["reduce_frequent_words"]
topic_model.ctfidf_model.bm25_weighting = ctfidf_config["ctfidf_model"]["bm25_weighting"]
idf = ctfidf_tensors["diag"].numpy()
idf = ctfidf_tensors["diag"]
topic_model.ctfidf_model._idf_diag = sp.diags(
idf, offsets=0, shape=(len(idf), len(idf)), format="csr", dtype=np.float64
)
Expand Down
48 changes: 28 additions & 20 deletions bertopic/_save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def load_local_files(path):
torch_path = path / HF_WEIGHTS_NAME
if torch_path.is_file():
tensors = torch.load(torch_path, map_location="cpu")
tensors = {k: v.numpy() for k, v in tensors.items()}

# c-TF-IDF
try:
Expand All @@ -196,6 +197,7 @@ def load_local_files(path):
torch_path = path / CTFIDF_WEIGHTS_NAME
if torch_path.is_file():
ctfidf_tensors = torch.load(torch_path, map_location="cpu")
ctfidf_tensors = {k: v.numpy() for k, v in ctfidf_tensors.items()}
ctfidf_config = load_cfg_from_json(path / CTFIDF_CFG_NAME)
except: # noqa: E722
ctfidf_config, ctfidf_tensors = None, None
Expand Down Expand Up @@ -315,35 +317,43 @@ def generate_readme(model, repo_id: str):

def save_hf(model, save_directory, serialization: str):
"""Save topic embeddings, either safely (using safetensors) or using legacy pytorch."""
tensors = torch.from_numpy(np.array(model.topic_embeddings_, dtype=np.float32))
tensors = {"topic_embeddings": tensors}
tensors = np.array(model.topic_embeddings_, dtype=np.float32)

if serialization == "safetensors":
tensors = {"topic_embeddings": tensors}
save_safetensors(save_directory / HF_SAFE_WEIGHTS_NAME, tensors)
if serialization == "pytorch":
assert _has_torch, "`pip install pytorch` to save as bin"
tensors = {"topic_embeddings": torch.from_numpy(tensors)}
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)


def save_ctfidf(model, save_directory: str, serialization: str):
"""Save c-TF-IDF sparse matrix."""
indptr = torch.from_numpy(model.c_tf_idf_.indptr)
indices = torch.from_numpy(model.c_tf_idf_.indices)
data = torch.from_numpy(model.c_tf_idf_.data)
shape = torch.from_numpy(np.array(model.c_tf_idf_.shape))
diag = torch.from_numpy(np.array(model.ctfidf_model._idf_diag.data))
tensors = {
"indptr": indptr,
"indices": indices,
"data": data,
"shape": shape,
"diag": diag,
}
indptr = model.c_tf_idf_.indptr
indices = model.c_tf_idf_.indices
data = model.c_tf_idf_.data
shape = np.array(model.c_tf_idf_.shape)
diag = np.array(model.ctfidf_model._idf_diag.data)

if serialization == "safetensors":
tensors = {
"indptr": indptr,
"indices": indices,
"data": data,
"shape": shape,
"diag": diag,
}
save_safetensors(save_directory / CTFIDF_SAFE_WEIGHTS_NAME, tensors)
if serialization == "pytorch":
assert _has_torch, "`pip install pytorch` to save as .bin"
tensors = {
"indptr": torch.from_numpy(indptr),
"indices": torch.from_numpy(indices),
"data": torch.from_numpy(data),
"shape": torch.from_numpy(shape),
"diag": torch.from_numpy(diag),
}
torch.save(tensors, save_directory / CTFIDF_WEIGHTS_NAME)


Expand Down Expand Up @@ -511,20 +521,18 @@ def get_package_versions():
def load_safetensors(path):
"""Load safetensors and check whether it is installed."""
try:
import safetensors.torch
import safetensors
import safetensors.numpy

return safetensors.torch.load_file(path, device="cpu")
return safetensors.numpy.load_file(path)
except ImportError:
raise ValueError("`pip install safetensors` to load .safetensors")


def save_safetensors(path, tensors):
"""Save safetensors and check whether it is installed."""
try:
import safetensors.torch
import safetensors
import safetensors.numpy

safetensors.torch.save_file(tensors, path)
safetensors.numpy.save_file(tensors, path)
except ImportError:
raise ValueError("`pip install safetensors` to save as .safetensors")