1212
1313import polars as pl
1414
15- from ._rule import DtypeCastRule , GroupRule , Rule
15+ from ._rule import DtypeCastRule , GroupRule , Rule , RuleFactory
1616from .columns import Column
1717from .exc import ImplementationError
1818
@@ -81,7 +81,7 @@ class Metadata:
8181 """Utility class to gather columns and rules associated with a schema."""
8282
8383 columns : dict [str , Column ] = field (default_factory = dict )
84- rules : dict [str , Rule ] = field (default_factory = dict )
84+ rules : dict [str , RuleFactory ] = field (default_factory = dict )
8585
8686 def update (self , other : Self ) -> None :
8787 self .columns .update (other .columns )
@@ -102,7 +102,11 @@ def __new__(
102102 result .update (mcs ._get_metadata_recursively (base ))
103103 result .update (mcs ._get_metadata (namespace ))
104104 namespace [_COLUMN_ATTR ] = result .columns
105- namespace [_RULE_ATTR ] = result .rules
105+ cls = super ().__new__ (mcs , name , bases , namespace , * args , ** kwargs )
106+
107+ # Assign rules retroactively as we only encounter rule factories in the result
108+ rules = {name : factory .make (cls ) for name , factory in result .rules .items ()}
109+ setattr (cls , _RULE_ATTR , rules )
106110
107111 # At this point, we already know all columns and custom rules. We want to run
108112 # some checks...
@@ -111,7 +115,7 @@ def __new__(
111115 # we assume that users cast dtypes, i.e. additional rules for dtype casting
112116 # are also checked.
113117 all_column_names = set (result .columns )
114- all_rule_names = set (_build_rules (result . rules , result .columns , with_cast = True ))
118+ all_rule_names = set (_build_rules (rules , result .columns , with_cast = True ))
115119 common_names = all_column_names & all_rule_names
116120 if len (common_names ) > 0 :
117121 common_list = ", " .join (sorted (f"'{ col } '" for col in common_names ))
@@ -121,7 +125,7 @@ def __new__(
121125 )
122126
123127 # 2) Check that the columns referenced in the group rules exist.
124- for rule_name , rule in result . rules .items ():
128+ for rule_name , rule in rules .items ():
125129 if isinstance (rule , GroupRule ):
126130 missing_columns = set (rule .group_columns ) - set (result .columns )
127131 if len (missing_columns ) > 0 :
@@ -138,6 +142,7 @@ def __new__(
138142 for attr , value in namespace .items ():
139143 if attr .startswith ("__" ):
140144 continue
145+
141146 # Check for tuple of column (commonly caused by trailing comma)
142147 if (
143148 isinstance (value , tuple )
@@ -157,7 +162,7 @@ def __new__(
157162 f"Did you forget to add parentheses?"
158163 )
159164
160- return super (). __new__ ( mcs , name , bases , namespace , * args , ** kwargs )
165+ return cls
161166
162167 def __getattribute__ (cls , name : str ) -> Any :
163168 val = super ().__getattribute__ (name )
@@ -182,7 +187,7 @@ def _get_metadata(source: dict[str, Any]) -> Metadata:
182187 }.items ():
183188 if isinstance (value , Column ):
184189 result .columns [value .alias or attr ] = value
185- if isinstance (value , Rule ):
190+ if isinstance (value , RuleFactory ):
186191 # We must ensure that custom rules do not clash with internal rules.
187192 if attr == "primary_key" :
188193 raise ImplementationError (
0 commit comments