Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
1b3056d
[Feature][DataLoader] Wire TableTransformer into OpenHouseDataLoader
ShreyeshArangath Mar 10, 2026
39bc6fc
[DataLoader] Clean up DataFusion TableTransformer integration
ShreyeshArangath Mar 10, 2026
6819fa3
[DataLoader] Simplify transform to SQL string + add pickle support
ShreyeshArangath Mar 10, 2026
f31e778
[DataLoader] Harden SQL identifier escaping and runtime validation
ShreyeshArangath Mar 10, 2026
0b62dc2
[DataLoader] Reuse DataFusion session per split and rebind batch table
ShreyeshArangath Mar 10, 2026
01a92e5
Merge upstream/main into feat/add-basic-datafusion-integration
ShreyeshArangath Mar 11, 2026
6f130be
[DataLoader] Move table_id into TableScanContext
ShreyeshArangath Mar 11, 2026
60c3642
[DataLoader] Remove unused Any import and simplify dict type hint
ShreyeshArangath Mar 11, 2026
fdfc92b
Merge main into feat/add-basic-datafusion-integration
ShreyeshArangath Mar 11, 2026
07e8d3c
[DataLoader] Make table_id required and introduce DataLoaderRuntimeError
ShreyeshArangath Mar 11, 2026
89be56e
[DataLoader] Remove unused DataLoaderRuntimeError and dead defensive …
ShreyeshArangath Mar 11, 2026
e076a1a
Merge upstream/main into feat/add-basic-datafusion-integration
ShreyeshArangath Mar 12, 2026
c49e5f4
[DataLoader] Address PR #496 review comments
ShreyeshArangath Mar 16, 2026
1f403d9
Merge upstream/main into feat/add-basic-datafusion-integration
ShreyeshArangath Mar 16, 2026
cdbb160
[DataLoader] Integrate SQL transpilation into data loader pipeline
ShreyeshArangath Mar 16, 2026
f7fbeb3
[DataLoader] Remove dialect property from TableTransformer
ShreyeshArangath Mar 16, 2026
d0aec79
[DataLoader] Add dialect support to TableTransformer with transpilation
ShreyeshArangath Mar 16, 2026
511ce4b
[DataLoader] Simplify TableTransformer docstrings
ShreyeshArangath Mar 16, 2026
f6bfbc8
[DataLoader] Remove implementation details from transform() docstring
ShreyeshArangath Mar 16, 2026
c9d4a67
[DataLoader] Address PR review: simplify docstrings, remove implement…
ShreyeshArangath Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,22 @@
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import TableMetadata

from openhouse.dataloader.table_identifier import TableIdentifier


def _unpickle_scan_context(
table_metadata: TableMetadata,
io_properties: dict[str, str],
projected_schema: Schema,
row_filter: BooleanExpression,
table_id: TableIdentifier,
) -> TableScanContext:
return TableScanContext(
table_metadata=table_metadata,
io=load_file_io(properties=io_properties, location=table_metadata.location),
projected_schema=projected_schema,
row_filter=row_filter,
table_id=table_id,
)


Expand All @@ -33,16 +37,18 @@ class TableScanContext:
table_metadata: Full Iceberg table metadata (schema, properties, partition specs, etc.)
io: FileIO configured for the table's storage location
projected_schema: Subset of columns to read (equals table schema when no projection)
table_id: Identifier for the table being scanned
row_filter: Row-level filter expression pushed down to the scan
"""

table_metadata: TableMetadata
io: FileIO
projected_schema: Schema
table_id: TableIdentifier
row_filter: BooleanExpression = AlwaysTrue()

def __reduce__(self) -> tuple:
return (
_unpickle_scan_context,
(self.table_metadata, dict(self.io.properties), self.projected_schema, self.row_filter),
(self.table_metadata, dict(self.io.properties), self.projected_schema, self.row_filter, self.table_id),
)
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ def _verify_snapshot(self, snapshot: Snapshot | None) -> None:
else:
logger.info("No snapshot found for table %s", self._table_id)

def _build_transform_sql(self, transformer: TableTransformer, context: Mapping[str, str]) -> str | None:
"""Call the transformer to get the SQL string for the transformation.

Returns the SQL string if the transformer returned one,
or ``None`` if no transformation is needed.
"""
return transformer.transform(self._table_id, context)

def __iter__(self) -> Iterator[DataLoaderSplit]:
"""Iterate over data splits for distributed data loading of the table.

