Skip to content

Commit 6c9305e

Browse files
authored
feat: Add proper serialization support for failure info (#75)
1 parent 3e232a4 commit 6c9305e

File tree

5 files changed

+205
-139
lines changed

5 files changed

+205
-139
lines changed

dataframely/_serialization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import polars as pl
1212

13+
SCHEMA_METADATA_KEY = "dataframely_schema"
1314
SERIALIZATION_FORMAT_VERSION = "1"
1415

1516

dataframely/failure.py

Lines changed: 110 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
# Copyright (c) QuantCo 2025-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
import importlib
4+
from __future__ import annotations
5+
56
import json
67
from functools import cached_property
78
from pathlib import Path
8-
from typing import IO, Generic, Self, TypeVar, cast
9+
from typing import IO, TYPE_CHECKING, Any, Generic, TypeVar
910

1011
import polars as pl
12+
from polars._typing import PartitioningScheme
1113

1214
from dataframely._base_schema import BaseSchema
1315

16+
from ._serialization import SCHEMA_METADATA_KEY
17+
18+
if TYPE_CHECKING: # pragma: no cover
19+
from .schema import Schema
20+
21+
RULE_METADATA_KEY = "dataframely_rule_columns"
22+
1423
S = TypeVar("S", bound=BaseSchema)
1524

1625

@@ -73,52 +82,116 @@ def __len__(self) -> int:
7382

7483
# ---------------------------------- PERSISTENCE --------------------------------- #
7584

76-
def write_parquet(self, file: str | Path | IO[bytes]) -> None:
77-
"""Write the failure info to a Parquet file.
85+
def write_parquet(self, file: str | Path | IO[bytes], **kwargs: Any) -> None:
86+
"""Write the failure info to a parquet file.
7887
7988
Args:
80-
file: The file path or writable file-like object to write to.
89+
file: The file path or writable file-like object to which to write the
90+
parquet file. This should be a path to a directory if writing a
91+
partitioned dataset.
92+
kwargs: Additional keyword arguments passed directly to
93+
:meth:`polars.write_parquet`. ``metadata`` may only be provided if it
94+
is a dictionary.
95+
96+
Attention:
97+
Be aware that this method suffers from the same limitations as
98+
:meth:`Schema.serialize`.
8199
"""
82-
metadata_json = json.dumps(
83-
{
84-
"rule_columns": self._rule_columns,
85-
"schema": f"{self.schema.__module__}.{self.schema.__name__}",
86-
}
87-
)
88-
self._df.write_parquet(file, metadata={"dataframely": metadata_json})
100+
metadata = self._build_metadata(**kwargs)
101+
self._df.write_parquet(file, metadata=metadata, **kwargs)
102+
103+
def sink_parquet(
104+
self, file: str | Path | IO[bytes] | PartitioningScheme, **kwargs: Any
105+
) -> None:
106+
"""Stream the failure info to a parquet file.
107+
108+
Args:
109+
file: The file path or writable file-like object to which to write the
110+
parquet file. This should be a path to a directory if writing a
111+
partitioned dataset.
112+
kwargs: Additional keyword arguments passed directly to
113+
:meth:`polars.sink_parquet`. ``metadata`` may only be provided if it
114+
is a dictionary.
115+
116+
Attention:
117+
Be aware that this method suffers from the same limitations as
118+
:meth:`Schema.serialize`.
119+
"""
120+
metadata = self._build_metadata(**kwargs)
121+
self._lf.sink_parquet(file, metadata=metadata, **kwargs)
122+
123+
def _build_metadata(self, **kwargs: Any) -> dict[str, Any]:
124+
metadata = kwargs.pop("metadata", {})
125+
metadata[RULE_METADATA_KEY] = json.dumps(self._rule_columns)
126+
metadata[SCHEMA_METADATA_KEY] = self.schema.serialize()
127+
return metadata
89128

90129
@classmethod
91-
def scan_parquet(cls, source: str | Path | IO[bytes]) -> Self:
92-
"""Lazily read the parquet file with the failure info.
130+
def read_parquet(
131+
cls, source: str | Path | IO[bytes], **kwargs: Any
132+
) -> FailureInfo[Schema]:
133+
"""Read a parquet file with the failure info.
93134
94135
Args:
95-
source: The file path or readable file-like object to read from.
136+
source: Path, directory, or file-like object from which to read the data.
137+
kwargs: Additional keyword arguments passed directly to
138+
:meth:`polars.read_parquet`.
96139
97140
Returns:
98141
The failure info object.
142+
143+
Raises:
144+
ValueError: If no appropriate metadata can be found.
145+
146+
Attention:
147+
Be aware that this method suffers from the same limitations as
148+
:meth:`Schema.serialize`
149+
"""
150+
return cls._from_parquet(source, scan=False, **kwargs)
151+
152+
@classmethod
153+
def scan_parquet(
154+
cls, source: str | Path | IO[bytes], **kwargs: Any
155+
) -> FailureInfo[Schema]:
156+
"""Lazily read a parquet file with the failure info.
157+
158+
Args:
159+
source: Path, directory, or file-like object from which to read the data.
160+
161+
Returns:
162+
The failure info object.
163+
164+
Raises:
165+
ValueError: If no appropriate metadata can be found.
166+
167+
Attention:
168+
Be aware that this method suffers from the same limitations as
169+
:meth:`Schema.serialize`
99170
"""
100-
lf = pl.scan_parquet(source)
101-
102-
# We can read the rule columns either from the metadata of the Parquet file
103-
# or, to remain backwards-compatible, from the last column of the lazy frame if
104-
# the parquet file is missing metadata.
105-
rule_columns: list[str]
106-
schema_name: str
107-
if (meta := pl.read_parquet_metadata(source).get("dataframely")) is not None:
108-
metadata = json.loads(meta)
109-
rule_columns = metadata["rule_columns"]
110-
schema_name = metadata["schema"]
111-
else:
112-
last_column = lf.collect_schema().names()[-1]
113-
metadata = json.loads(last_column)
114-
rule_columns = metadata["rule_columns"]
115-
schema_name = metadata["schema"]
116-
lf = lf.drop(last_column)
117-
118-
*schema_module_parts, schema_name = schema_name.split(".")
119-
module = importlib.import_module(".".join(schema_module_parts))
120-
schema = cast(type[S], getattr(module, schema_name))
121-
return cls(lf, rule_columns, schema=schema)
171+
return cls._from_parquet(source, scan=True, **kwargs)
172+
173+
@classmethod
174+
def _from_parquet(
175+
cls, source: str | Path | IO[bytes], scan: bool, **kwargs: Any
176+
) -> FailureInfo[Schema]:
177+
from .schema import deserialize_schema
178+
179+
metadata = pl.read_parquet_metadata(source)
180+
schema_metadata = metadata.get(SCHEMA_METADATA_KEY)
181+
rule_metadata = metadata.get(RULE_METADATA_KEY)
182+
if schema_metadata is None or rule_metadata is None:
183+
raise ValueError("The parquet file does not contain the required metadata.")
184+
185+
lf = (
186+
pl.scan_parquet(source, **kwargs)
187+
if scan
188+
else pl.read_parquet(source, **kwargs).lazy()
189+
)
190+
return FailureInfo(
191+
lf,
192+
json.loads(rule_metadata),
193+
schema=deserialize_schema(schema_metadata),
194+
)
122195

123196

124197
# ------------------------------------ COMPUTATION ----------------------------------- #

dataframely/schema.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ._compat import pa, sa
2020
from ._rule import Rule, rule_from_dict, with_evaluation_rules
2121
from ._serialization import (
22+
SCHEMA_METADATA_KEY,
2223
SERIALIZATION_FORMAT_VERSION,
2324
SchemaJSONDecoder,
2425
SchemaJSONEncoder,
@@ -33,7 +34,7 @@
3334
from .random import Generator
3435

3536
_ORIGINAL_NULL_SUFFIX = "__orig_null__"
36-
_METADATA_KEY = "dataframely_schema"
37+
3738

3839
# ------------------------------------------------------------------------------------ #
3940
# SCHEMA DEFINITION #
@@ -479,7 +480,7 @@ def _validate_schema(
479480
@classmethod
480481
def filter(
481482
cls, df: pl.DataFrame | pl.LazyFrame, /, *, cast: bool = False
482-
) -> tuple[DataFrame[Self], FailureInfo]:
483+
) -> tuple[DataFrame[Self], FailureInfo[Self]]:
483484
"""Filter the data frame by the rules of this schema.
484485
485486
This method can be thought of as a "soft alternative" to :meth:`validate`.
@@ -708,7 +709,7 @@ def write_parquet(
708709
"""
709710
metadata = kwargs.pop("metadata", {})
710711
df.write_parquet(
711-
file, metadata={**metadata, _METADATA_KEY: cls.serialize()}, **kwargs
712+
file, metadata={**metadata, SCHEMA_METADATA_KEY: cls.serialize()}, **kwargs
712713
)
713714

714715
@classmethod
@@ -739,7 +740,7 @@ def sink_parquet(
739740
"""
740741
metadata = kwargs.pop("metadata", {})
741742
lf.sink_parquet(
742-
file, metadata={**metadata, _METADATA_KEY: cls.serialize()}, **kwargs
743+
file, metadata={**metadata, SCHEMA_METADATA_KEY: cls.serialize()}, **kwargs
743744
)
744745

745746
@classmethod
@@ -860,7 +861,7 @@ def _requires_validation_for_reading_parquet(
860861
# does, we check whether it matches this schema. If it does, we assume that the
861862
# data adheres to the schema and we do not need to run validation.
862863
metadata = (
863-
pl.read_parquet_metadata(source).get(_METADATA_KEY)
864+
pl.read_parquet_metadata(source).get(SCHEMA_METADATA_KEY)
864865
if not isinstance(source, list)
865866
else None
866867
)

0 commit comments

Comments
 (0)