1- # Copyright (c) 2024 Microsoft Corporation.
1+ # Copyright (C) 2026 Microsoft
22# Licensed under the MIT License
33
4- """A module containing embed_text method definition ."""
4+ """Streaming text embedding operation ."""
55
66import logging
7- from typing import TYPE_CHECKING
7+ from typing import TYPE_CHECKING , Any
88
99import numpy as np
10- import pandas as pd
1110from graphrag_llm .tokenizer import Tokenizer
11+ from graphrag_storage .tables .table import Table
1212from graphrag_vectors import VectorStore , VectorStoreDocument
1313
1414from graphrag .callbacks .workflow_callbacks import WorkflowCallbacks
2121
2222
2323async def embed_text (
24- input : pd . DataFrame ,
24+ input_table : Table ,
2525 callbacks : WorkflowCallbacks ,
2626 model : "LLMEmbedding" ,
2727 tokenizer : Tokenizer ,
@@ -31,59 +31,116 @@ async def embed_text(
3131 num_threads : int ,
3232 vector_store : VectorStore ,
3333 id_column : str = "id" ,
34- ):
35- """Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector."""
36- if embed_column not in input .columns :
37- msg = f"Column { embed_column } not found in input dataframe with columns { input .columns } "
38- raise ValueError (msg )
39- if id_column not in input .columns :
40- msg = f"Column { id_column } not found in input dataframe with columns { input .columns } "
41- raise ValueError (msg )
42-
34+ output_table : Table | None = None ,
35+ ) -> int :
36+ """Embed text from a streaming Table into a vector store."""
4337 vector_store .create_index ()
4438
45- index = 0
39+ buffer : list [dict [str , Any ]] = []
40+ total_rows = 0
4641
47- all_results = []
42+ async for row in input_table :
43+ text = row .get (embed_column )
44+ if text is None :
45+ text = ""
4846
49- num_total_batches = (input .shape [0 ] + batch_size - 1 ) // batch_size
50- while batch_size * index < input .shape [0 ]:
51- logger .info (
52- "uploading text embeddings batch %d/%d of size %d to vector store" ,
53- index + 1 ,
54- num_total_batches ,
55- batch_size ,
56- )
57- batch = input .iloc [batch_size * index : batch_size * (index + 1 )]
58- texts : list [str ] = batch [embed_column ].tolist ()
59- ids : list [str ] = batch [id_column ].tolist ()
60- result = await run_embed_text (
61- texts ,
47+ buffer .append ({
48+ id_column : row [id_column ],
49+ embed_column : text ,
50+ })
51+
52+ if len (buffer ) >= batch_size :
53+ total_rows += await _flush_embedding_buffer (
54+ buffer ,
55+ embed_column ,
56+ id_column ,
57+ callbacks ,
58+ model ,
59+ tokenizer ,
60+ batch_size ,
61+ batch_max_tokens ,
62+ num_threads ,
63+ vector_store ,
64+ output_table ,
65+ )
66+ buffer .clear ()
67+
68+ if buffer :
69+ total_rows += await _flush_embedding_buffer (
70+ buffer ,
71+ embed_column ,
72+ id_column ,
6273 callbacks ,
6374 model ,
6475 tokenizer ,
6576 batch_size ,
6677 batch_max_tokens ,
6778 num_threads ,
79+ vector_store ,
80+ output_table ,
6881 )
69- if result .embeddings :
70- embeddings = [
71- embedding for embedding in result .embeddings if embedding is not None
72- ]
73- all_results .extend (embeddings )
74-
75- vectors = result .embeddings or []
76- documents : list [VectorStoreDocument ] = []
77- for doc_id , doc_vector in zip (ids , vectors , strict = True ):
78- if type (doc_vector ) is np .ndarray :
79- doc_vector = doc_vector .tolist ()
80- document = VectorStoreDocument (
82+
83+ return total_rows
84+
85+
86+ async def _flush_embedding_buffer (
87+ buffer : list [dict [str , Any ]],
88+ embed_column : str ,
89+ id_column : str ,
90+ callbacks : WorkflowCallbacks ,
91+ model : "LLMEmbedding" ,
92+ tokenizer : Tokenizer ,
93+ batch_size : int ,
94+ batch_max_tokens : int ,
95+ num_threads : int ,
96+ vector_store : VectorStore ,
97+ output_table : Table | None ,
98+ ) -> int :
99+ """Embed a buffer of rows and load results into the vector store."""
100+ texts : list [str ] = [row [embed_column ] for row in buffer ]
101+ ids : list [str ] = [row [id_column ] for row in buffer ]
102+
103+ result = await run_embed_text (
104+ texts ,
105+ callbacks ,
106+ model ,
107+ tokenizer ,
108+ batch_size ,
109+ batch_max_tokens ,
110+ num_threads ,
111+ )
112+
113+ vectors = result .embeddings or []
114+ skipped = 0
115+ documents : list [VectorStoreDocument ] = []
116+ for doc_id , doc_vector in zip (ids , vectors , strict = True ):
117+ if doc_vector is None :
118+ skipped += 1
119+ continue
120+ if type (doc_vector ) is np .ndarray :
121+ doc_vector = doc_vector .tolist ()
122+ documents .append (
123+ VectorStoreDocument (
81124 id = doc_id ,
82125 vector = doc_vector ,
83126 )
84- documents .append (document )
127+ )
128+
129+ vector_store .load_documents (documents )
130+
131+ if skipped > 0 :
132+ logger .warning (
133+ "Skipped %d rows with None embeddings out of %d" ,
134+ skipped ,
135+ len (buffer ),
136+ )
85137
86- vector_store .load_documents (documents )
87- index += 1
138+ if output_table is not None :
139+ for doc_id , doc_vector in zip (ids , vectors , strict = True ):
140+ if doc_vector is None :
141+ continue
142+ if type (doc_vector ) is np .ndarray :
143+ doc_vector = doc_vector .tolist ()
144+ await output_table .write ({"id" : doc_id , "embedding" : doc_vector })
88145
89- return all_results
146+ return len ( buffer )
0 commit comments