diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 8c27ba14..d584a598 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -494,6 +494,19 @@ def fit_transform( documents, embeddings, assigned_documents, assigned_embeddings ) else: + # Update topic id to zeroshot topic id mapping + self._topic_id_to_zeroshot_topic_idx = { + new_topic_id: zeroshot_topic_id + for new_topic_id, zeroshot_topic_id in enumerate(set(assigned_documents.Topic)) + } + + # All documents matches zero-shot topics + documents = assigned_documents + embeddings = assigned_embeddings + + # Update topic sizes + self._update_topic_size(documents) + # All documents matches zero-shot topics documents = assigned_documents embeddings = assigned_embeddings