Expand All @@ -154,12 +162,20 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
"""
table = self._iceberg_table

# Build transform SQL: call transformer once to get the SQL string
transformer = self._context.table_transformer
execution_context = self._context.execution_context or {}
transform_sql = self._build_transform_sql(transformer, execution_context) if transformer is not None else None

row_filter = _to_pyiceberg(self._filters)

scan_kwargs: dict = {"row_filter": row_filter}
if self.snapshot_id is not None:
scan_kwargs["snapshot_id"] = self.snapshot_id
if self._columns:

# When a transform is active, skip column projection — the transform may need all columns
# TODO: extract projected columns from the plan instead of skipping projection
if self._columns and transform_sql is None:
Comment thread
ShreyeshArangath marked this conversation as resolved.
Outdated
scan_kwargs["selected_fields"] = tuple(self._columns)

scan = table.scan(**scan_kwargs)
Expand All @@ -171,6 +187,7 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
io=table.io,
projected_schema=scan.projection(),
row_filter=row_filter,
table_id=self._table_id,
)

# plan_files() materializes all tasks at once (PyIceberg doesn't support streaming)
Expand All @@ -183,5 +200,7 @@ def __iter__(self) -> Iterator[DataLoaderSplit]:
yield DataLoaderSplit(
file_scan_task=scan_task,
scan_context=scan_context,
transform_sql=transform_sql,
udf_registry=self._context.udf_registry,
batch_size=self._batch_size,
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,54 @@
from types import MappingProxyType

from datafusion.context import SessionContext
from datafusion.plan import LogicalPlan
from datafusion.substrait import Producer
from pyarrow import RecordBatch
from pyiceberg.io.pyarrow import ArrowScan
from pyiceberg.table import ArrivalOrder, FileScanTask

from openhouse.dataloader._table_scan_context import TableScanContext
from openhouse.dataloader.table_identifier import TableIdentifier, _quote_identifier
from openhouse.dataloader.udf_registry import NoOpRegistry, UDFRegistry


def _create_transform_session(
table_id: TableIdentifier,
udf_registry: UDFRegistry,
) -> SessionContext:
"""Create a DataFusion SessionContext for running split-level transforms.

