diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 6a75171b..d512d1fd 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -1488,6 +1488,7 @@ def update_topics( vectorizer_model: CountVectorizer = None, ctfidf_model: ClassTfidfTransformer = None, representation_model: BaseRepresentation = None, + embeddings: np.ndarray = None, ): """Updates the topic representation by recalculating c-TF-IDF with the new parameters as defined in this function. @@ -1514,6 +1515,7 @@ def update_topics( representation_model: Pass in a model that fine-tunes the topic representations calculated through c-TF-IDF. Models from `bertopic.representation` are supported. + embeddings: Pre-trained document embeddings. Examples: In order to update the topic representation, you will need to first fit the topic @@ -1586,7 +1588,12 @@ def update_topics( if same_position and -1 not in topics and -1 in self.topics_: self.topic_embeddings_ = self.topic_embeddings_[1:] else: - self._create_topic_vectors() + if embeddings is not None: + # Use provided embeddings + self._create_topic_vectors(documents=documents, embeddings=embeddings) + else: + # Use self.embedding_model to calculate embeddings + self._create_topic_vectors() def get_topics(self, full: bool = False) -> Mapping[str, Tuple[str, float]]: """Return topics with top n words and their c-TF-IDF score.