Skip to content

Commit 97045b5

Browse files
authored
generate_text_embeddings streaming (#2241)
* addd streaming * fixes
1 parent bf1034a commit 97045b5

File tree

10 files changed

+723
-243
lines changed

10 files changed

+723
-243
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "generate_text_embeddings streaming"
4+
}

docs/examples_notebooks/index_migration_to_v1.ipynb

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -205,44 +205,23 @@
205205
"metadata": {},
206206
"outputs": [],
207207
"source": [
208-
"from graphrag.cache.factory import CacheFactory\n",
209208
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
210209
"from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings\n",
211-
"from graphrag.language_model.manager import ModelManager\n",
212-
"from graphrag.tokenizer.get_tokenizer import get_tokenizer\n",
210+
"from graphrag_cache import create_cache\n",
213211
"\n",
214-
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
215-
"# We'll construct the context and run this function flow directly to avoid everything else\n",
212+
"# We only need to re-run the embeddings workflow, to ensure that embeddings\n",
213+
"# for all required search fields are in place.\n",
214+
"# We pass in the table_provider created earlier so that generate_text_embeddings\n",
215+
"# reads the migrated tables we just wrote.\n",
216216
"\n",
217-
"model_config = config.get_language_model_config(config.embed_text.model_id)\n",
218217
"callbacks = NoopWorkflowCallbacks()\n",
219-
"cache_config = config.cache.model_dump() # type: ignore\n",
220-
"cache = CacheFactory().create_cache(\n",
221-
" cache_type=cache_config[\"type\"], # type: ignore\n",
222-
" **cache_config,\n",
223-
")\n",
224-
"model = ModelManager().get_or_create_embedding_model(\n",
225-
" name=\"text_embedding\",\n",
226-
" model_type=model_config.type,\n",
227-
" config=model_config,\n",
228-
" callbacks=callbacks,\n",
229-
" cache=cache,\n",
230-
")\n",
231-
"\n",
232-
"tokenizer = get_tokenizer(model_config)\n",
218+
"cache = create_cache(config.cache)\n",
233219
"\n",
234220
"await generate_text_embeddings(\n",
235-
" text_units=final_text_units,\n",
236-
" entities=final_entities,\n",
237-
" community_reports=final_community_reports,\n",
221+
" config=config,\n",
222+
" table_provider=table_provider,\n",
223+
" cache=cache,\n",
238224
" callbacks=callbacks,\n",
239-
" model=model,\n",
240-
" tokenizer=tokenizer,\n",
241-
" batch_size=config.embed_text.batch_size,\n",
242-
" batch_max_tokens=config.embed_text.batch_max_tokens,\n",
243-
" num_threads=model_config.concurrent_requests,\n",
244-
" vector_store_config=config.vector_store,\n",
245-
" embedded_fields=config.embed_text.names,\n",
246225
")"
247226
]
248227
}

packages/graphrag/graphrag/data_model/row_transformers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ def transform_entity_row(row: dict[str, Any]) -> dict[str, Any]:
8989
return row
9090

9191

92+
def transform_entity_row_for_embedding(
93+
row: dict[str, Any],
94+
) -> dict[str, Any]:
95+
"""Add a title_description column for embedding generation."""
96+
title = row.get("title") or ""
97+
description = row.get("description") or ""
98+
row["title_description"] = f"{title}:{description}"
99+
return row
100+
101+
92102
# -- relationships (mirrors relationships_typed) --------------------------
93103

94104

Lines changed: 102 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
# Copyright (c) 2024 Microsoft Corporation.
1+
# Copyright (C) 2026 Microsoft
22
# Licensed under the MIT License
33

4-
"""A module containing embed_text method definition."""
4+
"""Streaming text embedding operation."""
55

66
import logging
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Any
88

99
import numpy as np
10-
import pandas as pd
1110
from graphrag_llm.tokenizer import Tokenizer
11+
from graphrag_storage.tables.table import Table
1212
from graphrag_vectors import VectorStore, VectorStoreDocument
1313

1414
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@@ -21,7 +21,7 @@
2121

2222

2323
async def embed_text(
24-
input: pd.DataFrame,
24+
input_table: Table,
2525
callbacks: WorkflowCallbacks,
2626
model: "LLMEmbedding",
2727
tokenizer: Tokenizer,
@@ -31,59 +31,116 @@ async def embed_text(
3131
num_threads: int,
3232
vector_store: VectorStore,
3333
id_column: str = "id",
34-
):
35-
"""Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector."""
36-
if embed_column not in input.columns:
37-
msg = f"Column {embed_column} not found in input dataframe with columns {input.columns}"
38-
raise ValueError(msg)
39-
if id_column not in input.columns:
40-
msg = f"Column {id_column} not found in input dataframe with columns {input.columns}"
41-
raise ValueError(msg)
42-
34+
output_table: Table | None = None,
35+
) -> int:
36+
"""Embed text from a streaming Table into a vector store."""
4337
vector_store.create_index()
4438

