Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
1b3056d
[Feature][DataLoader] Wire TableTransformer into OpenHouseDataLoader
ShreyeshArangath Mar 10, 2026
39bc6fc
[DataLoader] Clean up DataFusion TableTransformer integration
ShreyeshArangath Mar 10, 2026
6819fa3
[DataLoader] Simplify transform to SQL string + add pickle support
ShreyeshArangath Mar 10, 2026
f31e778
[DataLoader] Harden SQL identifier escaping and runtime validation
ShreyeshArangath Mar 10, 2026
0b62dc2
[DataLoader] Reuse DataFusion session per split and rebind batch table
ShreyeshArangath Mar 10, 2026
01a92e5
Merge upstream/main into feat/add-basic-datafusion-integration
ShreyeshArangath Mar 11, 2026
6f130be
[DataLoader] Move table_id into TableScanContext
ShreyeshArangath Mar 11, 2026
60c3642
[DataLoader] Remove unused Any import and simplify dict type hint
ShreyeshArangath Mar 11, 2026
fdfc92b
Merge main into feat/add-basic-datafusion-integration
ShreyeshArangath Mar 11, 2026
07e8d3c
[DataLoader] Make table_id required and introduce DataLoaderRuntimeError
ShreyeshArangath Mar 11, 2026
89be56e
[DataLoader] Remove unused DataLoaderRuntimeError and dead defensive …
ShreyeshArangath Mar 11, 2026
e076a1a
Merge upstream/main into feat/add-basic-datafusion-integration
ShreyeshArangath Mar 12, 2026
c49e5f4
[DataLoader] Address PR #496 review comments
ShreyeshArangath Mar 16, 2026
1f403d9
Merge upstream/main into feat/add-basic-datafusion-integration
ShreyeshArangath Mar 16, 2026
cdbb160
[DataLoader] Integrate SQL transpilation into data loader pipeline
ShreyeshArangath Mar 16, 2026
f7fbeb3
[DataLoader] Remove dialect property from TableTransformer
ShreyeshArangath Mar 16, 2026
d0aec79
[DataLoader] Add dialect support to TableTransformer with transpilation
ShreyeshArangath Mar 16, 2026
511ce4b
[DataLoader] Simplify TableTransformer docstrings
ShreyeshArangath Mar 16, 2026
f6bfbc8
[DataLoader] Remove implementation details from transform() docstring
ShreyeshArangath Mar 16, 2026
c9d4a67
[DataLoader] Address PR review: simplify docstrings, remove implement…
ShreyeshArangath Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,22 @@
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import TableMetadata

from openhouse.dataloader.table_identifier import TableIdentifier


def _unpickle_scan_context(
table_metadata: TableMetadata,
io_properties: dict[str, str],
projected_schema: Schema,
row_filter: BooleanExpression,
table_id: TableIdentifier,
) -> TableScanContext:
return TableScanContext(
table_metadata=table_metadata,
io=load_file_io(properties=io_properties, location=table_metadata.location),
projected_schema=projected_schema,
row_filter=row_filter,
table_id=table_id,
)


Expand All @@ -33,16 +37,18 @@ class TableScanContext:
table_metadata: Full Iceberg table metadata (schema, properties, partition specs, etc.)
io: FileIO configured for the table's storage location
projected_schema: Subset of columns to read (equals table schema when no projection)
table_id: Identifier for the table being scanned
row_filter: Row-level filter expression pushed down to the scan
"""

table_metadata: TableMetadata
io: FileIO
projected_schema: Schema
table_id: TableIdentifier
row_filter: BooleanExpression = AlwaysTrue()

def __reduce__(self) -> tuple:
return (
_unpickle_scan_context,
(self.table_metadata, dict(self.io.properties), self.projected_schema, self.row_filter),
(self.table_metadata, dict(self.io.properties), self.projected_schema, self.row_filter, self.table_id),
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openhouse.dataloader._table_scan_context import TableScanContext
from openhouse.dataloader._timer import log_duration
from openhouse.dataloader.data_loader_split import DataLoaderSplit
from openhouse.dataloader.datafusion_sql import to_datafusion_sql
from openhouse.dataloader.filters import Filter, _to_pyiceberg, always_true
from openhouse.dataloader.table_identifier import TableIdentifier
from openhouse.dataloader.table_transformer import TableTransformer
Expand Down Expand Up @@ -146,6 +147,13 @@ def _verify_snapshot(self, snapshot: Snapshot | None) -> None:
else:
logger.info("No snapshot found for table %s", self._table_id)

def _build_transform_sql(self, transformer: TableTransformer, context: Mapping[str, str]) -> str | None:
"""Return DataFusion-compatible SQL for the transformation, or ``None``."""
sql = transformer.transform(self._table_id, context)
if sql is None:
return None
return to_datafusion_sql(sql, transformer.dialect)

def __iter__(self) -> Iterator[DataLoaderSplit]:
"""Iterate over data splits for distributed data loading of the table.

