Skip to content

Commit 5d1c4d7

Browse files
feat: Support Dataframes as Fields in Pydantic Models (#148)
Co-authored-by: Andreas Albert <[email protected]>
1 parent fbecd10 commit 5d1c4d7

File tree

10 files changed

+572
-29
lines changed

10 files changed

+572
-29
lines changed

dataframely/_base_schema.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
import sys
77
import textwrap
8-
from abc import ABCMeta
8+
from abc import ABCMeta, abstractmethod
99
from copy import copy
1010
from dataclasses import dataclass, field
11-
from typing import Any
11+
from typing import TYPE_CHECKING, Any
1212

1313
import polars as pl
1414

@@ -21,6 +21,10 @@
2121
else:
2222
from typing_extensions import Self
2323

24+
25+
if TYPE_CHECKING:
26+
from ._typing import DataFrame
27+
2428
_COLUMN_ATTR = "__dataframely_columns__"
2529
_RULE_ATTR = "__dataframely_rules__"
2630

@@ -198,11 +202,46 @@ def columns(cls) -> dict[str, Column]:
198202
columns[name]._name = name
199203
return columns
200204

205+
@classmethod
206+
@abstractmethod
207+
def polars_schema(cls) -> pl.Schema:
208+
"""Obtain the polars schema for this schema.
209+
210+
Returns:
211+
A :mod:`polars` schema that mirrors the schema defined by this class.
212+
"""
213+
201214
@classmethod
202215
def primary_keys(cls) -> list[str]:
203216
"""The primary key columns in this schema (possibly empty)."""
204217
return _primary_keys(cls.columns())
205218

219+
@classmethod
220+
@abstractmethod
221+
def validate(
222+
cls, df: pl.DataFrame | pl.LazyFrame, /, *, cast: bool = False
223+
) -> DataFrame[Self]:
224+
"""Validate that a data frame satisfies the schema.
225+
226+
Args:
227+
df: The data frame to validate.
228+
cast: Whether columns with a wrong data type in the input data frame are
229+
cast to the schema's defined data type if possible.
230+
231+
Returns:
232+
The (collected) input data frame, wrapped in a generic version of the
233+
input's data frame type to reflect schema adherence. The data frame is
234+
guaranteed to maintain its order.
235+
236+
Raises:
237+
ValidationError: If the input data frame does not satisfy the schema
238+
definition.
239+
240+
Note:
241+
This method _always_ collects the input data frame in order to raise
242+
potential validation errors.
243+
"""
244+
206245
@classmethod
207246
def _validation_rules(cls, *, with_cast: bool) -> dict[str, Rule]:
208247
return _build_rules(

dataframely/_compat.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,19 @@ class Dialect: # type: ignore # noqa: N801
5454
except ImportError: # pragma: no cover
5555
pa = _DummyModule("pyarrow")
5656

57+
58+
# -------------------------------------- PYDANTIC ------------------------------------ #
59+
60+
try:
61+
import pydantic
62+
except ImportError: # pragma: no cover
63+
pydantic = _DummyModule("pydantic") # type: ignore
64+
65+
try:
66+
from pydantic_core import core_schema as pydantic_core_schema # pragma: no cover
67+
except ImportError:
68+
pydantic_core_schema = _DummyModule("pydantic_core_schema") # type: ignore
69+
5770
# ------------------------------------------------------------------------------------ #
5871

5972
__all__ = [
@@ -64,4 +77,6 @@ class Dialect: # type: ignore # noqa: N801
6477
"pa",
6578
"MSDialect_pyodbc",
6679
"PGDialect_psycopg2",
80+
"pydantic",
81+
"pydantic_core_schema",
6782
]

dataframely/_pydantic.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) QuantCo 2025-2025
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
from __future__ import annotations
4+
5+
from functools import partial
6+
from typing import TYPE_CHECKING, Literal, TypeVar, get_args, get_origin, overload
7+
8+
import polars as pl
9+
10+
from ._base_schema import BaseSchema
11+
from ._compat import pydantic, pydantic_core_schema
12+
from .exc import ValidationError
13+
14+
if TYPE_CHECKING:
15+
from ._typing import DataFrame, LazyFrame
16+
17+
18+
_S = TypeVar("_S", bound=BaseSchema)
19+
20+
21+
def _dict_to_df(schema_type: type[BaseSchema], data: dict) -> pl.DataFrame:
22+
return pl.from_dict(
23+
data,
24+
schema=schema_type.polars_schema(),
25+
)
26+
27+
28+
def _validate_df_schema(schema_type: type[_S], df: pl.DataFrame) -> DataFrame[_S]:
29+
try:
30+
return schema_type.validate(df, cast=False)
31+
except ValidationError as e:
32+
raise ValueError("DataFrame violates schema") from e
33+
34+
35+
def _serialize_df(df: pl.DataFrame) -> dict:
36+
return df.to_dict(as_series=False)
37+
38+
39+
@overload
40+
def get_pydantic_core_schema(
41+
source_type: type[DataFrame],
42+
_handler: pydantic.GetCoreSchemaHandler,
43+
lazy: Literal[False],
44+
) -> pydantic_core_schema.CoreSchema: ...
45+
46+
47+
@overload
48+
def get_pydantic_core_schema(
49+
source_type: type[LazyFrame],
50+
_handler: pydantic.GetCoreSchemaHandler,
51+
lazy: Literal[True],
52+
) -> pydantic_core_schema.CoreSchema: ...
53+
54+
55+
def get_pydantic_core_schema(
56+
source_type: type[DataFrame | LazyFrame],
57+
_handler: pydantic.GetCoreSchemaHandler,
58+
lazy: bool,
59+
) -> pydantic_core_schema.CoreSchema:
60+
# https://docs.pydantic.dev/2.11/concepts/types/#handling-custom-generic-classes
61+
origin = get_origin(source_type)
62+
if origin is None:
63+
# used as `x: dy.DataFrame` without schema
64+
raise TypeError("DataFrame must be parametrized with a schema")
65+
66+
schema_type: type[BaseSchema] = get_args(source_type)[0]
67+
68+
# accept a DataFrame, a LazyFrame, or a dict that is converted to a DataFrame
69+
# (-> output: DataFrame or LazyFrame)
70+
polars_schema = pydantic_core_schema.union_schema(
71+
[
72+
pydantic_core_schema.is_instance_schema(pl.DataFrame),
73+
pydantic_core_schema.is_instance_schema(pl.LazyFrame),
74+
pydantic_core_schema.chain_schema(
75+
[
76+
pydantic_core_schema.dict_schema(),
77+
pydantic_core_schema.no_info_plain_validator_function(
78+
partial(_dict_to_df, schema_type)
79+
),
80+
]
81+
),
82+
]
83+
)
84+
85+
to_lazy_schema = []
86+
if lazy:
87+
# If the Pydantic field type is LazyFrame, add a step to convert
88+
# the model back to a LazyFrame.
89+
to_lazy_schema.append(
90+
pydantic_core_schema.no_info_plain_validator_function(
91+
lambda df: df.lazy(),
92+
)
93+
)
94+
95+
return pydantic_core_schema.chain_schema(
96+
[
97+
polars_schema,
98+
pydantic_core_schema.no_info_plain_validator_function(
99+
partial(_validate_df_schema, schema_type)
100+
),
101+
*to_lazy_schema,
102+
],
103+
serialization=pydantic_core_schema.plain_serializer_function_ser_schema(
104+
_serialize_df
105+
),
106+
)
107+
108+
109+
def get_pydantic_json_schema(
110+
handler: pydantic.GetJsonSchemaHandler,
111+
) -> pydantic.json_schema.JsonSchemaValue:
112+
from pydantic_core import core_schema
113+
114+
# This could be made more sophisticated by actually reflecting the schema.
115+
return handler(core_schema.dict_schema())

dataframely/_typing.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import polars as pl
1010

1111
from ._base_schema import BaseSchema
12+
from ._compat import pydantic, pydantic_core_schema
13+
from ._pydantic import get_pydantic_core_schema, get_pydantic_json_schema
1214

1315
S = TypeVar("S", bound=BaseSchema, covariant=True)
1416

@@ -70,6 +72,20 @@ def set_sorted(self, *args: Any, **kwargs: Any) -> DataFrame[S]:
7072
def shrink_to_fit(self, *args: Any, **kwargs: Any) -> DataFrame[S]:
7173
raise NotImplementedError # pragma: no cover
7274

75+
@classmethod
76+
def __get_pydantic_core_schema__(
77+
cls, source_type: Any, handler: pydantic.GetCoreSchemaHandler
78+
) -> pydantic_core_schema.CoreSchema:
79+
return get_pydantic_core_schema(source_type, handler, lazy=False)
80+
81+
@classmethod
82+
def __get_pydantic_json_schema__(
83+
cls,
84+
_core_schema: pydantic_core_schema.CoreSchema,
85+
handler: pydantic.GetJsonSchemaHandler,
86+
) -> pydantic.json_schema.JsonSchemaValue:
87+
return get_pydantic_json_schema(handler)
88+
7389

7490
class LazyFrame(pl.LazyFrame, Generic[S]):
7591
"""Generic wrapper around a :class:`polars.LazyFrame` to attach schema information.
@@ -113,3 +129,17 @@ def pipe(
113129
@inherit_signature(pl.LazyFrame.set_sorted)
114130
def set_sorted(self, *args: Any, **kwargs: Any) -> LazyFrame[S]:
115131
raise NotImplementedError # pragma: no cover
132+
133+
@classmethod
134+
def __get_pydantic_core_schema__(
135+
cls, source_type: Any, handler: pydantic.GetCoreSchemaHandler
136+
) -> pydantic_core_schema.CoreSchema:
137+
return get_pydantic_core_schema(source_type, handler, lazy=True)
138+
139+
@classmethod
140+
def __get_pydantic_json_schema__(
141+
cls,
142+
_core_schema: pydantic_core_schema.CoreSchema,
143+
handler: pydantic.GetJsonSchemaHandler,
144+
) -> pydantic.json_schema.JsonSchemaValue:
145+
return get_pydantic_json_schema(handler)

dataframely/schema.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -414,26 +414,6 @@ def _sampling_overrides(cls) -> dict[str, pl.Expr]:
414414
def validate(
415415
cls, df: pl.DataFrame | pl.LazyFrame, /, *, cast: bool = False
416416
) -> DataFrame[Self]:
417-
"""Validate that a data frame satisfies the schema.
418-
419-
Args:
420-
df: The data frame to validate.
421-
cast: Whether columns with a wrong data type in the input data frame are
422-
cast to the schema's defined data type if possible.
423-
424-
Returns:
425-
The (collected) input data frame, wrapped in a generic version of the
426-
input's data frame type to reflect schema adherence. The data frame is
427-
guaranteed to maintain its order.
428-
429-
Raises:
430-
ValidationError: If the input data frame does not satisfy the schema
431-
definition.
432-
433-
Note:
434-
This method _always_ collects the input data frame in order to raise
435-
potential validation errors.
436-
"""
437417
# We can dispatch to the `filter` method and raise an error if any row cannot
438418
# be validated
439419
df_valid, failures = cls.filter(df, cast=cast)
@@ -1118,11 +1098,6 @@ def _validate_if_needed(
11181098

11191099
@classmethod
11201100
def polars_schema(cls) -> pl.Schema:
1121-
"""Obtain the polars schema for this schema.
1122-
1123-
Returns:
1124-
A :mod:`polars` schema that mirrors the schema defined by this class.
1125-
"""
11261101
return pl.Schema({name: col.dtype for name, col in cls.columns().items()})
11271102

11281103
@classmethod

docs/sites/quickstart.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ Lastly, ``dataframely`` schemas can be used to integrate with external tools:
253253
- ``HouseSchema.create_empty()`` creates an empty ``dy.DataFrame[HouseSchema]`` that can be used for testing
254254
- ``HouseSchema.sql_schema()`` provides a list of `sqlalchemy <https://www.sqlalchemy.org>`_ columns that can be used to create SQL tables using types and constraints in line with the schema
255255
- ``HouseSchema.pyarrow_schema()`` provides a `pyarrow <https://arrow.apache.org/docs/python/index.html>`_ schema with appropriate column dtypes and nullability information
256+
- You can use ``dy.DataFrame[HouseSchema]`` (or the ``LazyFrame`` equivalent) as fields in `pydantic <https://pydantic.dev>`_ models, including support for validation and serialization. Integration with pydantic is unstable.
256257

257258

258259
Outlook

0 commit comments

Comments
 (0)