|
1 | 1 | # Copyright (c) QuantCo 2025-2025 |
2 | 2 | # SPDX-License-Identifier: BSD-3-Clause |
3 | 3 |
|
4 | | -import importlib |
| 4 | +from __future__ import annotations |
| 5 | + |
5 | 6 | import json |
6 | 7 | from functools import cached_property |
7 | 8 | from pathlib import Path |
8 | | -from typing import IO, Generic, Self, TypeVar, cast |
| 9 | +from typing import IO, TYPE_CHECKING, Any, Generic, TypeVar |
9 | 10 |
|
10 | 11 | import polars as pl |
| 12 | +from polars._typing import PartitioningScheme |
11 | 13 |
|
12 | 14 | from dataframely._base_schema import BaseSchema |
13 | 15 |
|
| 16 | +from ._serialization import SCHEMA_METADATA_KEY |
| 17 | + |
| 18 | +if TYPE_CHECKING: # pragma: no cover |
| 19 | + from .schema import Schema |
| 20 | + |
| 21 | +RULE_METADATA_KEY = "dataframely_rule_columns" |
| 22 | + |
14 | 23 | S = TypeVar("S", bound=BaseSchema) |
15 | 24 |
|
16 | 25 |
|
@@ -73,52 +82,116 @@ def __len__(self) -> int: |
73 | 82 |
|
74 | 83 | # ---------------------------------- PERSISTENCE --------------------------------- # |
75 | 84 |
|
76 | | - def write_parquet(self, file: str | Path | IO[bytes]) -> None: |
77 | | - """Write the failure info to a Parquet file. |
| 85 | + def write_parquet(self, file: str | Path | IO[bytes], **kwargs: Any) -> None: |
| 86 | + """Write the failure info to a parquet file. |
78 | 87 |
|
79 | 88 | Args: |
80 | | - file: The file path or writable file-like object to write to. |
| 89 | + file: The file path or writable file-like object to which to write the |
| 90 | + parquet file. This should be a path to a directory if writing a |
| 91 | + partitioned dataset. |
| 92 | + kwargs: Additional keyword arguments passed directly to |
| 93 | + :meth:`polars.write_parquet`. ``metadata`` may only be provided if it |
| 94 | + is a dictionary. |
| 95 | +
|
| 96 | + Attention: |
| 97 | + Be aware that this method suffers from the same limitations as |
| 98 | + :meth:`Schema.serialize`. |
81 | 99 | """ |
82 | | - metadata_json = json.dumps( |
83 | | - { |
84 | | - "rule_columns": self._rule_columns, |
85 | | - "schema": f"{self.schema.__module__}.{self.schema.__name__}", |
86 | | - } |
87 | | - ) |
88 | | - self._df.write_parquet(file, metadata={"dataframely": metadata_json}) |
| 100 | + metadata = self._build_metadata(**kwargs) |
| 101 | + self._df.write_parquet(file, metadata=metadata, **kwargs) |
| 102 | + |
| 103 | + def sink_parquet( |
| 104 | + self, file: str | Path | IO[bytes] | PartitioningScheme, **kwargs: Any |
| 105 | + ) -> None: |
| 106 | + """Stream the failure info to a parquet file. |
| 107 | +
|
| 108 | + Args: |
| 109 | + file: The file path or writable file-like object to which to write the |
| 110 | + parquet file. This should be a path to a directory if writing a |
| 111 | + partitioned dataset. |
| 112 | + kwargs: Additional keyword arguments passed directly to |
| 113 | + :meth:`polars.sink_parquet`. ``metadata`` may only be provided if it |
| 114 | + is a dictionary. |
| 115 | +
|
| 116 | + Attention: |
| 117 | + Be aware that this method suffers from the same limitations as |
| 118 | + :meth:`Schema.serialize`. |
| 119 | + """ |
| 120 | + metadata = self._build_metadata(**kwargs) |
| 121 | + self._lf.sink_parquet(file, metadata=metadata, **kwargs) |
| 122 | + |
| 123 | + def _build_metadata(self, **kwargs: Any) -> dict[str, Any]: |
| 124 | + metadata = kwargs.pop("metadata", {}) |
| 125 | + metadata[RULE_METADATA_KEY] = json.dumps(self._rule_columns) |
| 126 | + metadata[SCHEMA_METADATA_KEY] = self.schema.serialize() |
| 127 | + return metadata |
89 | 128 |
|
90 | 129 | @classmethod |
91 | | - def scan_parquet(cls, source: str | Path | IO[bytes]) -> Self: |
92 | | - """Lazily read the parquet file with the failure info. |
| 130 | + def read_parquet( |
| 131 | + cls, source: str | Path | IO[bytes], **kwargs: Any |
| 132 | + ) -> FailureInfo[Schema]: |
| 133 | + """Read a parquet file with the failure info. |
93 | 134 |
|
94 | 135 | Args: |
95 | | - source: The file path or readable file-like object to read from. |
| 136 | + source: Path, directory, or file-like object from which to read the data. |
| 137 | + kwargs: Additional keyword arguments passed directly to |
| 138 | + :meth:`polars.read_parquet`. |
96 | 139 |
|
97 | 140 | Returns: |
98 | 141 | The failure info object. |
| 142 | +
|
| 143 | + Raises: |
| 144 | + ValueError: If no appropriate metadata can be found. |
| 145 | +
|
| 146 | + Attention: |
| 147 | + Be aware that this method suffers from the same limitations as |
| 148 | + :meth:`Schema.serialize` |
| 149 | + """ |
| 150 | + return cls._from_parquet(source, scan=False, **kwargs) |
| 151 | + |
| 152 | + @classmethod |
| 153 | + def scan_parquet( |
| 154 | + cls, source: str | Path | IO[bytes], **kwargs: Any |
| 155 | + ) -> FailureInfo[Schema]: |
| 156 | + """Lazily read a parquet file with the failure info. |
| 157 | +
|
| 158 | + Args: |
| 159 | + source: Path, directory, or file-like object from which to read the data. |
| 160 | +
|
| 161 | + Returns: |
| 162 | + The failure info object. |
| 163 | +
|
| 164 | + Raises: |
| 165 | + ValueError: If no appropriate metadata can be found. |
| 166 | +
|
| 167 | + Attention: |
| 168 | + Be aware that this method suffers from the same limitations as |
| 169 | + :meth:`Schema.serialize` |
99 | 170 | """ |
100 | | - lf = pl.scan_parquet(source) |
101 | | - |
102 | | - # We can read the rule columns either from the metadata of the Parquet file |
103 | | - # or, to remain backwards-compatible, from the last column of the lazy frame if |
104 | | - # the parquet file is missing metadata. |
105 | | - rule_columns: list[str] |
106 | | - schema_name: str |
107 | | - if (meta := pl.read_parquet_metadata(source).get("dataframely")) is not None: |
108 | | - metadata = json.loads(meta) |
109 | | - rule_columns = metadata["rule_columns"] |
110 | | - schema_name = metadata["schema"] |
111 | | - else: |
112 | | - last_column = lf.collect_schema().names()[-1] |
113 | | - metadata = json.loads(last_column) |
114 | | - rule_columns = metadata["rule_columns"] |
115 | | - schema_name = metadata["schema"] |
116 | | - lf = lf.drop(last_column) |
117 | | - |
118 | | - *schema_module_parts, schema_name = schema_name.split(".") |
119 | | - module = importlib.import_module(".".join(schema_module_parts)) |
120 | | - schema = cast(type[S], getattr(module, schema_name)) |
121 | | - return cls(lf, rule_columns, schema=schema) |
| 171 | + return cls._from_parquet(source, scan=True, **kwargs) |
| 172 | + |
| 173 | + @classmethod |
| 174 | + def _from_parquet( |
| 175 | + cls, source: str | Path | IO[bytes], scan: bool, **kwargs: Any |
| 176 | + ) -> FailureInfo[Schema]: |
| 177 | + from .schema import deserialize_schema |
| 178 | + |
| 179 | + metadata = pl.read_parquet_metadata(source) |
| 180 | + schema_metadata = metadata.get(SCHEMA_METADATA_KEY) |
| 181 | + rule_metadata = metadata.get(RULE_METADATA_KEY) |
| 182 | + if schema_metadata is None or rule_metadata is None: |
| 183 | + raise ValueError("The parquet file does not contain the required metadata.") |
| 184 | + |
| 185 | + lf = ( |
| 186 | + pl.scan_parquet(source, **kwargs) |
| 187 | + if scan |
| 188 | + else pl.read_parquet(source, **kwargs).lazy() |
| 189 | + ) |
| 190 | + return FailureInfo( |
| 191 | + lf, |
| 192 | + json.loads(rule_metadata), |
| 193 | + schema=deserialize_schema(schema_metadata), |
| 194 | + ) |
122 | 195 |
|
123 | 196 |
|
124 | 197 | # ------------------------------------ COMPUTATION ----------------------------------- # |
|
0 commit comments