Skip to content

Commit 3e232a4

Browse files
authored
feat: Allow to efficiently read and write collections (#77)
1 parent a4d3ebf commit 3e232a4

File tree

4 files changed

+496
-76
lines changed

4 files changed

+496
-76
lines changed

dataframely/collection.py

Lines changed: 207 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@
2121
SchemaJSONEncoder,
2222
serialization_versions,
2323
)
24-
from ._typing import LazyFrame
25-
from .exc import MemberValidationError, RuleValidationError, ValidationError
24+
from ._typing import LazyFrame, Validation
25+
from .exc import (
26+
MemberValidationError,
27+
RuleValidationError,
28+
ValidationError,
29+
ValidationRequiredError,
30+
)
2631
from .failure import FailureInfo
2732
from .random import Generator
2833
from .schema import _schema_from_dict
@@ -620,84 +625,254 @@ def serialize(cls) -> str:
620625

621626
# ---------------------------------- PERSISTENCE --------------------------------- #
622627

623-
def write_parquet(self, directory: Path) -> None:
624-
"""Write the members of this collection to Parquet files in a directory.
628+
def write_parquet(self, directory: str | Path, **kwargs: Any) -> None:
629+
"""Write the members of this collection to parquet files in a directory.
625630
626-
This method writes one Parquet file per member into the provided directory.
631+
This method writes one parquet file per member into the provided directory.
627632
Each parquet file is named ``<member>.parquet``. No file is written for optional
628633
members which are not provided in the current collection.
629634
635+
In addition, one JSON file named ``schema.json`` is written, serializing the
636+
collection's definition for fast reads.
637+
630638
Args:
631639
directory: The directory where the Parquet files should be written to. If
632640
the directory does not exist, it is created automatically, including all
633641
of its parents.
642+
kwargs: Additional keyword arguments passed directly to
643+
:meth:`polars.write_parquet` of all members. ``metadata`` may only be
644+
provided if it is a dictionary.
645+
646+
Attention:
647+
This method suffers from the same limitations as :meth:`Schema.serialize`.
634648
"""
635-
directory.mkdir(parents=True, exist_ok=True)
649+
self._to_parquet(directory, sink=False, **kwargs)
650+
651+
def sink_parquet(self, directory: str | Path, **kwargs: Any) -> None:
652+
"""Stream the members of this collection into parquet files in a directory.
653+
654+
This method writes one parquet file per member into the provided directory.
655+
Each parquet file is named ``<member>.parquet``. No file is written for optional
656+
members which are not provided in the current collection.
657+
658+
In addition, one JSON file named ``schema.json`` is written, serializing the
659+
collection's definition for fast reads.
660+
661+
Args:
662+
directory: The directory where the Parquet files should be written to. If
663+
the directory does not exist, it is created automatically, including all
664+
of its parents.
665+
kwargs: Additional keyword arguments passed directly to
666+
:meth:`polars.sink_parquet` of all members. ``metadata`` may only be
667+
provided if it is a dictionary.
668+
669+
Attention:
670+
This method suffers from the same limitations as :meth:`Schema.serialize`.
671+
"""
672+
self._to_parquet(directory, sink=True, **kwargs)
673+
674+
def _to_parquet(self, directory: str | Path, *, sink: bool, **kwargs: Any) -> None:
675+
path = Path(directory) if isinstance(directory, str) else directory
676+
path.mkdir(parents=True, exist_ok=True)
677+
with open(path / "schema.json", "w") as f:
678+
f.write(self.serialize())
679+
680+
member_schemas = self.member_schemas()
636681
for key, lf in self.to_dict().items():
637-
lf.collect().write_parquet(directory / f"{key}.parquet")
682+
destination = (
683+
path / key if "partition_by" in kwargs else path / f"{key}.parquet"
684+
)
685+
if sink:
686+
member_schemas[key].sink_parquet(
687+
lf, # type: ignore
688+
destination,
689+
**kwargs,
690+
)
691+
else:
692+
member_schemas[key].write_parquet(
693+
lf.collect(), # type: ignore
694+
destination,
695+
**kwargs,
696+
)
638697

639698
@classmethod
640-
def read_parquet(cls, directory: Path) -> Self:
641-
"""Eagerly read and validate all collection members from Parquet file in a
642-
directory.
699+
def read_parquet(
700+
cls,
701+
directory: str | Path,
702+
*,
703+
validation: Validation = "warn",
704+
**kwargs: Any,
705+
) -> Self:
706+
"""Read all collection members from parquet files in a directory.
643707
644708
This method searches for files named ``<member>.parquet`` in the provided
645709
directory for all required and optional members of the collection.
646710
647711
Args:
648712
directory: The directory where the Parquet files should be read from.
713+
Parquet files may have been written with Hive partitioning.
714+
validation: The strategy for running validation when reading the data:
715+
716+
- ``"allow"`: The method tries to read the ``schema.json`` file in the
717+
directory. If the stored collection schema matches this collection
718+
schema, the collection is read without validation. If the stored
719+
schema mismatches this schema or no ``schema.json`` can be found in
720+
the directory, this method automatically runs :meth:`validate` with
721+
``cast=True``.
722+
- ``"warn"`: The method behaves similarly to ``"allow"``. However,
723+
it prints a warning if validation is necessary.
724+
- ``"forbid"``: The method never runs validation automatically and only
725+
returns if the ``schema.json`` stores a collection schema that matches
726+
this collection.
727+
- ``"skip"``: The method never runs validation and simply reads the
728+
data, entrusting the user that the schema is valid. _Use this option
729+
carefully_.
730+
731+
kwargs: Additional keyword arguments passed directly to
732+
:meth:`polars.read_parquet`.
649733
650734
Returns:
651735
The initialized collection.
652736
653737
Raises:
738+
ValidationRequiredError: If no collection schema can be read from the
739+
directory and ``validation`` is set to ``"forbid"``.
654740
ValueError: If the provided directory does not contain parquet files for
655741
all required members.
656742
ValidationError: If the collection cannot be validate.
657743
658-
Note:
659-
If you are certain that your Parquet files contain valid data, you can also
660-
use :meth:`scan_parquet` to prevent the runtime overhead of validation.
744+
Attention:
745+
Be aware that this method suffers from the same limitations as
746+
:meth:`serialize`.
661747
"""
662-
data = {
663-
key: pl.scan_parquet(directory / f"{key}.parquet")
664-
for key in cls.members()
665-
if (directory / f"{key}.parquet").exists()
666-
}
667-
return cls.validate(data)
748+
path = Path(directory)
749+
data = cls._from_parquet(path, scan=True, **kwargs)
750+
if not cls._requires_validation_for_reading_parquets(path, validation):
751+
cls._validate_input_keys(data)
752+
return cls._init(data)
753+
return cls.validate(data, cast=True)
668754