Expand All @@ -154,11 +162,20 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
"""
table = self._iceberg_table

# Build transform SQL: call transformer once to get the SQL string
transformer = self._context.table_transformer
execution_context = self._context.execution_context or {}
transform_sql = self._build_transform_sql(transformer, execution_context) if transformer is not None else None

if self._columns and transform_sql is not None:
raise ValueError("Column projections with table transformers are not supported yet")

row_filter = _to_pyiceberg(self._filters)

scan_kwargs: dict = {"row_filter": row_filter}
if self.snapshot_id is not None:
scan_kwargs["snapshot_id"] = self.snapshot_id

if self._columns:
scan_kwargs["selected_fields"] = tuple(self._columns)

Expand All @@ -171,6 +188,7 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
io=table.io,
projected_schema=scan.projection(),
row_filter=row_filter,
table_id=self._table_id,
)

# plan_files() materializes all tasks at once (PyIceberg doesn't support streaming)
Expand All @@ -183,5 +201,7 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
yield DataLoaderSplit(
file_scan_task=scan_task,
scan_context=scan_context,
transform_sql=transform_sql,
udf_registry=self._context.udf_registry,
batch_size=self._batch_size,
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,65 @@
from types import MappingProxyType

from datafusion.context import SessionContext
from datafusion.plan import LogicalPlan
from datafusion.substrait import Producer
from pyarrow import RecordBatch
from pyiceberg.io.pyarrow import ArrowScan
from pyiceberg.table import ArrivalOrder, FileScanTask

from openhouse.dataloader._table_scan_context import TableScanContext
from openhouse.dataloader.table_identifier import TableIdentifier
from openhouse.dataloader.udf_registry import NoOpRegistry, UDFRegistry


def _quote_identifier(name: str) -> str:
"""Escape a SQL identifier by doubling embedded double quotes and wrapping in double quotes."""
return '"' + name.replace('"', '""') + '"'


def to_sql_identifier(table_id: TableIdentifier) -> str:
"""Return the quoted DataFusion SQL identifier, e.g. ``"db"."tbl"``."""
return f"{_quote_identifier(table_id.database)}.{_quote_identifier(table_id.table)}"


def _create_transform_session(
table_id: TableIdentifier,
udf_registry: UDFRegistry,
) -> SessionContext:
"""Create a DataFusion SessionContext for running split-level transforms.

Returns a ready-to-query SessionContext where UDFs are registered and the
target schema exists.
"""
session = SessionContext()
udf_registry.register_udfs(session)

session.sql(f"CREATE SCHEMA IF NOT EXISTS {_quote_identifier(table_id.database)}").collect()
return session


def _bind_batch_table(session: SessionContext, table_id: TableIdentifier, batch: RecordBatch) -> None:
"""Bind a single batch to the table name used by transform SQL."""
name = to_sql_identifier(table_id)
session.deregister_table(name)
session.register_record_batches(name, [[batch]])


class DataLoaderSplit:
"""A single data split"""

def __init__(
self,
file_scan_task: FileScanTask,
scan_context: TableScanContext,
plan: LogicalPlan | None = None,
session_context: SessionContext | None = None,
transform_sql: str | None = None,
udf_registry: UDFRegistry | None = None,
batch_size: int | None = None,
):
self._file_scan_task = file_scan_task
self._udf_registry = udf_registry or NoOpRegistry()
self._scan_context = scan_context
self._transform_sql = transform_sql
self._udf_registry = udf_registry or NoOpRegistry()
self._batch_size = batch_size

if (plan is None) != (session_context is None):
raise ValueError("plan and session_context must both be provided or both be None")

if plan is not None:
# TODO: Deserialize back to a LogicalPlan once we integrate with DataFusion for execution.
# The UDF registry is retained so UDFs can be re-registered on remote workers.
assert session_context is not None # guaranteed by the guard above
self._udf_registry.register_udfs(session_context)
self._plan_substrait_bytes: bytes | None = Producer.to_substrait_plan(plan, session_context).encode()
else:
self._plan_substrait_bytes = None

@property
def id(self) -> str:
"""Unique ID for the split. This is stable across executions for a given
Expand Down Expand Up @@ -71,7 +91,21 @@ def __iter__(self) -> Iterator[RecordBatch]:
projected_schema=ctx.projected_schema,
row_filter=ctx.row_filter,
)
yield from arrow_scan.to_record_batches(

batches = arrow_scan.to_record_batches(
[self._file_scan_task],
order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size),
)

if self._transform_sql is None:
yield from batches
else:
session = _create_transform_session(self._scan_context.table_id, self._udf_registry)
for batch in batches:
yield from self._apply_transform(session, batch)

def _apply_transform(self, session: SessionContext, batch: RecordBatch) -> Iterator[RecordBatch]:
"""Execute the transform SQL against a single RecordBatch."""
_bind_batch_table(session, self._scan_context.table_id, batch)
df = session.sql(self._transform_sql) # type: ignore[arg-type] # caller guarantees not None
yield from df.collect()
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping

from datafusion.context import SessionContext
from datafusion.dataframe import DataFrame

from openhouse.dataloader.table_identifier import TableIdentifier


class TableTransformer(ABC):
"""Interface for applying additional transformation logic to the data
being loaded (e.g. column masking, row filtering)
"""Applies transformation logic to the base table that is being loaded.

Args:
dialect: The SQL dialect used by ``transform()`` (e.g. ``"spark"``).
"""

def __init__(self, dialect: str) -> None:
self.dialect: str = dialect

@abstractmethod
def transform(
self, session_context: SessionContext, table: TableIdentifier, context: Mapping[str, str]
) -> DataFrame | None:
"""Applies transformation logic to the base table that is being loaded.
def transform(self, table: TableIdentifier, context: Mapping[str, str]) -> str | None:
"""Builds a SQL string representing the transformation to apply.

Args:
table: Identifier for the table
context: Dictionary of context information (e.g. tenant, environment, etc.)

Returns:
The DataFrame representing the transformation. This is expected to read from the exact
base table identifier passed in as input. If no transformation is required, None is returned.
A SQL string, or None if no transformation is needed.
"""
pass
Loading