diff --git a/.semversioner/next-release/patch-20260212002508389038.json b/.semversioner/next-release/patch-20260212002508389038.json new file mode 100644 index 0000000000..64c07d115f --- /dev/null +++ b/.semversioner/next-release/patch-20260212002508389038.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "add streamming to the two first workflows" +} diff --git a/packages/graphrag-input/graphrag_input/csv.py b/packages/graphrag-input/graphrag_input/csv.py index 6c0f51dd3a..e041bff275 100644 --- a/packages/graphrag-input/graphrag_input/csv.py +++ b/packages/graphrag-input/graphrag_input/csv.py @@ -5,12 +5,18 @@ import csv import logging +import sys from graphrag_input.structured_file_reader import StructuredFileReader from graphrag_input.text_document import TextDocument logger = logging.getLogger(__name__) +try: + csv.field_size_limit(sys.maxsize) +except OverflowError: + csv.field_size_limit(100 * 1024 * 1024) + class CSVFileReader(StructuredFileReader): """Reader implementation for csv files.""" diff --git a/packages/graphrag-storage/graphrag_storage/file_storage.py b/packages/graphrag-storage/graphrag_storage/file_storage.py index 547659abcd..7eb89dcc2e 100644 --- a/packages/graphrag-storage/graphrag_storage/file_storage.py +++ b/packages/graphrag-storage/graphrag_storage/file_storage.py @@ -144,6 +144,10 @@ async def get_creation_date(self, key: str) -> str: return get_timestamp_formatted_with_local_tz(creation_time_utc) + def get_path(self, key: str) -> Path: + """Get the full file path for a key (for streaming access).""" + return _join_path(self._base_dir, key) + def _join_path(file_path: Path, file_name: str) -> Path: """Join a path and a file. Independent of the OS.""" diff --git a/packages/graphrag-storage/graphrag_storage/tables/__init__.py b/packages/graphrag-storage/graphrag_storage/tables/__init__.py index 0210d935f3..9f95b076ca 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/__init__.py +++ b/packages/graphrag-storage/graphrag_storage/tables/__init__.py @@ -3,6 +3,7 @@ """Table provider module for GraphRAG storage.""" +from .table import Table from .table_provider import TableProvider -__all__ = ["TableProvider"] +__all__ = ["Table", "TableProvider"] diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py new file mode 100644 index 0000000000..0c55b17ea3 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table.py @@ -0,0 +1,165 @@ +# Copyright (c) 2025 Microsoft Corporation. +# Licensed under the MIT Licenses + +"""A CSV-based implementation of the Table abstraction for streaming row access.""" + +from __future__ import annotations + +import csv +import inspect +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import aiofiles + +from graphrag_storage.file_storage import FileStorage +from graphrag_storage.tables.table import RowTransformer, Table + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + from io import TextIOWrapper + + from graphrag_storage import Storage + +try: + csv.field_size_limit(sys.maxsize) +except OverflowError: + csv.field_size_limit(100 * 1024 * 1024) + + +def _identity(row: dict[str, Any]) -> Any: + """Return row unchanged (default transformer).""" + return row + + +def _apply_transformer(transformer: RowTransformer, row: dict[str, Any]) -> Any: + """Apply transformer to row, handling both callables and classes. + + If transformer is a class (e.g., Pydantic model), calls it with **row. + Otherwise calls it with row as positional argument. + """ + if inspect.isclass(transformer): + return transformer(**row) + return transformer(row) + + +class CSVTable(Table): + """Row-by-row streaming interface for CSV tables.""" + + def __init__( + self, + storage: Storage, + table_name: str, + transformer: RowTransformer | None = None, + truncate: bool = True, + encoding: str = "utf-8", + ): + """Initialize with storage backend and table name. + + Args: + storage: Storage instance (File, Blob, or Cosmos) + table_name: Name of the table (e.g., "documents") + transformer: Optional callable to transform each row before + yielding. Receives a dict, returns a transformed dict. + Defaults to identity (no transformation). + truncate: If True (default), truncate file on first write. + If False, append to existing file. + encoding: Character encoding for reading/writing CSV files. + Defaults to "utf-8". + """ + self._storage = storage + self._table_name = table_name + self._file_key = f"{table_name}.csv" + self._transformer = transformer or _identity + self._truncate = truncate + self._encoding = encoding + self._write_file: TextIOWrapper | None = None + self._writer: csv.DictWriter | None = None + self._header_written = False + + def __aiter__(self) -> AsyncIterator[Any]: + """Iterate through rows one at a time. + + The transformer is applied to each row before yielding. + If transformer is a Pydantic model, yields model instances. + + Yields + ------ + Any: + Each row as dict or transformed type (e.g., Pydantic model). + """ + return self._aiter_impl() + + async def _aiter_impl(self) -> AsyncIterator[Any]: + """Implement async iteration over rows.""" + if isinstance(self._storage, FileStorage): + file_path = self._storage.get_path(self._file_key) + with Path.open(file_path, "r", encoding=self._encoding) as f: + reader = csv.DictReader(f) + for row in reader: + yield _apply_transformer(self._transformer, row) + + async def length(self) -> int: + """Return the number of rows in the table.""" + if isinstance(self._storage, FileStorage): + file_path = self._storage.get_path(self._file_key) + count = 0 + async with aiofiles.open(file_path, "rb") as f: + while True: + chunk = await f.read(65536) + if not chunk: + break + count += chunk.count(b"\n") + return count - 1 + return 0 + + async def has(self, row_id: str) -> bool: + """Check if row with given ID exists.""" + async for row in self: + # Handle both dict and object (e.g., Pydantic model) + if isinstance(row, dict): + if row.get("id") == row_id: + return True + elif getattr(row, "id", None) == row_id: + return True + return False + + async def write(self, row: dict[str, Any]) -> None: + """Write a single row to the CSV file. + + On first write, opens the file. If truncate=True, overwrites any existing + file and writes header. If truncate=False, appends to existing file + (skips header if file exists). + + Args + ---- + row: Dictionary representing a single row to write. + """ + if isinstance(self._storage, FileStorage) and self._write_file is None: + file_path = self._storage.get_path(self._file_key) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_exists = file_path.exists() and file_path.stat().st_size > 0 + mode = "w" if self._truncate else "a" + write_header = self._truncate or not file_exists + self._write_file = Path.open( + file_path, mode, encoding=self._encoding, newline="" + ) + self._writer = csv.DictWriter(self._write_file, fieldnames=list(row.keys())) + if write_header: + self._writer.writeheader() + self._header_written = write_header + + if self._writer is not None: + self._writer.writerow(row) + + async def close(self) -> None: + """Flush buffered writes and release resources. + + Closes the file handle if writing was performed. + """ + if self._write_file is not None: + self._write_file.close() + self._write_file = None + self._writer = None + self._header_written = False diff --git a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py index 5de021b8a5..2561bde0d8 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/csv_table_provider.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 Microsoft Corporation. +# Copyright (c) 2025 Microsoft Corporation. # Licensed under the MIT License """CSV-based table provider implementation.""" @@ -9,7 +9,10 @@ import pandas as pd +from graphrag_storage.file_storage import FileStorage from graphrag_storage.storage import Storage +from graphrag_storage.tables.csv_table import CSVTable +from graphrag_storage.tables.table import RowTransformer from graphrag_storage.tables.table_provider import TableProvider logger = logging.getLogger(__name__) @@ -32,6 +35,9 @@ def __init__(self, storage: Storage, **kwargs) -> None: **kwargs: Any Additional keyword arguments (currently unused). """ + if not isinstance(storage, FileStorage): + msg = "CSVTableProvider only works with FileStorage backends for now. " + raise TypeError(msg) self._storage = storage async def read_dataframe(self, table_name: str) -> pd.DataFrame: @@ -108,3 +114,27 @@ def list(self) -> list[str]: file.replace(".csv", "") for file in self._storage.find(re.compile(r"\.csv$")) ] + + def open( + self, + table_name: str, + transformer: RowTransformer | None = None, + truncate: bool = True, + encoding: str = "utf-8", + ) -> CSVTable: + """Open table for streaming. + + Args: + table_name: Name of the table to open + transformer: Optional callable to transform each row + truncate: If True, truncate file on first write + encoding: Character encoding for reading/writing CSV files. + Defaults to "utf-8". + """ + return CSVTable( + self._storage, + table_name, + transformer=transformer, + truncate=truncate, + encoding=encoding, + ) diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table.py new file mode 100644 index 0000000000..2c62713987 --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table.py @@ -0,0 +1,150 @@ +# Copyright (C) 2025 Microsoft +# Licensed under the MIT License + +"""A Parquet-based implementation of the Table abstraction with simulated streaming.""" + +from __future__ import annotations + +import inspect +from io import BytesIO +from typing import TYPE_CHECKING, Any, cast + +import pandas as pd + +from graphrag_storage.tables.table import RowTransformer, Table + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from graphrag_storage.storage import Storage + + +def _identity(row: dict[str, Any]) -> Any: + """Return row unchanged (default transformer).""" + return row + + +def _apply_transformer(transformer: RowTransformer, row: dict[str, Any]) -> Any: + """Apply transformer to row, handling both callables and classes. + + If transformer is a class (e.g., Pydantic model), calls it with **row. + Otherwise calls it with row as positional argument. + """ + if inspect.isclass(transformer): + return transformer(**row) + return transformer(row) + + +class ParquetTable(Table): + """Simulated streaming interface for Parquet tables. + + Parquet format doesn't support true row-by-row streaming, so this + implementation simulates streaming via: + - Read: Loads DataFrame, yields rows via iterrows() + - Write: Accumulates rows in memory, writes all at once on close() + + This provides API compatibility with CSVTable while maintaining + Parquet's performance characteristics for bulk operations. + """ + + def __init__( + self, + storage: Storage, + table_name: str, + transformer: RowTransformer | None = None, + truncate: bool = True, + ): + """Initialize with storage backend and table name. + + Args: + storage: Storage instance (File, Blob, or Cosmos) + table_name: Name of the table (e.g., "documents") + transformer: Optional callable to transform each row before + yielding. Receives a dict, returns a transformed dict. + Defaults to identity (no transformation). + truncate: If True (default), overwrite file on close. + If False, append to existing file. + """ + self._storage = storage + self._table_name = table_name + self._file_key = f"{table_name}.parquet" + self._transformer = transformer or _identity + self._truncate = truncate + self._df: pd.DataFrame | None = None + self._write_rows: list[dict[str, Any]] = [] + + def __aiter__(self) -> AsyncIterator[Any]: + """Iterate through rows one at a time. + + Loads the entire DataFrame on first iteration, then yields rows + one at a time with the transformer applied. + + Yields + ------ + Any: + Each row as dict or transformed type (e.g., Pydantic model). + """ + return self._aiter_impl() + + async def _aiter_impl(self) -> AsyncIterator[Any]: + """Implement async iteration over rows.""" + if self._df is None: + if await self._storage.has(self._file_key): + data = await self._storage.get(self._file_key, as_bytes=True) + self._df = pd.read_parquet(BytesIO(data)) + else: + self._df = pd.DataFrame() + + for _, row in self._df.iterrows(): + row_dict = cast("dict[str, Any]", row.to_dict()) + yield _apply_transformer(self._transformer, row_dict) + + async def length(self) -> int: + """Return the number of rows in the table.""" + if self._df is None: + if await self._storage.has(self._file_key): + data = await self._storage.get(self._file_key, as_bytes=True) + self._df = pd.read_parquet(BytesIO(data)) + else: + return 0 + return len(self._df) + + async def has(self, row_id: str) -> bool: + """Check if row with given ID exists.""" + async for row in self: + if isinstance(row, dict): + if row.get("id") == row_id: + return True + elif getattr(row, "id", None) == row_id: + return True + return False + + async def write(self, row: dict[str, Any]) -> None: + """Accumulate a single row for later batch write. + + Rows are stored in memory and written to Parquet format + when close() is called. + + Args + ---- + row: Dictionary representing a single row to write. + """ + self._write_rows.append(row) + + async def close(self) -> None: + """Flush accumulated rows to Parquet file and release resources. + + Converts all accumulated rows to a DataFrame and writes + to storage as a Parquet file. If truncate=False and file exists, + appends to existing data. + """ + if self._write_rows: + new_df = pd.DataFrame(self._write_rows) + if not self._truncate and await self._storage.has(self._file_key): + existing_data = await self._storage.get(self._file_key, as_bytes=True) + existing_df = pd.read_parquet(BytesIO(existing_data)) + new_df = pd.concat([existing_df, new_df], ignore_index=True) + await self._storage.set(self._file_key, new_df.to_parquet()) + self._write_rows = [] + + self._df = None diff --git a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py index 74f63660dc..b6c6f251bc 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/parquet_table_provider.py @@ -10,6 +10,8 @@ import pandas as pd from graphrag_storage.storage import Storage +from graphrag_storage.tables.parquet_table import ParquetTable +from graphrag_storage.tables.table import RowTransformer, Table from graphrag_storage.tables.table_provider import TableProvider logger = logging.getLogger(__name__) @@ -106,3 +108,31 @@ def list(self) -> list[str]: file.replace(".parquet", "") for file in self._storage.find(re.compile(r"\.parquet$")) ] + + def open( + self, + table_name: str, + transformer: RowTransformer | None = None, + truncate: bool = True, + ) -> Table: + """Open a table for streaming row operations. + + Returns a ParquetTable that simulates streaming by loading the + DataFrame and iterating rows, or accumulating writes for batch output. + + Args + ---- + table_name: str + The name of the table to open. + transformer: RowTransformer | None + Optional callable to transform each row on read. + truncate: bool + If True (default), overwrite existing file on close. + If False, append new rows to existing file. + + Returns + ------- + Table: + A ParquetTable instance for row-by-row access. + """ + return ParquetTable(self._storage, table_name, transformer, truncate=truncate) diff --git a/packages/graphrag-storage/graphrag_storage/tables/table.py b/packages/graphrag-storage/graphrag_storage/tables/table.py new file mode 100644 index 0000000000..d845fb8b3f --- /dev/null +++ b/packages/graphrag-storage/graphrag_storage/tables/table.py @@ -0,0 +1,125 @@ +# Copyright (C) 2025 Microsoft +# Licensed under the MIT License + +"""Table abstraction for streaming row-by-row access.""" + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Callable +from types import TracebackType +from typing import Any + +from typing_extensions import Self + +RowTransformer = Callable[[dict[str, Any]], Any] + + +class Table(ABC): + """Abstract base class for streaming table access. + + Provides row-by-row iteration and write capabilities for memory-efficient + processing of large datasets. Supports async context manager protocol for + automatic resource cleanup. + + Examples + -------- + Reading rows as dicts: + >>> async with ( + ... provider.open( + ... "documents" + ... ) as table + ... ): + ... async for ( + ... row + ... ) in table: + ... process(row) + + With Pydantic model as transformer: + >>> async with ( + ... provider.open( + ... "entities", + ... Entity, + ... ) as table + ... ): + ... async for entity in table: # yields Entity instances + ... print( + ... entity.name + ... ) + """ + + @abstractmethod + def __aiter__(self) -> AsyncIterator[Any]: + """Yield rows asynchronously, transformed if transformer provided. + + Yields + ------ + Any: + Each row, either as dict or transformed type (e.g., Pydantic model). + """ + ... + + @abstractmethod + async def length(self) -> int: + """Return number of rows asynchronously. + + Returns + ------- + int: + Number of rows in the table. + """ + + @abstractmethod + async def has(self, row_id: str) -> bool: + """Check if a row with the given ID exists. + + Args + ---- + row_id: The ID value to search for. + + Returns + ------- + bool: + True if a row with matching ID exists. + """ + + @abstractmethod + async def write(self, row: dict[str, Any]) -> None: + """Write a single row to the table. + + Args + ---- + row: Dictionary representing a single row to write. + """ + + @abstractmethod + async def close(self) -> None: + """Flush buffered writes and release resources. + + This method is called automatically when exiting the async context + manager, but can also be called explicitly. + """ + + async def __aenter__(self) -> Self: + """Enter async context manager. + + Returns + ------- + Table: + Self for context manager usage. + """ + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit async context manager, ensuring close() is called. + + Args + ---- + exc_type: Exception type if an exception occurred + exc_val: Exception value if an exception occurred + exc_tb: Exception traceback if an exception occurred + """ + await self.close() diff --git a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py index 07a86c3119..39965839f8 100644 --- a/packages/graphrag-storage/graphrag_storage/tables/table_provider.py +++ b/packages/graphrag-storage/graphrag_storage/tables/table_provider.py @@ -8,6 +8,8 @@ import pandas as pd +from graphrag_storage.tables.table import RowTransformer, Table + class TableProvider(ABC): """Provide a table-based storage interface with support for DataFrames and row dictionaries.""" @@ -73,3 +75,28 @@ def list(self) -> list[str]: list[str]: List of table names (without file extensions). """ + + @abstractmethod + def open( + self, + table_name: str, + transformer: RowTransformer | None = None, + truncate: bool = True, + ) -> Table: # Returns Table instance + """Open a table for row-by-row streaming operations. + + Args + ---- + table_name: str + The name of the table to open. + transformer: RowTransformer | None + Optional transformer function to apply to each row. + truncate: bool + If True (default), truncate existing file on first write. + If False, append rows to existing file (DB-like behavior). + + Returns + ------- + Table: + A Table instance for streaming row operations. + """ diff --git a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py index 196ab3f1b6..a3d8964255 100644 --- a/packages/graphrag/graphrag/index/workflows/create_base_text_units.py +++ b/packages/graphrag/graphrag/index/workflows/create_base_text_units.py @@ -4,18 +4,17 @@ """A module containing run_workflow method definition.""" import logging -from typing import Any, cast +from typing import Any -import pandas as pd from graphrag_chunking.chunker import Chunker from graphrag_chunking.chunker_factory import create_chunker from graphrag_chunking.transformers import add_metadata from graphrag_input import TextDocument from graphrag_llm.tokenizer import Tokenizer +from graphrag_storage.tables.table import Table from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.data_model.data_reader import DataReader from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.utils.hashing import gen_sha512_hash @@ -31,89 +30,131 @@ async def run_workflow( ) -> WorkflowFunctionOutput: """All the steps to transform base text_units.""" logger.info("Workflow started: create_base_text_units") - reader = DataReader(context.output_table_provider) - documents = await reader.documents() tokenizer = get_tokenizer(encoding_model=config.chunking.encoding_model) chunker = create_chunker(config.chunking, tokenizer.encode, tokenizer.decode) - output = create_base_text_units( - documents, - context.callbacks, - tokenizer=tokenizer, - chunker=chunker, - prepend_metadata=config.chunking.prepend_metadata, - ) - await context.output_table_provider.write_dataframe("text_units", output) + async with ( + context.output_table_provider.open("documents") as documents_table, + context.output_table_provider.open("text_units") as text_units_table, + ): + total_rows = await documents_table.length() + sample_rows = await create_base_text_units( + documents_table, + text_units_table, + total_rows, + context.callbacks, + tokenizer=tokenizer, + chunker=chunker, + prepend_metadata=config.chunking.prepend_metadata, + ) logger.info("Workflow completed: create_base_text_units") - return WorkflowFunctionOutput(result=output) + return WorkflowFunctionOutput(result=sample_rows) -def create_base_text_units( - documents: pd.DataFrame, +async def create_base_text_units( + documents_table: Table, + text_units_table: Table, + total_rows: int, callbacks: WorkflowCallbacks, tokenizer: Tokenizer, chunker: Chunker, prepend_metadata: list[str] | None = None, -) -> pd.DataFrame: - """All the steps to transform base text_units.""" - documents.sort_values(by=["id"], ascending=[True], inplace=True) - - total_rows = len(documents) +) -> list[dict[str, Any]]: + """Transform documents into chunked text units via streaming read/write. + + Reads documents row-by-row from an async iterable and writes text units + directly to the output table, avoiding loading all data into memory. + + Args + ---- + documents_table: Table + Table instance for reading documents. Supports async iteration. + text_units_table: Table + Table instance for writing text units row by row. + total_rows: int + Total number of documents for progress reporting. + callbacks: WorkflowCallbacks + Callbacks for progress reporting. + tokenizer: Tokenizer + Tokenizer for measuring chunk token counts. + chunker: Chunker + Chunker instance for splitting document text. + prepend_metadata: list[str] | None + Optional list of metadata fields to prepend to + each chunk. + """ tick = progress_ticker(callbacks.progress, total_rows) - # Track progress of row-wise apply operation - logger.info("Starting chunking process for %d documents", total_rows) - - def chunker_with_logging(row: pd.Series, row_index: int) -> Any: - if prepend_metadata: - # create a standard text document for metadata plucking - # ignore any additional fields in case the input dataframe has extra columns - document = TextDocument( - id=row["id"], - title=row["title"], - text=row["text"], - creation_date=row["creation_date"], - raw_data=row["raw_data"], - ) - metadata = document.collect(prepend_metadata) - transformer = add_metadata( - metadata=metadata, line_delimiter=".\n" - ) # delim with . for back-compat older indexes - else: - transformer = None - - row["chunks"] = [ - chunk.text for chunk in chunker.chunk(row["text"], transform=transformer) - ] + logger.info( + "Starting chunking process for %d documents", + total_rows, + ) + doc_index = 0 + sample_rows: list[dict[str, Any]] = [] + sample_size = 5 + + async for doc in documents_table: + chunks = chunk_document(doc, chunker, prepend_metadata) + for chunk_text in chunks: + if chunk_text is None: + continue + row = { + "id": "", + "document_id": doc["id"], + "text": chunk_text, + "n_tokens": len(tokenizer.encode(chunk_text)), + } + row["id"] = gen_sha512_hash(row, ["text"]) + await text_units_table.write(row) + + if len(sample_rows) < sample_size: + sample_rows.append(row) + + doc_index += 1 tick() - logger.info("chunker progress: %d/%d", row_index + 1, total_rows) - return row - - text_units = documents.apply( - lambda row: chunker_with_logging(row, row.name), axis=1 - ) + logger.info( + "chunker progress: %d/%d", + doc_index, + total_rows, + ) - text_units = cast("pd.DataFrame", text_units[["id", "chunks"]]) - text_units = text_units.explode("chunks") - text_units.rename( - columns={ - "id": "document_id", - "chunks": "text", - }, - inplace=True, - ) + return sample_rows - text_units["id"] = text_units.apply( - lambda row: gen_sha512_hash(row, ["text"]), axis=1 - ) - # get a final token measurement - text_units["n_tokens"] = text_units["text"].apply( - lambda x: len(tokenizer.encode(x)) - ) - return cast( - "pd.DataFrame", text_units[text_units["text"].notna()].reset_index(drop=True) - ) +def chunk_document( + doc: dict[str, Any], + chunker: Chunker, + prepend_metadata: list[str] | None = None, +) -> list[str]: + """Chunk a single document row into text fragments. + + Args + ---- + doc: dict[str, Any] + A single document row as a dictionary. + chunker: Chunker + Chunker instance for splitting text. + prepend_metadata: list[str] | None + Optional metadata fields to prepend. + + Returns + ------- + list[str]: + List of chunk text strings. + """ + transformer = None + if prepend_metadata: + document = TextDocument( + id=doc["id"], + title=doc.get("title", ""), + text=doc["text"], + creation_date=doc.get("creation_date", ""), + raw_data=doc.get("raw_data"), + ) + metadata = document.collect(prepend_metadata) + transformer = add_metadata(metadata=metadata, line_delimiter=".\n") + + return [chunk.text for chunk in chunker.chunk(doc["text"], transform=transformer)] diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index ebce58b914..a78de2bb58 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -105,21 +105,22 @@ async def generate_text_embeddings( """All the steps to generate all embeddings.""" embedding_param_map = { text_unit_text_embedding: { - "data": text_units.loc[:, ["id", "text"]] + "data": text_units.loc[:, ["id", "text"]].fillna("") if text_units is not None else None, "embed_column": "text", }, entity_description_embedding: { - "data": entities.loc[:, ["id", "title", "description"]].assign( - title_description=lambda df: df["title"] + ":" + df["description"] - ) + "data": entities + .loc[:, ["id", "title", "description"]] + .fillna("") + .assign(title_description=lambda df: df["title"] + ":" + df["description"]) if entities is not None else None, "embed_column": "title_description", }, community_full_content_embedding: { - "data": community_reports.loc[:, ["id", "full_content"]] + "data": community_reports.loc[:, ["id", "full_content"]].fillna("") if community_reports is not None else None, "embed_column": "full_content", diff --git a/packages/graphrag/graphrag/index/workflows/load_input_documents.py b/packages/graphrag/graphrag/index/workflows/load_input_documents.py index 26166bb279..a2d7ba937e 100644 --- a/packages/graphrag/graphrag/index/workflows/load_input_documents.py +++ b/packages/graphrag/graphrag/index/workflows/load_input_documents.py @@ -8,6 +8,7 @@ import pandas as pd from graphrag_input import InputReader, create_input_reader +from graphrag_storage.tables.table import Table from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.typing.context import PipelineRunContext @@ -23,26 +24,37 @@ async def run_workflow( """Load and parse input documents into a standard format.""" input_reader = create_input_reader(config.input, context.input_storage) - output = await load_input_documents(input_reader) + async with ( + context.output_table_provider.open("documents") as documents_table, + ): + sample, total_count = await load_input_documents(input_reader, documents_table) - if len(output) == 0: - msg = "Error reading documents, please see logs." - logger.error(msg) - raise ValueError(msg) + if total_count == 0: + msg = "Error reading documents, please see logs." + logger.error(msg) + raise ValueError(msg) - logger.info("Final # of rows loaded: %s", len(output)) - context.stats.num_documents = len(output) + logger.info("Final # of rows loaded: %s", total_count) + context.stats.num_documents = total_count - await context.output_table_provider.write_dataframe("documents", output) + return WorkflowFunctionOutput(result=sample) - return WorkflowFunctionOutput(result=output) - -async def load_input_documents(input_reader: InputReader) -> pd.DataFrame: +async def load_input_documents( + input_reader: InputReader, documents_table: Table, sample_size: int = 5 +) -> tuple[pd.DataFrame, int]: """Load and parse input documents into a standard format.""" - documents = [asdict(doc) async for doc in input_reader] - documents = pd.DataFrame(documents) - documents["human_readable_id"] = documents.index - if "raw_data" not in documents.columns: - documents["raw_data"] = pd.Series(dtype="object") - return documents + sample: list[dict] = [] + idx = 0 + + async for doc in input_reader: + row = asdict(doc) + row["human_readable_id"] = idx + if "raw_data" not in row: + row["raw_data"] = None + await documents_table.write(row) + if len(sample) < sample_size: + sample.append(row) + idx += 1 + + return pd.DataFrame(sample), idx diff --git a/packages/graphrag/graphrag/prompt_tune/loader/input.py b/packages/graphrag/graphrag/prompt_tune/loader/input.py index 0cfdb2299a..fb3be66744 100644 --- a/packages/graphrag/graphrag/prompt_tune/loader/input.py +++ b/packages/graphrag/graphrag/prompt_tune/loader/input.py @@ -3,6 +3,7 @@ """Input loading module.""" +import dataclasses import logging from typing import Any @@ -18,7 +19,7 @@ from graphrag.index.operations.embed_text.run_embed_text import ( run_embed_text, ) -from graphrag.index.workflows.create_base_text_units import create_base_text_units +from graphrag.index.workflows.create_base_text_units import chunk_document from graphrag.prompt_tune.defaults import ( LIMIT, N_SUBSET_MAX, @@ -58,12 +59,14 @@ async def load_docs_in_chunks( input_storage = create_storage(config.input_storage) input_reader = create_input_reader(config.input, input_storage) dataset = await input_reader.read_files() - chunks_df = create_base_text_units( - documents=pd.DataFrame(dataset), - callbacks=NoopWorkflowCallbacks(), - tokenizer=tokenizer, - chunker=chunker, - ) + + all_chunks: list[str] = [] + for doc in dataset: + doc_dict = dataclasses.asdict(doc) + chunks = chunk_document(doc_dict, chunker) + all_chunks.extend(chunks) + + chunks_df = pd.DataFrame({"text": all_chunks}) # Depending on the select method, build the dataset if limit <= 0 or limit > len(chunks_df): diff --git a/tests/unit/prompt_tune/__init__.py b/tests/unit/prompt_tune/__init__.py new file mode 100644 index 0000000000..4d4df03613 --- /dev/null +++ b/tests/unit/prompt_tune/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Microsoft +# Licensed under the MIT License + +"""Unit tests for prompt_tune module.""" diff --git a/tests/unit/prompt_tune/test_load_docs_in_chunks.py b/tests/unit/prompt_tune/test_load_docs_in_chunks.py new file mode 100644 index 0000000000..6268aef91e --- /dev/null +++ b/tests/unit/prompt_tune/test_load_docs_in_chunks.py @@ -0,0 +1,292 @@ +# Copyright (C) 2025 Microsoft +# Licensed under the MIT License + +"""Unit tests for load_docs_in_chunks function.""" + +import logging +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from graphrag.prompt_tune.loader.input import load_docs_in_chunks +from graphrag.prompt_tune.types import DocSelectionType + + +@dataclass +class MockTextDocument: + """Mock TextDocument for testing.""" + + id: str + text: str + title: str + creation_date: str + raw_data: dict[str, Any] | None = None + + +class MockTokenizer: + """Mock tokenizer for testing.""" + + def encode(self, text: str) -> list[int]: + """Encode text to tokens (simple char-based).""" + return [ord(c) for c in text] + + def decode(self, tokens: list[int]) -> str: + """Decode tokens to text.""" + return "".join(chr(t) for t in tokens) + + +@dataclass +class MockChunk: + """Mock chunk result.""" + + text: str + + +class MockChunker: + """Mock chunker for testing.""" + + def chunk(self, text: str, transform: Any = None) -> list[MockChunk]: + """Split text into sentence-like chunks.""" + sentences = [s.strip() for s in text.split(".") if s.strip()] + return [MockChunk(text=s + ".") for s in sentences] + + +class MockEmbeddingModel: + """Mock embedding model for testing.""" + + def __init__(self): + """Initialize with mock tokenizer.""" + self.tokenizer = MockTokenizer() + + +@pytest.fixture +def mock_config(): + """Create a mock GraphRagConfig.""" + config = MagicMock() + config.embed_text.embedding_model_id = "test-model" + config.embed_text.batch_size = 10 + config.embed_text.batch_max_tokens = 1000 + config.concurrent_requests = 1 + config.get_embedding_model_config.return_value = MagicMock() + return config + + +@pytest.fixture +def mock_logger(): + """Create a mock logger.""" + return logging.getLogger("test") + + +@pytest.fixture +def sample_documents(): + """Create sample documents for testing.""" + return [ + MockTextDocument( + id="doc1", + text="First sentence. Second sentence. Third sentence.", + title="Doc 1", + creation_date="2025-01-01", + ), + MockTextDocument( + id="doc2", + text="Another document. With content.", + title="Doc 2", + creation_date="2025-01-02", + ), + ] + + +class TestLoadDocsInChunks: + """Tests for load_docs_in_chunks function.""" + + @pytest.mark.asyncio + async def test_top_selection_returns_limited_chunks( + self, mock_config, mock_logger, sample_documents + ): + """Test TOP selection method returns the first N chunks.""" + mock_reader = AsyncMock() + mock_reader.read_files.return_value = sample_documents + + with ( + patch( + "graphrag.prompt_tune.loader.input.create_embedding", + return_value=MockEmbeddingModel(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_storage", + return_value=MagicMock(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_input_reader", + return_value=mock_reader, + ), + patch( + "graphrag.prompt_tune.loader.input.create_chunker", + return_value=MockChunker(), + ), + ): + result = await load_docs_in_chunks( + config=mock_config, + select_method=DocSelectionType.TOP, + limit=2, + logger=mock_logger, + ) + + assert len(result) == 2 + assert result[0] == "First sentence." + assert result[1] == "Second sentence." + + @pytest.mark.asyncio + async def test_random_selection_returns_correct_count( + self, mock_config, mock_logger, sample_documents + ): + """Test RANDOM selection method returns the correct number of chunks.""" + mock_reader = AsyncMock() + mock_reader.read_files.return_value = sample_documents + + with ( + patch( + "graphrag.prompt_tune.loader.input.create_embedding", + return_value=MockEmbeddingModel(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_storage", + return_value=MagicMock(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_input_reader", + return_value=mock_reader, + ), + patch( + "graphrag.prompt_tune.loader.input.create_chunker", + return_value=MockChunker(), + ), + ): + result = await load_docs_in_chunks( + config=mock_config, + select_method=DocSelectionType.RANDOM, + limit=3, + logger=mock_logger, + ) + + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_escapes_braces_in_output(self, mock_config, mock_logger): + """Test that curly braces are escaped for str.format() compatibility.""" + docs_with_braces = [ + MockTextDocument( + id="doc1", + text="Some {latex} content.", + title="Doc 1", + creation_date="2025-01-01", + ), + ] + + mock_reader = AsyncMock() + mock_reader.read_files.return_value = docs_with_braces + + with ( + patch( + "graphrag.prompt_tune.loader.input.create_embedding", + return_value=MockEmbeddingModel(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_storage", + return_value=MagicMock(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_input_reader", + return_value=mock_reader, + ), + patch( + "graphrag.prompt_tune.loader.input.create_chunker", + return_value=MockChunker(), + ), + ): + result = await load_docs_in_chunks( + config=mock_config, + select_method=DocSelectionType.TOP, + limit=1, + logger=mock_logger, + ) + + assert len(result) == 1 + assert "{{latex}}" in result[0] + + @pytest.mark.asyncio + async def test_limit_out_of_range_uses_default( + self, mock_config, mock_logger, sample_documents + ): + """Test that invalid limit falls back to default LIMIT.""" + mock_reader = AsyncMock() + mock_reader.read_files.return_value = sample_documents + + with ( + patch( + "graphrag.prompt_tune.loader.input.create_embedding", + return_value=MockEmbeddingModel(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_storage", + return_value=MagicMock(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_input_reader", + return_value=mock_reader, + ), + patch( + "graphrag.prompt_tune.loader.input.create_chunker", + return_value=MockChunker(), + ), + patch( + "graphrag.prompt_tune.loader.input.LIMIT", + 3, + ), + ): + result = await load_docs_in_chunks( + config=mock_config, + select_method=DocSelectionType.TOP, + limit=-1, + logger=mock_logger, + ) + + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_chunks_all_documents( + self, mock_config, mock_logger, sample_documents + ): + """Test that all documents are chunked correctly.""" + mock_reader = AsyncMock() + mock_reader.read_files.return_value = sample_documents + + with ( + patch( + "graphrag.prompt_tune.loader.input.create_embedding", + return_value=MockEmbeddingModel(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_storage", + return_value=MagicMock(), + ), + patch( + "graphrag.prompt_tune.loader.input.create_input_reader", + return_value=mock_reader, + ), + patch( + "graphrag.prompt_tune.loader.input.create_chunker", + return_value=MockChunker(), + ), + ): + result = await load_docs_in_chunks( + config=mock_config, + select_method=DocSelectionType.TOP, + limit=5, + logger=mock_logger, + ) + + assert len(result) == 5 + assert "First sentence." in result + assert "Another document." in result