669755
@classmethod
670-
def scan_parquet(cls, directory: Path) -> Self:
671-
"""Lazily read all collection members from Parquet files in a directory.
756+
def scan_parquet(
757+
cls,
758+
directory: str | Path,
759+
*,
760+
validation: Validation = "warn",
761+
**kwargs: Any,
762+
) -> Self:
763+
"""Lazily read all collection members from parquet files in a directory.
672764
673765
This method searches for files named ``<member>.parquet`` in the provided
674766
directory for all required and optional members of the collection.
675767
676768
Args:
677769
directory: The directory where the Parquet files should be read from.
770+
Parquet files may have been written with Hive partitioning.
771+
validation: The strategy for running validation when reading the data:
772+
773+
- ``"allow"`: The method tries to read the ``schema.json`` file in the
774+
directory. If the stored collection schema matches this collection
775+
schema, the collection is read without validation. If the stored
776+
schema mismatches this schema or no ``schema.json`` can be found in
777+
the directory, this method automatically runs :meth:`validate` with
778+
``cast=True``.
779+
- ``"warn"`: The method behaves similarly to ``"allow"``. However,
780+
it prints a warning if validation is necessary.
781+
- ``"forbid"``: The method never runs validation automatically and only
782+
returns if the ``schema.json`` stores a collection schema that matches
783+
this collection.
784+
- ``"skip"``: The method never runs validation and simply reads the
785+
data, entrusting the user that the schema is valid. _Use this option
786+
carefully_.
787+
788+
kwargs: Additional keyword arguments passed directly to
789+
:meth:`polars.scan_parquet` for all members.
678790
679791
Returns:
680792
The initialized collection.
681793
682794
Raises:
795+
ValidationRequiredError: If no collection schema can be read from the
796+
directory and ``validation`` is set to ``"forbid"``.
683797
ValueError: If the provided directory does not contain parquet files for
684798
all required members.
685799
686800
Note:
687-
If you want to eagerly read all Parquet files, consider calling
688-
:meth:`collect_all` on the returned collection.
801+
Due to current limitations in dataframely, this method actually reads the
802+
parquet file into memory if ``"validation"`` is ``"warn"`` or ``"allow"``
803+
and validation is required.
689804
690805
Attention:
691-
This method does **not** validate the contents of the Parquet file. Consider
692-
using :meth:`read_parquet` if you want to validate the collection.
806+
Be aware that this method suffers from the same limitations as
807+
:meth:`serialize`.
693808
"""
694-
data = {
695-
key: pl.scan_parquet(directory / f"{key}.parquet")
696-
for key in cls.members()
697-
if (directory / f"{key}.parquet").exists()
698-
}
699-
cls._validate_input_keys(data)
700-
return cls._init(data)
809+
path = Path(directory)
810+
data = cls._from_parquet(path, scan=True, **kwargs)
811+
if not cls._requires_validation_for_reading_parquets(path, validation):
812+
cls._validate_input_keys(data)
813+
return cls._init(data)
814+
return cls.validate(data, cast=True)
815+
816+
@classmethod
817+
def _from_parquet(
818+
cls, path: Path, scan: bool, **kwargs: Any
819+
) -> dict[str, pl.LazyFrame]:
820+
data = {}
821+
for key in cls.members():
822+
if (source_path := cls._member_source_path(path, key)) is not None:
823+
data[key] = (
824+
pl.scan_parquet(source_path, **kwargs)
825+
if scan
826+
else pl.read_parquet(source_path, **kwargs).lazy()
827+
)
828+
return data
829+
830+
@classmethod
831+
def _member_source_path(cls, base_path: Path, name: str) -> Path | None:
832+
if (path := base_path / name).exists() and base_path.is_dir():
833+
# We assume that the member is stored as a hive-partitioned dataset
834+
return path
835+
if (path := base_path / f"{name}.parquet").exists():
836+
# We assume that the member is stored as a single parquet file
837+
return path
838+
return None
839+
840+
@classmethod
841+
def _requires_validation_for_reading_parquets(
842+
cls,
843+
directory: Path,
844+
validation: Validation,
845+
) -> bool:
846+
if validation == "skip":
847+
return False
848+
849+
# First, we check whether the path provides the serialization of the collection.
850+
# If it does, we check whether it matches this collection. If it does, we assume
851+
# that the data adheres to the collection and we do not need to run validation.
852+
if (json_serialization := directory / "schema.json").exists():
853+
metadata = json_serialization.read_text()
854+
serialized_collection = deserialize_collection(metadata)
855+
if cls.matches(serialized_collection):
856+
return False
857+
else:
858+
serialized_collection = None
859+
860+
# Otherwise, we definitely need to run validation. However, we emit different
861+
# information to the user depending on the value of `validate`.
862+
msg = (
863+
"current collection schema does not match stored collection schema"
864+
if serialized_collection is not None
865+
else "no collection schema to check validity can be read from the source"
866+
)
867+
if validation == "forbid":
868+
raise ValidationRequiredError(
869+
f"Cannot read collection from '{directory!r}' without validation: {msg}."
870+
)
871+
if validation == "warn":
872+
warnings.warn(
873+
f"Reading parquet file from '{directory!r}' requires validation: {msg}."
874+
)
875+
return True
701876

702877
# ----------------------------------- UTILITIES ---------------------------------- #
703878

dataframely/schema.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ def read_parquet(
783783
784784
Raises:
785785
ValidationRequiredError: If no schema information can be read from the
786-
source and ``validate`` is set to ``False``.
786+
source and ``validation`` is set to ``"forbid"``.
787787
788788
Attention:
789789
Be aware that this method suffers from the same limitations as
@@ -824,7 +824,7 @@ def scan_parquet(
824824
- ``"skip"``: The method never runs validation and simply reads the
825825
parquet file, entrusting the user that the schema is valid. _Use this
826826
option carefully and consider replacing it with
827-
:meth:`polars.read_parquet` to convey the purpose better_.
827+
:meth:`polars.scan_parquet` to convey the purpose better_.
828828
829829
kwargs: Additional keyword arguments passed directly to
830830
:meth:`polars.scan_parquet`.
@@ -834,11 +834,11 @@ def scan_parquet(
834834
835835
Raises:
836836
ValidationRequiredError: If no schema information can be read from the
837-
source and ``validate`` is set to ``False``.
837+
source and ``validation`` is set to ``"forbid"``.
838838
839839
Note:
840840
Due to current limitations in dataframely, this method actually reads the
841-
parquet file into memory if ``validate`` is ``"auto"`` or ``True`` and
841+
parquet file into memory if ``validation`` is ``"warn"`` or ``"allow"`` and
842842
validation is required.
843843
844844
Attention:

tests/collection/test_base.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
# Copyright (c) QuantCo 2025-2025
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
from collections.abc import Callable
5-
from pathlib import Path
6-
74
import polars as pl
85
import pytest
9-
from polars.testing import assert_frame_equal
106

117
import dataframely as dy
128

@@ -113,39 +109,3 @@ def test_collect_all_optional() -> None:
113109
assert isinstance(out, MyCollection)
114110
assert len(out.first.collect()) == 3
115111
assert out.second is None
116-
117-
118-
@pytest.mark.parametrize(
119-
"read_fn", [MyCollection.scan_parquet, MyCollection.read_parquet]
120-
)
121-
def test_read_write_parquet(
122-
tmp_path: Path, read_fn: Callable[[Path], MyCollection]
123-
) -> None:
124-
collection = MyCollection.cast(
125-
{
126-
"first": pl.LazyFrame({"a": [1, 2, 3]}),
127-
"second": pl.LazyFrame({"a": [1, 2], "b": [10, 15]}),
128-
}
129-
)
130-
collection.write_parquet(tmp_path)
131-
132-
read = read_fn(tmp_path)
133-
assert_frame_equal(collection.first, read.first)
134-
assert collection.second is not None
135-
assert read.second is not None
136-
assert_frame_equal(collection.second, read.second)
137-
138-
139-
@pytest.mark.parametrize(
140-
"read_fn", [MyCollection.scan_parquet, MyCollection.read_parquet]
141-
)
142-
def test_read_write_parquet_optional(
143-
tmp_path: Path, read_fn: Callable[[Path], MyCollection]
144-
) -> None:
145-
collection = MyCollection.cast({"first": pl.LazyFrame({"a": [1, 2, 3]})})
146-
collection.write_parquet(tmp_path)
147-
148-
read = read_fn(tmp_path)
149-
assert_frame_equal(collection.first, read.first)
150-
assert collection.second is None
151-
assert read.second is None

0 commit comments

Comments
 (0)