Skip to content

Commit f7f6996

Browse files
committed
feat: add support for lazy annotations PEP563 (yukinarit#112)
* also move SerdeError to compat.py
1 parent f6a36ba commit f7f6996

File tree

4 files changed

+223
-29
lines changed

4 files changed

+223
-29
lines changed

serde/compat.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
"""
44
import dataclasses
55
import enum
6+
import sys
67
import typing
7-
from dataclasses import fields, is_dataclass
8+
from dataclasses import is_dataclass
89
from itertools import zip_longest
910
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
1011

@@ -15,6 +16,13 @@
1516
T = TypeVar('T')
1617

1718

19+
# moved SerdeError from core.py to compat.py to prevent circular dependency issues
20+
class SerdeError(TypeError):
21+
"""
22+
Serde error class.
23+
"""
24+
25+
1826
def get_origin(typ):
1927
"""
2028
Provide `get_origin` that works in all python versions.
@@ -128,13 +136,41 @@ def union_args(typ: Union) -> Tuple:
128136
return tuple(types)
129137

130138

139+
def dataclass_fields(cls: Type) -> Iterator:
140+
raw_fields = dataclasses.fields(cls)
141+
142+
try:
143+
# this resolves types when string forward reference
144+
# or PEP 563: "from __future__ import annotations" are used
145+
resolved_hints = typing.get_type_hints(cls)
146+
except Exception as e:
147+
raise SerdeError(
148+
f"Failed to resolve type hints for {typename(cls)}:\n"
149+
f"{e.__class__.__name__}: {e}\n\n"
150+
f"If you are using forward references make sure you are calling deserialize & serialize after all classes are globally visible."
151+
)
152+
153+
for f in raw_fields:
154+
real_type = resolved_hints.get(f.name)
155+
# python <= 3.6 has no typing.ForwardRef so we need to skip the check
156+
if sys.version_info[:2] != (3, 6) and isinstance(real_type, typing.ForwardRef):
157+
raise SerdeError(
158+
f"Failed to resolve {real_type} for {typename(cls)}.\n\n"
159+
f"Make sure you are calling deserialize & serialize after all classes are globally visible."
160+
)
161+
if real_type is not None:
162+
f.type = real_type
163+
164+
return iter(raw_fields)
165+
166+
131167
def iter_types(cls: Type) -> Iterator[Type]:
132168
"""
133169
Iterate field types recursively.
134170
"""
135171
if is_dataclass(cls):
136172
yield cls
137-
for f in fields(cls):
173+
for f in dataclass_fields(cls):
138174
yield from iter_types(f.type)
139175
elif isinstance(cls, str):
140176
yield cls
@@ -170,7 +206,7 @@ def iter_unions(cls: Type) -> Iterator[Type]:
170206
for arg in type_args(cls):
171207
yield from iter_unions(arg)
172208
if is_dataclass(cls):
173-
for f in fields(cls):
209+
for f in dataclass_fields(cls):
174210
yield from iter_unions(f.type)
175211
elif is_opt(cls):
176212
arg = type_args(cls)
@@ -369,9 +405,9 @@ def has_default(field) -> bool:
369405
... class C:
370406
... a: int
371407
... d: int = 10
372-
>>> has_default(fields(C)[0])
408+
>>> has_default(dataclasses.fields(C)[0])
373409
False
374-
>>> has_default(fields(C)[1])
410+
>>> has_default(dataclasses.fields(C)[1])
375411
True
376412
"""
377413
return not isinstance(field.default, dataclasses._MISSING_TYPE)
@@ -386,9 +422,9 @@ def has_default_factory(field) -> bool:
386422
... class C:
387423
... a: int
388424
... d: Dict = dataclasses.field(default_factory=dict)
389-
>>> has_default_factory(fields(C)[0])
425+
>>> has_default_factory(dataclasses.fields(C)[0])
390426
False
391-
>>> has_default_factory(fields(C)[1])
427+
>>> has_default_factory(dataclasses.fields(C)[1])
392428
True
393429
"""
394430
return not isinstance(field.default_factory, dataclasses._MISSING_TYPE)

serde/core.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import stringcase
1111

1212
from .compat import (
13+
SerdeError,
14+
dataclass_fields,
1315
is_bare_dict,
1416
is_bare_list,
1517
is_bare_set,
@@ -124,12 +126,6 @@ def _justify(self, s: str, length=50) -> str:
124126
return ' ' * (white_spaces if white_spaces > 0 else 0) + s
125127

126128

127-
class SerdeError(TypeError):
128-
"""
129-
Serde error class.
130-
"""
131-
132-
133129
def raise_unsupported_type(obj):
134130
# needed because we can not render a raise statement everywhere, e.g. as argument
135131
raise SerdeError(f"Unsupported type: {typename(type(obj))}")
@@ -297,7 +293,7 @@ def conv_name(self) -> str:
297293

298294

299295
def fields(FieldCls: Type, cls: Type) -> Iterator[Field]:
300-
return iter(FieldCls.from_dataclass(f) for f in dataclasses.fields(cls))
296+
return iter(FieldCls.from_dataclass(f) for f in dataclass_fields(cls))
301297

302298

303299
def conv(f: Field, case: Optional[str] = None) -> str:

tests/test_basics.py

+51-15
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
import more_itertools
1616
import pytest
1717

18-
import serde.compat
18+
import serde
1919
from serde import SerdeError, deserialize, from_dict, from_tuple, serialize, to_dict, to_tuple
20+
from serde.compat import dataclass_fields
2021
from serde.core import SERDE_SCOPE
2122
from serde.json import from_json, to_json
2223
from serde.msgpack import from_msgpack, to_msgpack
@@ -211,22 +212,57 @@ class Foo:
211212
i: int
212213

213214

214-
def test_forward_declaration():
215-
@serialize
216-
@deserialize
217-
@dataclass
218-
class Foo:
219-
bar: 'Bar'
215+
# test_string_forward_reference_works currently only works with global visible classes
216+
# and can not be mixed with PEP 563 "from __future__ import annotations"
217+
@dataclass
218+
class ForwardReferenceFoo:
219+
bar: 'ForwardReferenceBar'
220220

221-
@serialize
222-
@deserialize
223-
@dataclass
224-
class Bar:
225-
i: int
226221

227-
h = Foo(bar=Bar(i=10))
228-
assert h.bar.i == 10
229-
assert 'Bar' == dataclasses.fields(Foo)[0].type
222+
@serialize
223+
@deserialize
224+
@dataclass
225+
class ForwardReferenceBar:
226+
i: int
227+
228+
229+
# assert type is str
230+
assert 'ForwardReferenceBar' == dataclasses.fields(ForwardReferenceFoo)[0].type
231+
232+
# setup pyserde for Foo after Bar becomes visible to global scope
233+
deserialize(ForwardReferenceFoo)
234+
serialize(ForwardReferenceFoo)
235+
236+
# now the type really is of type Bar
237+
assert ForwardReferenceBar == dataclasses.fields(ForwardReferenceFoo)[0].type
238+
assert ForwardReferenceBar == next(dataclass_fields(ForwardReferenceFoo)).type
239+
240+
# verify usage works
241+
def test_string_forward_reference_works():
242+
h = ForwardReferenceFoo(bar=ForwardReferenceBar(i=10))
243+
h_dict = {"bar": {"i": 10}}
244+
245+
assert to_dict(h) == h_dict
246+
assert from_dict(ForwardReferenceFoo, h_dict) == h
247+
248+
249+
# trying to use string forward reference normally will throw
250+
def test_unresolved_forward_reference_throws():
251+
with pytest.raises(SerdeError) as e:
252+
253+
@serialize
254+
@deserialize
255+
@dataclass
256+
class UnresolvedForwardFoo:
257+
bar: 'UnresolvedForwardBar'
258+
259+
@serialize
260+
@deserialize
261+
@dataclass
262+
class UnresolvedForwardBar:
263+
i: int
264+
265+
assert "Failed to resolve type hints for UnresolvedForwardFoo" in str(e)
230266

231267

232268
@pytest.mark.parametrize('opt', opt_case, ids=opt_case_ids())

tests/test_lazy_type_evaluation.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from __future__ import annotations # this is the line this test file is all about
2+
3+
import dataclasses
4+
from dataclasses import dataclass
5+
from enum import Enum
6+
from typing import List, Tuple
7+
8+
import pytest
9+
10+
import serde
11+
from serde import SerdeError, deserialize, from_dict, serialize, to_dict
12+
from serde.compat import dataclass_fields
13+
14+
serde.init(True)
15+
16+
17+
class Status(Enum):
18+
OK = "ok"
19+
ERR = "err"
20+
21+
22+
@deserialize
23+
@serialize
24+
@dataclass
25+
class A:
26+
a: int
27+
b: Status
28+
c: List[str]
29+
30+
31+
@deserialize
32+
@serialize
33+
@dataclass
34+
class B:
35+
a: A
36+
b: Tuple[str, A]
37+
c: Status
38+
39+
40+
# only works with global classes
41+
def test_serde_with_lazy_type_annotations():
42+
a = A(1, Status.ERR, ["foo"])
43+
a_dict = {"a": 1, "b": "err", "c": ["foo"]}
44+
45+
assert a == from_dict(A, a_dict)
46+
assert a_dict == to_dict(a)
47+
48+
b = B(a, ("foo", a), Status.OK)
49+
b_dict = {"a": a_dict, "b": ("foo", a_dict), "c": "ok"}
50+
51+
assert b == from_dict(B, b_dict)
52+
assert b_dict == to_dict(b)
53+
54+
55+
# test_forward_reference_works currently only works with global visible classes
56+
@dataclass
57+
class ForwardReferenceFoo:
58+
# this is not a string forward reference because we use PEP 563 (see 1st line of this file)
59+
bar: ForwardReferenceBar
60+
61+
62+
@serialize
63+
@deserialize
64+
@dataclass
65+
class ForwardReferenceBar:
66+
i: int
67+
68+
69+
# assert type is str
70+
assert 'ForwardReferenceBar' == dataclasses.fields(ForwardReferenceFoo)[0].type
71+
72+
# setup pyserde for Foo after Bar becomes visible to global scope
73+
deserialize(ForwardReferenceFoo)
74+
serialize(ForwardReferenceFoo)
75+
76+
# now the type really is of type Bar
77+
assert ForwardReferenceBar == dataclasses.fields(ForwardReferenceFoo)[0].type
78+
assert ForwardReferenceBar == next(dataclass_fields(ForwardReferenceFoo)).type
79+
80+
# verify usage works
81+
def test_forward_reference_works():
82+
h = ForwardReferenceFoo(bar=ForwardReferenceBar(i=10))
83+
h_dict = {"bar": {"i": 10}}
84+
85+
assert to_dict(h) == h_dict
86+
assert from_dict(ForwardReferenceFoo, h_dict) == h
87+
88+
89+
# trying to use forward reference normally will throw
90+
def test_unresolved_forward_reference_throws():
91+
with pytest.raises(SerdeError) as e:
92+
93+
@serialize
94+
@deserialize
95+
@dataclass
96+
class UnresolvedForwardFoo:
97+
bar: UnresolvedForwardBar
98+
99+
@serialize
100+
@deserialize
101+
@dataclass
102+
class UnresolvedForwardBar:
103+
i: int
104+
105+
assert "Failed to resolve type hints for UnresolvedForwardFoo" in str(e)
106+
107+
108+
# trying to use string forward reference will throw
109+
def test_string_forward_reference_throws():
110+
with pytest.raises(SerdeError) as e:
111+
112+
@serialize
113+
@deserialize
114+
@dataclass
115+
class UnresolvedStringForwardFoo:
116+
# string forward references are not compatible with PEP 563 and will throw
117+
bar: 'UnresolvedStringForwardBar'
118+
119+
@serialize
120+
@deserialize
121+
@dataclass
122+
class UnresolvedStringForwardBar:
123+
i: int
124+
125+
# message is different between <= 3.8 & >= 3.9
126+
assert "Failed to resolve " in str(e.value)

0 commit comments

Comments
 (0)