diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 90379b86..6ae842e2 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -370,6 +370,7 @@ def fit_transform( embeddings: np.ndarray = None, images: List[str] = None, y: Union[List[int], np.ndarray] = None, + nr_repr_docs: int = None, ) -> Tuple[List[int], Union[np.ndarray, None]]: """Fit the models on a collection of documents, generate topics, and return the probabilities and topic per document. @@ -381,6 +382,7 @@ def fit_transform( images: A list of paths to the images to fit on or the images themselves y: The target class for (semi)-supervised modeling. Use -1 if no class for a specific instance is specified. + nr_repr_docs: [optional] Number of representative docs to keep for each topic. Returns: predictions: Topic predictions for each documents @@ -487,7 +489,7 @@ def fit_transform( custom_documents = self._reduce_topics(custom_documents) # Save the top 3 most representative documents per topic - self._save_representative_docs(custom_documents) + self._save_representative_docs(custom_documents, nr_repr_docs=nr_repr_docs) else: # Extract topics by calculating c-TF-IDF self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose) @@ -497,7 +499,7 @@ def fit_transform( documents = self._reduce_topics(documents) # Save the top 3 most representative documents per topic - self._save_representative_docs(documents) + self._save_representative_docs(documents, nr_repr_docs=nr_repr_docs) # In the case of zero-shot topics, probability will come from cosine similarity, # and the HDBSCAN model will be removed @@ -2077,6 +2079,7 @@ def merge_topics( docs: List[str], topics_to_merge: List[Union[Iterable[int], int]], images: List[str] = None, + nr_repr_docs: int = None, ) -> None: """Arguments: docs: The documents you used when calling either `fit` or `fit_transform` @@ -2087,6 +2090,7 @@ def merge_topics( separately merge topics 3 and 4. images: A list of paths to the images used when calling either `fit` or `fit_transform`. + nr_repr_docs: [optional] Number of representative docs to keep for each topic. Examples: If you want to merge topics 1, 2, and 3: @@ -2147,7 +2151,7 @@ def merge_topics( documents = self._sort_mappings_by_frequency(documents) self._extract_topics(documents, mappings=mappings) self._update_topic_size(documents) - self._save_representative_docs(documents) + self._save_representative_docs(documents, nr_repr_docs=nr_repr_docs) self.probabilities_ = self._map_probabilities(self.probabilities_) def reduce_topics( @@ -2156,6 +2160,7 @@ def reduce_topics( nr_topics: Union[int, str] = 20, images: List[str] = None, use_ctfidf: bool = False, + nr_repr_docs: int = None, ) -> None: """Reduce the number of topics to a fixed number of topics or automatically. @@ -2176,6 +2181,7 @@ def reduce_topics( `fit` or `fit_transform` use_ctfidf: Whether to calculate distances between topics based on c-TF-IDF embeddings. If False, the embeddings from the embedding model are used. + nr_repr_docs: [optional] Number of representative docs to keep for each topic. Updates: topics_ : Assigns topics to their merged representations. @@ -2212,7 +2218,7 @@ def reduce_topics( # Reduce number of topics documents = self._reduce_topics(documents, use_ctfidf) self._merged_topics = None - self._save_representative_docs(documents) + self._save_representative_docs(documents, nr_repr_docs=nr_repr_docs) self.probabilities_ = self._map_probabilities(self.probabilities_) return self @@ -3993,21 +3999,22 @@ def _extract_topics( if verbose: logger.info("Representation - Completed \u2713") - def _save_representative_docs(self, documents: pd.DataFrame): + def _save_representative_docs(self, documents: pd.DataFrame, nr_repr_docs: int): """Save the 3 most representative docs per topic. Arguments: documents: Dataframe with documents and their corresponding IDs + nr_repr_docs: The number of representative documents to extract per topic Updates: - self.representative_docs_: Populate each topic with 3 representative docs + self.representative_docs_: Populate each topic with {nr_repr_docs} representative docs """ repr_docs, _, _, _ = self._extract_representative_docs( self.c_tf_idf_, documents, self.topic_representations_, nr_samples=500, - nr_repr_docs=3, + nr_repr_docs=nr_repr_docs, ) self.representative_docs_ = repr_docs