45-
index = 0
39+
buffer: list[dict[str, Any]] = []
40+
total_rows = 0
4641

47-
all_results = []
42+
async for row in input_table:
43+
text = row.get(embed_column)
44+
if text is None:
45+
text = ""
4846

49-
num_total_batches = (input.shape[0] + batch_size - 1) // batch_size
50-
while batch_size * index < input.shape[0]:
51-
logger.info(
52-
"uploading text embeddings batch %d/%d of size %d to vector store",
53-
index + 1,
54-
num_total_batches,
55-
batch_size,
56-
)
57-
batch = input.iloc[batch_size * index : batch_size * (index + 1)]
58-
texts: list[str] = batch[embed_column].tolist()
59-
ids: list[str] = batch[id_column].tolist()
60-
result = await run_embed_text(
61-
texts,
47+
buffer.append({
48+
id_column: row[id_column],
49+
embed_column: text,
50+
})
51+
52+
if len(buffer) >= batch_size:
53+
total_rows += await _flush_embedding_buffer(
54+
buffer,
55+
embed_column,
56+
id_column,
57+
callbacks,
58+
model,
59+
tokenizer,
60+
batch_size,
61+
batch_max_tokens,
62+
num_threads,
63+
vector_store,
64+
output_table,
65+
)
66+
buffer.clear()
67+
68+
if buffer:
69+
total_rows += await _flush_embedding_buffer(
70+
buffer,
71+
embed_column,
72+
id_column,
6273
callbacks,
6374
model,
6475
tokenizer,
6576
batch_size,
6677
batch_max_tokens,
6778
num_threads,
79+
vector_store,
80+
output_table,
6881
)
69-
if result.embeddings:
70-
embeddings = [
71-
embedding for embedding in result.embeddings if embedding is not None
72-
]
73-
all_results.extend(embeddings)
74-
75-
vectors = result.embeddings or []
76-
documents: list[VectorStoreDocument] = []
77-
for doc_id, doc_vector in zip(ids, vectors, strict=True):
78-
if type(doc_vector) is np.ndarray:
79-
doc_vector = doc_vector.tolist()
80-
document = VectorStoreDocument(
82+
83+
return total_rows
84+
85+
86+
async def _flush_embedding_buffer(
87+
buffer: list[dict[str, Any]],
88+
embed_column: str,
89+
id_column: str,
90+
callbacks: WorkflowCallbacks,
91+
model: "LLMEmbedding",
92+
tokenizer: Tokenizer,
93+
batch_size: int,
94+
batch_max_tokens: int,
95+
num_threads: int,
96+
vector_store: VectorStore,
97+
output_table: Table | None,
98+
) -> int:
99+
"""Embed a buffer of rows and load results into the vector store."""
100+
texts: list[str] = [row[embed_column] for row in buffer]
101+
ids: list[str] = [row[id_column] for row in buffer]
102+
103+
result = await run_embed_text(
104+
texts,
105+
callbacks,
106+
model,
107+
tokenizer,
108+
batch_size,
109+
batch_max_tokens,
110+
num_threads,
111+
)
112+
113+
vectors = result.embeddings or []
114+
skipped = 0
115+
documents: list[VectorStoreDocument] = []
116+
for doc_id, doc_vector in zip(ids, vectors, strict=True):
117+
if doc_vector is None:
118+
skipped += 1
119+
continue
120+
if type(doc_vector) is np.ndarray:
121+
doc_vector = doc_vector.tolist()
122+
documents.append(
123+
VectorStoreDocument(
81124
id=doc_id,
82125
vector=doc_vector,
83126
)
84-
documents.append(document)
127+
)
128+
129+
vector_store.load_documents(documents)
130+
131+
if skipped > 0:
132+
logger.warning(
133+
"Skipped %d rows with None embeddings out of %d",
134+
skipped,
135+
len(buffer),
136+
)
85137

86-
vector_store.load_documents(documents)
87-
index += 1
138+
if output_table is not None:
139+
for doc_id, doc_vector in zip(ids, vectors, strict=True):
140+
if doc_vector is None:
141+
continue
142+
if type(doc_vector) is np.ndarray:
143+
doc_vector = doc_vector.tolist()
144+
await output_table.write({"id": doc_id, "embedding": doc_vector})
88145

89-
return all_results
146+
return len(buffer)

0 commit comments

Comments
 (0)