Skip to content

Commit 17aa1a5

Browse files
authored
refactor!: Make @dy.rule decorator apply to classmethods (#198)
1 parent 9fc131f commit 17aa1a5

File tree

20 files changed

+111
-149
lines changed

20 files changed

+111
-149
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ jobs:
3131
run: rustup show
3232
- name: Cache Rust dependencies
3333
uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1
34-
- name: Install repository
35-
run: pixi run -e default postinstall
3634
- name: pre-commit
3735
run: pixi run pre-commit-run --color=always --show-diff-on-failure
3836

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ class HouseSchema(dy.Schema):
5252
price = dy.Float64(nullable=False)
5353

5454
@dy.rule()
55-
def reasonable_bathroom_to_bedroom_ratio() -> pl.Expr:
55+
def reasonable_bathroom_to_bedroom_ratio(cls) -> pl.Expr:
5656
ratio = pl.col("num_bathrooms") / pl.col("num_bedrooms")
5757
return (ratio >= 1 / 3) & (ratio <= 3)
5858

5959
@dy.rule(group_by=["zip_code"])
60-
def minimum_zip_code_count() -> pl.Expr:
60+
def minimum_zip_code_count(cls) -> pl.Expr:
6161
return pl.len() >= 2
6262
```
6363

dataframely/_base_schema.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import polars as pl
1414

15-
from ._rule import DtypeCastRule, GroupRule, Rule
15+
from ._rule import DtypeCastRule, GroupRule, Rule, RuleFactory
1616
from .columns import Column
1717
from .exc import ImplementationError
1818

@@ -81,7 +81,7 @@ class Metadata:
8181
"""Utility class to gather columns and rules associated with a schema."""
8282

8383
columns: dict[str, Column] = field(default_factory=dict)
84-
rules: dict[str, Rule] = field(default_factory=dict)
84+
rules: dict[str, RuleFactory] = field(default_factory=dict)
8585

8686
def update(self, other: Self) -> None:
8787
self.columns.update(other.columns)
@@ -102,7 +102,11 @@ def __new__(
102102
result.update(mcs._get_metadata_recursively(base))
103103
result.update(mcs._get_metadata(namespace))
104104
namespace[_COLUMN_ATTR] = result.columns
105-
namespace[_RULE_ATTR] = result.rules
105+
cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs)
106+
107+
# Assign rules retroactively as we only encounter rule factories in the result
108+
rules = {name: factory.make(cls) for name, factory in result.rules.items()}
109+
setattr(cls, _RULE_ATTR, rules)
106110

107111
# At this point, we already know all columns and custom rules. We want to run
108112
# some checks...
@@ -111,7 +115,7 @@ def __new__(
111115
# we assume that users cast dtypes, i.e. additional rules for dtype casting
112116
# are also checked.
113117
all_column_names = set(result.columns)
114-
all_rule_names = set(_build_rules(result.rules, result.columns, with_cast=True))
118+
all_rule_names = set(_build_rules(rules, result.columns, with_cast=True))
115119
common_names = all_column_names & all_rule_names
116120
if len(common_names) > 0:
117121
common_list = ", ".join(sorted(f"'{col}'" for col in common_names))
@@ -121,7 +125,7 @@ def __new__(
121125
)
122126

123127
# 2) Check that the columns referenced in the group rules exist.
124-
for rule_name, rule in result.rules.items():
128+
for rule_name, rule in rules.items():
125129
if isinstance(rule, GroupRule):
126130
missing_columns = set(rule.group_columns) - set(result.columns)
127131
if len(missing_columns) > 0:
@@ -138,6 +142,7 @@ def __new__(
138142
for attr, value in namespace.items():
139143
if attr.startswith("__"):
140144
continue
145+
141146
# Check for tuple of column (commonly caused by trailing comma)
142147
if (
143148
isinstance(value, tuple)
@@ -157,7 +162,7 @@ def __new__(
157162
f"Did you forget to add parentheses?"
158163
)
159164

160-
return super().__new__(mcs, name, bases, namespace, *args, **kwargs)
165+
return cls
161166

162167
def __getattribute__(cls, name: str) -> Any:
163168
val = super().__getattribute__(name)
@@ -182,7 +187,7 @@ def _get_metadata(source: dict[str, Any]) -> Metadata:
182187
}.items():
183188
if isinstance(value, Column):
184189
result.columns[value.alias or attr] = value
185-
if isinstance(value, Rule):
190+
if isinstance(value, RuleFactory):
186191
# We must ensure that custom rules do not clash with internal rules.
187192
if attr == "primary_key":
188193
raise ImplementationError(

dataframely/_rule.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
else:
1616
from typing_extensions import Self
1717

18-
ValidationFunction = Callable[[], pl.Expr]
18+
ValidationFunction = Callable[[Any], pl.Expr]
1919

2020

2121
class Rule:
2222
"""Internal class representing validation rules."""
2323

24-
def __init__(self, expr: pl.Expr | ValidationFunction) -> None:
24+
def __init__(self, expr: pl.Expr | Callable[[], pl.Expr]) -> None:
2525
self._expr = expr
2626

2727
@property
@@ -71,7 +71,7 @@ class GroupRule(Rule):
7171
"""Rule that is evaluated on a group of columns."""
7272

7373
def __init__(
74-
self, expr: pl.Expr | ValidationFunction, group_columns: list[str]
74+
self, expr: pl.Expr | Callable[[], pl.Expr], group_columns: list[str]
7575
) -> None:
7676
super().__init__(expr)
7777
self.group_columns = group_columns
@@ -92,7 +92,41 @@ def __repr__(self) -> str:
9292
return f"{super().__repr__()} grouped by {self.group_columns}"
9393

9494

95-
def rule(*, group_by: list[str] | None = None) -> Callable[[ValidationFunction], Rule]:
95+
# -------------------------------------- FACTORY ------------------------------------- #
96+
97+
98+
class RuleFactory:
99+
"""Factory class for rules created within schemas."""
100+
101+
def __init__(
102+
self, validation_fn: Callable[[Any], pl.Expr], group_columns: list[str] | None
103+
) -> None:
104+
self.validation_fn = validation_fn
105+
self.group_columns = group_columns
106+
107+
@classmethod
108+
def from_rule(cls, rule: Rule) -> Self:
109+
"""Create a rule factory from an existing rule."""
110+
if isinstance(rule, GroupRule):
111+
return cls(
112+
validation_fn=lambda _: rule.expr,
113+
group_columns=rule.group_columns,
114+
)
115+
return cls(validation_fn=lambda _: rule.expr, group_columns=None)
116+
117+
def make(self, schema: Any) -> Rule:
118+
"""Create a new rule from this factory."""
119+
if self.group_columns is not None:
120+
return GroupRule(
121+
expr=lambda: self.validation_fn(schema),
122+
group_columns=self.group_columns,
123+
)
124+
return Rule(expr=lambda: self.validation_fn(schema))
125+
126+
127+
def rule(
128+
*, group_by: list[str] | None = None
129+
) -> Callable[[ValidationFunction], RuleFactory]:
96130
"""Mark a function as a rule to evaluate during validation.
97131
98132
The name of the function will be used as the name of the rule. The function should
@@ -128,10 +162,8 @@ def rule(*, group_by: list[str] | None = None) -> Callable[[ValidationFunction],
128162
and (de-)serialization.
129163
"""
130164

131-
def decorator(validation_fn: ValidationFunction) -> Rule:
132-
if group_by is not None:
133-
return GroupRule(expr=validation_fn, group_columns=group_by)
134-
return Rule(expr=validation_fn)
165+
def decorator(validation_fn: ValidationFunction) -> RuleFactory:
166+
return RuleFactory(validation_fn=validation_fn, group_columns=group_by)
135167

136168
return decorator
137169

dataframely/mypy.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

dataframely/schema.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ._native import format_rule_failures
2525
from ._plugin import all_rules, all_rules_horizontal, all_rules_required
2626
from ._polars import collect_if
27-
from ._rule import Rule, rule_from_dict, with_evaluation_rules
27+
from ._rule import Rule, RuleFactory, rule_from_dict, with_evaluation_rules
2828
from ._serialization import (
2929
SERIALIZATION_FORMAT_VERSION,
3030
SchemaJSONDecoder,
@@ -1377,7 +1377,10 @@ def _schema_from_dict(data: dict[str, Any]) -> type[Schema]:
13771377
(Schema,),
13781378
{
13791379
**{name: column_from_dict(col) for name, col in data["columns"].items()},
1380-
**{name: rule_from_dict(rule) for name, rule in data["rules"].items()},
1380+
**{
1381+
name: RuleFactory.from_rule(rule_from_dict(rule))
1382+
for name, rule in data["rules"].items()
1383+
},
13811384
},
13821385
)
13831386

dataframely/testing/factory.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any
55

66
from dataframely._filter import Filter
7-
from dataframely._rule import Rule
7+
from dataframely._rule import Rule, RuleFactory
88
from dataframely._typing import LazyFrame
99
from dataframely.collection import Collection
1010
from dataframely.columns import Column
@@ -14,7 +14,7 @@
1414
def create_schema(
1515
name: str,
1616
columns: dict[str, Column],
17-
rules: dict[str, Rule] | None = None,
17+
rules: dict[str, Rule | RuleFactory] | None = None,
1818
) -> type[Schema]:
1919
"""Dynamically create a new schema with the provided name.
2020
@@ -23,12 +23,18 @@ def create_schema(
2323
columns: The columns to set on the schema. When properly defining the schema,
2424
this would be the annotations that define the column types.
2525
rules: The custom non-column-specific validation rules. When properly defining
26-
the schema, this would be the functions annotated with `@dy.rule`.
26+
the schema, this would be the functions annotated with ``@dy.rule``.
2727
2828
Returns:
2929
The dynamically created schema.
3030
"""
31-
return type(name, (Schema,), {**columns, **(rules or {})})
31+
rule_factories = {
32+
rule_name: (
33+
rule if isinstance(rule, RuleFactory) else RuleFactory.from_rule(rule)
34+
)
35+
for rule_name, rule in (rules or {}).items()
36+
}
37+
return type(name, (Schema,), {**columns, **rule_factories})
3238

3339

3440
def create_collection(

docs/guides/examples/real-world.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
},
8080
{
8181
"cell_type": "code",
82-
"execution_count": 4,
82+
"execution_count": null,
8383
"metadata": {},
8484
"outputs": [],
8585
"source": [
@@ -91,11 +91,11 @@
9191
" amount = dy.Decimal(nullable=False, min_exclusive=Decimal(0))\n",
9292
"\n",
9393
" @dy.rule()\n",
94-
" def discharge_after_admission() -> pl.Expr:\n",
94+
" def discharge_after_admission(cls) -> pl.Expr:\n",
9595
" return pl.col(\"discharge_date\") >= pl.col(\"admission_date\")\n",
9696
"\n",
9797
" @dy.rule()\n",
98-
" def received_at_after_discharge() -> pl.Expr:\n",
98+
" def received_at_after_discharge(cls) -> pl.Expr:\n",
9999
" return pl.col(\"received_at\").dt.date() >= pl.col(\"discharge_date\")"
100100
]
101101
},
@@ -318,7 +318,7 @@
318318
},
319319
{
320320
"cell_type": "code",
321-
"execution_count": 11,
321+
"execution_count": null,
322322
"metadata": {},
323323
"outputs": [],
324324
"source": [
@@ -328,7 +328,7 @@
328328
" is_main = dy.Bool(nullable=False)\n",
329329
"\n",
330330
" @dy.rule(group_by=[\"invoice_id\"])\n",
331-
" def exactly_one_main_diagnosis() -> pl.Expr:\n",
331+
" def exactly_one_main_diagnosis(cls) -> pl.Expr:\n",
332332
" return pl.col(\"is_main\").sum() == 1"
333333
]
334334
},
@@ -351,7 +351,7 @@
351351
},
352352
{
353353
"cell_type": "code",
354-
"execution_count": 12,
354+
"execution_count": null,
355355
"metadata": {},
356356
"outputs": [],
357357
"source": [
@@ -368,11 +368,11 @@
368368
" amount = dy.Decimal(nullable=False, min_exclusive=Decimal(0))\n",
369369
"\n",
370370
" @dy.rule()\n",
371-
" def discharge_after_admission() -> pl.Expr:\n",
371+
" def discharge_after_admission(cls) -> pl.Expr:\n",
372372
" return pl.col(\"discharge_date\") >= pl.col(\"admission_date\")\n",
373373
"\n",
374374
" @dy.rule()\n",
375-
" def received_at_after_discharge() -> pl.Expr:\n",
375+
" def received_at_after_discharge(cls) -> pl.Expr:\n",
376376
" return pl.col(\"received_at\").dt.date() >= pl.col(\"discharge_date\")\n",
377377
"\n",
378378
"\n",
@@ -381,7 +381,7 @@
381381
" is_main = dy.Bool(nullable=False)\n",
382382
"\n",
383383
" @dy.rule(group_by=[\"invoice_id\"])\n",
384-
" def exactly_one_main_diagnosis() -> pl.Expr:\n",
384+
" def exactly_one_main_diagnosis(cls) -> pl.Expr:\n",
385385
" return pl.col(\"is_main\").sum() == 1"
386386
]
387387
},

docs/guides/faq.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@ class UserSchema(dy.Schema):
2020
email = dy.String(nullable=True) # Must be unique, or null.
2121

2222
@dy.rule(group_by=["username"])
23-
def unique_username() -> pl.Expr:
23+
def unique_username(cls) -> pl.Expr:
2424
"""Username, a non-nullable field, must be total unique."""
2525
return pl.len() == 1
2626

2727
@dy.rule()
28-
def unique_email_or_null() -> pl.Expr:
28+
def unique_email_or_null(cls) -> pl.Expr:
2929
"""Email must be unique, if provided."""
3030
return pl.col("email").is_null() | pl.col("email").is_unique()
3131
```

0 commit comments

Comments
 (0)