Returns a ready-to-query SessionContext where UDFs are registered and the
target schema exists.
"""
session = SessionContext()
udf_registry.register_udfs(session)

session.sql(f"CREATE SCHEMA IF NOT EXISTS {_quote_identifier(table_id.database)}").collect()
return session


def _bind_batch_table(session: SessionContext, table_id: TableIdentifier, batch: RecordBatch) -> None:
"""Bind a single batch to the table name used by transform SQL."""
session.deregister_table(table_id.sql_name)
session.register_record_batches(table_id.sql_name, [[batch]])


class DataLoaderSplit:
"""A single data split"""

def __init__(
self,
file_scan_task: FileScanTask,
scan_context: TableScanContext,
plan: LogicalPlan | None = None,
session_context: SessionContext | None = None,
transform_sql: str | None = None,
udf_registry: UDFRegistry | None = None,
batch_size: int | None = None,
):
self._file_scan_task = file_scan_task
self._udf_registry = udf_registry or NoOpRegistry()
self._scan_context = scan_context
self._transform_sql = transform_sql
self._udf_registry = udf_registry or NoOpRegistry()
self._batch_size = batch_size

if (plan is None) != (session_context is None):
raise ValueError("plan and session_context must both be provided or both be None")

if plan is not None:
# TODO: Deserialize back to a LogicalPlan once we integrate with DataFusion for execution.
# The UDF registry is retained so UDFs can be re-registered on remote workers.
assert session_context is not None # guaranteed by the guard above
self._udf_registry.register_udfs(session_context)
self._plan_substrait_bytes: bytes | None = Producer.to_substrait_plan(plan, session_context).encode()
else:
self._plan_substrait_bytes = None

@property
def id(self) -> str:
"""Unique ID for the split. This is stable across executions for a given
Expand Down Expand Up @@ -71,7 +80,22 @@ def __iter__(self) -> Iterator[RecordBatch]:
projected_schema=ctx.projected_schema,
row_filter=ctx.row_filter,
)
yield from arrow_scan.to_record_batches(

batches = arrow_scan.to_record_batches(
[self._file_scan_task],
order=ArrivalOrder(concurrent_streams=1, batch_size=self._batch_size),
)

if self._transform_sql is None:
yield from batches
else:
session = _create_transform_session(self._scan_context.table_id, self._udf_registry)
for batch in batches:
yield from self._apply_transform(session, batch)

def _apply_transform(self, session: SessionContext, batch: RecordBatch) -> Iterator[RecordBatch]:
"""Execute the transform SQL against a single RecordBatch."""
assert self._transform_sql is not None # guaranteed by caller
Comment thread
ShreyeshArangath marked this conversation as resolved.
Outdated
_bind_batch_table(session, self._scan_context.table_id, batch)
df = session.sql(self._transform_sql)
yield from df.collect()
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from dataclasses import dataclass


def _quote_identifier(name: str) -> str:
"""Escape a SQL identifier by doubling embedded double quotes and wrapping in double quotes."""
return '"' + name.replace('"', '""') + '"'


@dataclass
class TableIdentifier:
"""Identifier for a table in OpenHouse
Expand All @@ -15,6 +20,11 @@ class TableIdentifier:
table: str
branch: str | None = None

@property
def sql_name(self) -> str:
Comment thread
ShreyeshArangath marked this conversation as resolved.
Outdated
"""Return the quoted DataFusion SQL identifier, e.g. ``"db"."tbl"``."""
return f"{_quote_identifier(self.database)}.{_quote_identifier(self.table)}"

def __str__(self) -> str:
"""Return the fully qualified table name."""
base = f"{self.database}.{self.table}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping

from datafusion.context import SessionContext
from datafusion.dataframe import DataFrame

from openhouse.dataloader.table_identifier import TableIdentifier


Expand All @@ -13,17 +10,21 @@ class TableTransformer(ABC):
"""

@abstractmethod
def transform(
self, session_context: SessionContext, table: TableIdentifier, context: Mapping[str, str]
) -> DataFrame | None:
"""Applies transformation logic to the base table that is being loaded.
def transform(self, table: TableIdentifier, context: Mapping[str, str]) -> str | None:
"""Builds a SQL string representing the transformation to apply.

Called once to extract the SQL. The SQL is then executed per batch in
each split against a DataFusion session where the batch is registered
under ``table.sql_name``.

The decision to return a SQL string or ``None`` **must not** depend on
row data — it should be based solely on the table identifier and context.
Comment thread
ShreyeshArangath marked this conversation as resolved.
Outdated

Args:
table: Identifier for the table
context: Dictionary of context information (e.g. tenant, environment, etc.)

Returns:
The DataFrame representing the transformation. This is expected to read from the exact
base table identifier passed in as input. If no transformation is required, None is returned.
A SQL string to execute against each batch, or None if no transformation is needed.
Comment thread
ShreyeshArangath marked this conversation as resolved.
Outdated
"""
pass
95 changes: 95 additions & 0 deletions integrations/python/dataloader/tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from openhouse.dataloader import DataLoaderContext, OpenHouseDataLoader, __version__
from openhouse.dataloader.data_loader_split import DataLoaderSplit
from openhouse.dataloader.filters import col
from openhouse.dataloader.table_transformer import TableTransformer


def test_package_imports():
Expand Down Expand Up @@ -324,6 +325,100 @@ def test_snapshot_id_with_columns_and_filters(tmp_path):
assert "row_filter" in scan_kwargs


# --- Transformer tests ---


class _NoneTransformer(TableTransformer):
"""Transformer that returns None (no transformation)."""

def transform(self, table, context):
return None


class _MaskingTransformer(TableTransformer):
"""Transformer that masks the name column."""

def transform(self, table, context):
return f"SELECT id, 'MASKED' as name, value FROM {table.sql_name}"


def test_iter_with_transformer_returning_none(tmp_path):
"""Transformer returns None → native Iceberg path, selected_fields still passed."""
catalog = _make_real_catalog(tmp_path)
mock_table = catalog.load_table.return_value

loader = OpenHouseDataLoader(
catalog=catalog,
database="db",
table="tbl",
columns=[COL_ID, COL_NAME],
context=DataLoaderContext(table_transformer=_NoneTransformer()),
)
result = _materialize(loader)

assert result.num_rows == 3
assert set(result.column_names) == {COL_ID, COL_NAME}
mock_table.scan.assert_called_once()
scan_kwargs = mock_table.scan.call_args.kwargs
assert scan_kwargs["selected_fields"] == (COL_ID, COL_NAME)


def test_iter_with_transformer_returning_sql(tmp_path):
"""Transformer returns SQL → transform is applied to splits."""
catalog = _make_real_catalog(tmp_path)

loader = OpenHouseDataLoader(
catalog=catalog,
database="db",
table="tbl",
context=DataLoaderContext(table_transformer=_MaskingTransformer()),
)
result = _materialize(loader)

assert result.num_rows == 3
assert result.column("name").to_pylist() == ["MASKED", "MASKED", "MASKED"]


def test_iter_with_transformer_skips_column_projection(tmp_path):
"""columns + transformer → Iceberg scan is called WITHOUT selected_fields."""
catalog = _make_real_catalog(tmp_path)
mock_table = catalog.load_table.return_value

loader = OpenHouseDataLoader(
catalog=catalog,
database="db",
table="tbl",
columns=[COL_ID],
context=DataLoaderContext(table_transformer=_MaskingTransformer()),
)
result = _materialize(loader)

assert result.num_rows == 3
Comment thread
ShreyeshArangath marked this conversation as resolved.
mock_table.scan.assert_called_once()
scan_kwargs = mock_table.scan.call_args.kwargs
assert "selected_fields" not in scan_kwargs


def test_iter_with_transformer_and_special_char_database(tmp_path):
"""Transformer works when the database name contains special characters."""
catalog = _make_real_catalog(tmp_path)

class _QuotedMaskingTransformer(TableTransformer):
def transform(self, table, context):
return f"SELECT id, 'MASKED' as name, value FROM {table.sql_name}"

loader = OpenHouseDataLoader(
catalog=catalog,
database='my"db',
table="tbl",
context=DataLoaderContext(table_transformer=_QuotedMaskingTransformer()),
)
result = _materialize(loader)

assert result.num_rows == 3
assert result.column("name").to_pylist() == ["MASKED", "MASKED", "MASKED"]


# --- branch tests ---


Expand Down
Loading