Skip to content

Commit 83acd98

Browse files
authored
fix: Allow to pass custom metadata to failure info parquets (#83)
1 parent 9c6d527 commit 83acd98

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

dataframely/failure.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def write_parquet(self, file: str | Path | IO[bytes], **kwargs: Any) -> None:
9797
Be aware that this method suffers from the same limitations as
9898
:meth:`Schema.serialize`.
9999
"""
100-
metadata = self._build_metadata(**kwargs)
100+
metadata, kwargs = self._build_metadata(**kwargs)
101101
self._df.write_parquet(file, metadata=metadata, **kwargs)
102102

103103
def sink_parquet(
@@ -117,14 +117,16 @@ def sink_parquet(
117117
Be aware that this method suffers from the same limitations as
118118
:meth:`Schema.serialize`.
119119
"""
120-
metadata = self._build_metadata(**kwargs)
120+
metadata, kwargs = self._build_metadata(**kwargs)
121121
self._lf.sink_parquet(file, metadata=metadata, **kwargs)
122122

123-
def _build_metadata(self, **kwargs: Any) -> dict[str, Any]:
123+
def _build_metadata(
124+
self, **kwargs: dict[str, Any]
125+
) -> tuple[dict[str, Any], dict[str, Any]]:
124126
metadata = kwargs.pop("metadata", {})
125127
metadata[RULE_METADATA_KEY] = json.dumps(self._rule_columns)
126128
metadata[SCHEMA_METADATA_KEY] = self.schema.serialize()
127-
return metadata
129+
return metadata, kwargs
128130

129131
@classmethod
130132
def read_parquet(

tests/test_failure_info.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ def test_scan_sink_parquet(tmp_path: Path) -> None:
5252
assert MySchema.matches(read.schema)
5353

5454

55+
def test_write_parquet_custom_metadata(tmp_path: Path) -> None:
56+
df = pl.DataFrame(
57+
{
58+
"a": [4, 5, 6, 6, 7, 8],
59+
"b": [1, 2, 3, 4, 5, 6],
60+
}
61+
)
62+
_, failure = MySchema.filter(df)
63+
failure.write_parquet(tmp_path / "failure.parquet", metadata={"custom": "test"})
64+
assert pl.read_parquet_metadata(tmp_path / "failure.parquet")["custom"] == "test"
65+
66+
5567
@pytest.mark.parametrize(
5668
"read_fn",
5769
[dy.FailureInfo.read_parquet, dy.FailureInfo.scan_parquet],

0 commit comments

Comments
 (0)