8
8
from dataclasses import Field as DataclassField
9
9
from dataclasses import asdict as _asdict
10
10
from dataclasses import astuple as _astuple
11
- from dataclasses import dataclass , fields as dataclass_fields , is_dataclass
12
- from typing import Any , ClassVar , Dict , List , Optional , Tuple , Type
11
+ from dataclasses import dataclass
12
+ from dataclasses import fields as dataclass_fields
13
+ from dataclasses import is_dataclass
14
+ from typing import Any , Callable , ClassVar , Dict , Iterator , List , Optional , Tuple , Type
13
15
14
16
import jinja2
15
17
import stringcase
@@ -75,8 +77,13 @@ def serialize(self, ser, **opts) -> None:
75
77
return ser .serialize (self , ** opts )
76
78
77
79
setattr (cls , SE_NAME , serialize )
78
- cls = se_func (cls , TO_ITER , render_astuple (cls ))
79
- cls = se_func (cls , TO_DICT , render_asdict (cls , rename_all ))
80
+
81
+ g : Dict [str , Any ] = globals ().copy ()
82
+ for f in fields (cls ):
83
+ if f .skip_if :
84
+ g [f .skip_if .mangled ] = f .skip_if
85
+ cls = se_func (cls , TO_ITER , render_astuple (cls ), g )
86
+ cls = se_func (cls , TO_DICT , render_asdict (cls , rename_all ), g )
80
87
return cls
81
88
82
89
if _cls is None :
@@ -185,11 +192,28 @@ class Field:
185
192
case : Optional [str ] = None
186
193
rename : Optional [str ] = None
187
194
skip : Optional [bool ] = None
195
+ skip_if : Optional [Callable [[Any ], bool ]] = None
188
196
skip_if_false : Optional [bool ] = None
189
197
190
198
@staticmethod
191
199
def from_dataclass (f : DataclassField ) -> '' :
192
- return Field (f .type , f .name , rename = f .metadata .get ('serde_rename' ), skip = f .metadata .get ('serde_skip' ), skip_if_false = f .metadata .get ('serde_skip_if_false' ))
200
+ if f .metadata .get ('serde_skip_if_false' ):
201
+ skip_if_false = lambda v : not bool (v )
202
+ skip_if_false .mangled = Field .mangle (f , 'skip_if' )
203
+ else :
204
+ skip_if_false = None
205
+
206
+ skip_if = f .metadata .get ('serde_skip_if' )
207
+ if skip_if :
208
+ skip_if .mangled = Field .mangle (f , 'skip_if' )
209
+
210
+ return Field (
211
+ f .type ,
212
+ f .name ,
213
+ rename = f .metadata .get ('serde_rename' ),
214
+ skip = f .metadata .get ('serde_skip' ),
215
+ skip_if = skip_if or skip_if_false ,
216
+ )
193
217
194
218
@property
195
219
def varname (self ) -> str :
@@ -203,9 +227,13 @@ def __getitem__(self, n) -> 'Field':
203
227
typ = type_args (self .type )[n ]
204
228
return Field (typ , None )
205
229
230
+ @staticmethod
231
+ def mangle (field : DataclassField , name : str ) -> str :
232
+ return f'{ field .name } _{ name } '
233
+
206
234
207
- def fields (cls : Type ) -> List [Field ]:
208
- return [ Field .from_dataclass (f ) for f in dataclass_fields (cls )]
235
+ def fields (cls : Type ) -> Iterator [Field ]:
236
+ return iter ( Field .from_dataclass (f ) for f in dataclass_fields (cls ))
209
237
210
238
211
239
def to_arg (f : Field ) -> Field :
@@ -248,14 +276,16 @@ def {{func}}(obj):
248
276
{% if cls|is_dataclass %}
249
277
res = {}
250
278
{% for f in cls|fields -%}
251
- {% if not f.skip|default(False) %}
252
- {% if f.skip_if_false|default(False) %}
253
- if {{f|arg|rvalue()}}:
279
+
280
+ {% if not f.skip %}
281
+ {% if f.skip_if %}
282
+ if not {{f.skip_if.mangled}}({{f|arg|rvalue()}}):
254
283
res["{{f|case}}"] = {{f|arg|rvalue()}}
255
284
{% else %}
256
285
res["{{f|case}}"] = {{f|arg|rvalue()}}
257
286
{% endif %}
258
287
{% endif %}
288
+
259
289
{% endfor -%}
260
290
return res
261
291
{% endif %}
@@ -359,16 +389,18 @@ def primitive(self, arg: Field) -> str:
359
389
return f'{ arg .varname } '
360
390
361
391
362
- def se_func (cls : Type [T ], func : str , code : str ) -> Type [T ]:
392
+ def se_func (cls : Type [T ], func : str , code : str , g : Dict = None , local : Dict = None ) -> Type [T ]:
363
393
"""
364
394
Generate function to serialize into an object.
365
395
"""
366
- g : Dict [str , Any ] = globals ().copy ()
367
-
368
396
# Generate serialize function.
369
- code = gen (code , g , cls = cls )
397
+ if not g :
398
+ g = globals ().copy ()
399
+ if not local :
400
+ local = locals ().copy ()
401
+ code = gen (code , g , local , cls = cls )
370
402
371
- setattr (cls , func , g [func ])
403
+ setattr (cls , func , local [func ])
372
404
if SETTINGS ['debug' ]:
373
405
hidden = getattr (cls , HIDDEN_NAME )
374
406
hidden .code [func ] = code
0 commit comments