Skip to content

Commit 30763e1

Browse files
authored
Merge pull request yukinarit#209 from kmsquire/feature/numpy
feat: Support numpy types
2 parents 0963c35 + 61b6130 commit 30763e1

13 files changed

+419
-12
lines changed

.github/workflows/test.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ jobs:
99
runs-on: ${{ matrix.os }}
1010
strategy:
1111
matrix:
12-
python-version: ["3.7", "3.8", "3.9", "3.10", "pypy-3.8"]
12+
# Remove pypy-3.8 until it supports numpy
13+
python-version: ["3.7", "3.8", "3.9", "3.10"] # "pypy-3.8"
1314
os: [ubuntu-20.04, macos-10.15, windows-2019]
1415
steps:
1516
- name: Checkout

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ venv/
1717
.idea/
1818
out/
1919
docs/CHANGELOG.md
20+
poetry.lock

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ repos:
88
- id: end-of-file-fixer
99

1010
- repo: https://github.com/psf/black
11-
rev: 19.3b0
11+
rev: 22.3.0
1212
hooks:
1313
- id: black
1414
args: [

docs/supported-data-formats.md

+73-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Call `to_json` to serialize `Foo` object into JSON string and `from_json` to des
4444
>>> from serde.json import to_json, from_json
4545

4646
>>> to_json(Foo(i=10, s='foo', f=100.0, b=True))
47-
{"i": 10, "s": "foo", "f": 100.0, "b": true}
47+
'{"i": 10, "s": "foo", "f": 100.0, "b": true}'
4848

4949
>>> from_json(Foo, '{"i": 10, "s": "foo", "f": 100.0, "b": true}')
5050
Foo(i=10, s='foo', f=100.0, b=True)
@@ -98,3 +98,75 @@ b'\x84\xa1i\n\xa1s\xa3foo\xa1f\xcb@Y\x00\x00\x00\x00\x00\x00\xa1b\xc3'
9898
>>> from_msgpack(Foo, b'\x84\xa1i\n\xa1s\xa3foo\xa1f\xcb@Y\x00\x00\x00\x00\x00\x00\xa1b\xc3')
9999
Foo(i=10, s='foo', f=100.0, b=True)
100100
```
101+
102+
## Numpy
103+
104+
All of the above (de)serialization methods can transparently handle most numpy
105+
types with the "numpy" extras package.
106+
107+
```python
108+
import numpy as np
109+
import numpy.typing as npt
110+
111+
@serde
112+
class NPFoo:
113+
i: np.int32
114+
j: np.int64
115+
f: np.float32
116+
g: np.float64
117+
h: np.bool_
118+
u: np.ndarray
119+
v: npt.NDArray
120+
w: npt.NDArray[np.int32]
121+
x: npt.NDArray[np.int64]
122+
y: npt.NDArray[np.float32]
123+
z: npt.NDArray[np.float64]
124+
125+
npfoo = NPFoo(
126+
np.int32(1),
127+
np.int64(2),
128+
np.float32(3.0),
129+
np.float64(4.0),
130+
np.bool_(False),
131+
np.array([1, 2]),
132+
np.array([3, 4]),
133+
np.array([np.int32(i) for i in [1, 2]]),
134+
np.array([np.int64(i) for i in [3, 4]]),
135+
np.array([np.float32(i) for i in [5.0, 6.0]]),
136+
np.array([np.float64(i) for i in [7.0, 8.0]]),
137+
)
138+
```
139+
140+
```python
141+
>>> from serde.json import to_json, from_json
142+
143+
>>> to_json(npfoo)
144+
'{"i": 1, "j": 2, "f": 3.0, "g": 4.0, "h": false, "u": [1, 2], "v": [3, 4], "w": [1, 2], "x": [3, 4], "y": [5.0, 6.0], "z": [7.0, 8.0]}'
145+
146+
>>> from_json(NPFoo, to_json(npfoo))
147+
NPFoo(i=1, j=2, f=3.0, g=4.0, h=False, u=array([1, 2]), v=array([3, 4]), w=array([1, 2], dtype=int32), x=array([3, 4]), y=array([5., 6.], dtype=float32), z=array([7., 8.]))
148+
149+
>>> from serde.yaml import from_yaml, to_yaml
150+
151+
>>> to_yaml(npfoo)
152+
'f: 3.0\ng: 4.0\nh: false\ni: 1\nj: 2\nu:\n- 1\n- 2\nv:\n- 3\n- 4\nw:\n- 1\n- 2\nx:\n- 3\n- 4\ny:\n- 5.0\n- 6.0\nz:\n- 7.0\n- 8.0\n'
153+
154+
>>> from_yaml(NPFoo, to_yaml(npfoo))
155+
NPFoo(i=1, j=2, f=3.0, g=4.0, h=False, u=array([1, 2]), v=array([3, 4]), w=array([1, 2], dtype=int32), x=array([3, 4]), y=array([5., 6.], dtype=float32), z=array([7., 8.]))
156+
157+
>>> from serde.toml import from_toml, to_toml
158+
159+
>>> to_toml(npfoo)
160+
'i = 1\nj = 2\nf = 3.0\ng = 4.0\nh = false\nu = [ 1, 2,]\nv = [ 3, 4,]\nw = [ 1, 2,]\nx = [ 3, 4,]\ny = [ 5.0, 6.0,]\nz = [ 7.0, 8.0,]\n'
161+
162+
>>> from_toml(NPFoo, to_toml(npfoo))
163+
NPFoo(i=1, j=2, f=3.0, g=4.0, h=False, u=array([1, 2]), v=array([3, 4]), w=array([1, 2], dtype=int32), x=array([3, 4]), y=array([5., 6.], dtype=float32), z=array([7., 8.]))
164+
165+
>>> from serde.msgpack import from_msgpack, to_msgpack
166+
167+
>>> to_msgpack(npfoo)
168+
b'\x8b\xa1i\x01\xa1j\x02\xa1f\xcb@\x08\x00\x00\x00\x00\x00\x00\xa1g\xcb@\x10\x00\x00\x00\x00\x00\x00\xa1h\xc2\xa1u\x92\x01\x02\xa1v\x92\x03\x04\xa1w\x92\x01\x02\xa1x\x92\x03\x04\xa1y\x92\xcb@\x14\x00\x00\x00\x00\x00\x00\xcb@\x18\x00\x00\x00\x00\x00\x00\xa1z\x92\xcb@\x1c\x00\x00\x00\x00\x00\x00\xcb@ \x00\x00\x00\x00\x00\x00'
169+
170+
>>> from_msgpack(NPFoo, to_msgpack(npfoo))
171+
NPFoo(i=1, j=2, f=3.0, g=4.0, h=False, u=array([1, 2]), v=array([3, 4]), w=array([1, 2], dtype=int32), x=array([3, 4]), y=array([5., 6.], dtype=float32), z=array([7., 8.]))
172+
```

pyproject.toml

+15-1
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,23 @@ jinja2 = "*"
3030
msgpack = { version = "*", markers = "extra == 'msgpack' or extra == 'all'", optional = true }
3131
toml = { version = "*", markers = "extra == 'toml' or extra == 'all'", optional = true }
3232
pyyaml = { version = "*", markers = "extra == 'yaml' or extra == 'all'", optional = true }
33+
numpy = [
34+
{ version = "~1.21.0", markers = "python_version ~= '3.7.0' and (extra == 'numpy' or extra == 'all')", optional = true },
35+
{ version = ">1.21.0", markers = "python_version ~= '3.8.0' and (extra == 'numpy' or extra == 'all')", optional = true },
36+
{ version = ">1.21.0", markers = "python_version ~= '3.9.0' and (extra == 'numpy' or extra == 'all')", optional = true },
37+
{ version = ">1.22.0", markers = "python_version ~= '3.10' and (extra == 'numpy' or extra == 'all')", optional = true },
38+
]
3339

3440
[tool.poetry.dev-dependencies]
3541
pyyaml = "*"
3642
toml = "*"
3743
msgpack = "*"
44+
numpy = [
45+
{ version = "~1.21.0", markers = "python_version ~= '3.7.0'" },
46+
{ version = ">1.21.0", markers = "python_version ~= '3.8.0'" },
47+
{ version = ">1.21.0", markers = "python_version ~= '3.9.0'" },
48+
{ version = ">1.22.0", markers = "python_version ~= '3.10'" },
49+
]
3850
flake8 = "*"
3951
pytest = "*"
4052
pytest-cov = "*"
@@ -49,10 +61,12 @@ pytest-xdist = "^2.3.0"
4961

5062
[tool.poetry.extras]
5163
msgpack = ["msgpack"]
64+
numpy = ["numpy"]
5265
toml = ["toml"]
5366
yaml = ["pyyaml"]
54-
all = ["msgpack", "toml", "pyyaml"]
67+
all = ["msgpack", "toml", "pyyaml", "numpy"]
5568

5669
[build-system]
5770
requires = ["poetry-core>=1.0.0"]
5871
build-backend = "poetry.core.masonry.api"
72+

serde/compat.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,24 @@
88
import types
99
import typing
1010
from dataclasses import is_dataclass
11-
from typing import Any, Dict, Generic, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
11+
from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
1212

1313
import typing_inspect
1414

15+
try:
16+
import numpy.typing as npt
17+
18+
def get_np_origin(tp):
19+
if isinstance(tp, npt._generic_alias._GenericAlias) and tp.__origin__ is not ClassVar:
20+
return tp.__origin__
21+
return None
22+
23+
except ImportError:
24+
25+
def get_np_origin(tp):
26+
return None
27+
28+
1529
__all__: List = []
1630

1731
T = TypeVar('T')
@@ -34,9 +48,9 @@ def get_origin(typ):
3448
Provide `get_origin` that works in all python versions.
3549
"""
3650
try:
37-
return typing.get_origin(typ) # python>=3.8 typing module has get_origin.
51+
return typing.get_origin(typ) or get_np_origin(typ) # python>=3.8 typing module has get_origin.
3852
except AttributeError:
39-
return typing_inspect.get_origin(typ)
53+
return typing_inspect.get_origin(typ) or get_np_origin(typ)
4054

4155

4256
def get_args(typ):

serde/de.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
raise_unsupported_type,
5858
union_func_name,
5959
)
60+
from .numpy import deserialize_numpy_array, deserialize_numpy_scalar, is_numpy_array, is_numpy_scalar
6061

6162
__all__: List = ['deserialize', 'is_deserializable', 'from_dict', 'from_tuple']
6263

@@ -488,6 +489,7 @@ class Renderer:
488489
func: str
489490
cls: Optional[Type] = None
490491
custom: Optional[DeserializeFunc] = None # Custom class level deserializer.
492+
import_numpy: bool = False
491493

492494
def render(self, arg: DeField) -> str:
493495
"""
@@ -509,7 +511,7 @@ def render(self, arg: DeField) -> str:
509511
res = self.tuple(arg)
510512
elif is_enum(arg.type):
511513
res = self.enum(arg)
512-
elif is_primitive(arg.type):
514+
elif is_primitive(arg.type) and not is_numpy_scalar(arg.type):
513515
res = self.primitive(arg)
514516
elif is_union(arg.type):
515517
res = self.union_func(arg)
@@ -529,6 +531,12 @@ def render(self, arg: DeField) -> str:
529531
elif is_generic(arg.type):
530532
arg.type = get_origin(arg.type)
531533
res = self.render(arg)
534+
elif is_numpy_scalar(arg.type):
535+
self.import_numpy = True
536+
res = deserialize_numpy_scalar(arg)
537+
elif is_numpy_array(arg.type):
538+
self.import_numpy = True
539+
res = deserialize_numpy_array(arg)
532540
else:
533541
return f"raise_unsupported_type({arg.data})"
534542

@@ -755,7 +763,12 @@ def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_s
755763
env = jinja2.Environment(loader=jinja2.DictLoader({'iter': template}))
756764
env.filters.update({'rvalue': renderer.render})
757765
env.filters.update({'arg': to_iter_arg})
758-
return env.get_template('iter').render(func=FROM_ITER, serde_scope=getattr(cls, SERDE_SCOPE), fields=defields(cls))
766+
res = env.get_template('iter').render(func=FROM_ITER, serde_scope=getattr(cls, SERDE_SCOPE), fields=defields(cls))
767+
768+
if renderer.import_numpy:
769+
res = "import numpy\n" + res
770+
771+
return res
759772

760773

761774
def render_from_dict(cls: Type, rename_all: Optional[str] = None, custom: Optional[DeserializeFunc] = None) -> str:
@@ -779,7 +792,12 @@ def {{func}}(cls=cls, maybe_generic=None, data=None, reuse_instances = {{serde_s
779792
env = jinja2.Environment(loader=jinja2.DictLoader({'dict': template}))
780793
env.filters.update({'rvalue': renderer.render})
781794
env.filters.update({'arg': functools.partial(to_arg, rename_all=rename_all)})
782-
return env.get_template('dict').render(func=FROM_DICT, serde_scope=getattr(cls, SERDE_SCOPE), fields=defields(cls))
795+
res = env.get_template('dict').render(func=FROM_DICT, serde_scope=getattr(cls, SERDE_SCOPE), fields=defields(cls))
796+
797+
if renderer.import_numpy:
798+
res = "import numpy\n" + res
799+
800+
return res
783801

784802

785803
def render_union_func(cls: Type, union_args: List[Type], tagging: Tagging = DefaultTagging) -> str:

serde/json.py

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .compat import T
88
from .de import Deserializer, from_dict
9+
from .numpy import encode_numpy
910
from .se import Serializer, to_dict
1011

1112
__all__ = ["from_json", "to_json"]
@@ -14,6 +15,8 @@
1415
class JsonSerializer(Serializer):
1516
@classmethod
1617
def serialize(cls, obj: Any, **opts) -> str:
18+
if "default" not in opts:
19+
opts["default"] = encode_numpy
1720
return json.dumps(obj, **opts)
1821

1922

serde/msgpack.py

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .compat import T
1010
from .core import SerdeError
1111
from .de import Deserializer, from_dict, from_tuple
12+
from .numpy import encode_numpy
1213
from .se import Serializer, to_dict, to_tuple
1314

1415
__all__ = ["from_msgpack", "to_msgpack"]
@@ -17,6 +18,8 @@
1718
class MsgPackSerializer(Serializer):
1819
@classmethod
1920
def serialize(cls, obj, use_bin_type: bool = True, ext_type_code: int = None, **opts) -> bytes:
21+
if "default" not in opts:
22+
opts["default"] = encode_numpy
2023
if ext_type_code is not None:
2124
obj_bytes = msgpack.packb(obj, use_bin_type=use_bin_type, **opts)
2225
obj_or_ext = msgpack.ExtType(ext_type_code, obj_bytes)

serde/numpy.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import Any, Callable, Optional
2+
3+
from serde.compat import get_origin
4+
5+
6+
def fullname(klass):
7+
module = klass.__module__
8+
if module == 'builtins':
9+
return klass.__qualname__ # avoid outputs like 'builtins.str'
10+
return module + '.' + klass.__qualname__
11+
12+
13+
try:
14+
import numpy as np
15+
import numpy.typing as npt
16+
17+
encode_numpy: Optional[Callable[[Any], Any]]
18+
19+
def encode_numpy(obj: Any):
20+
if isinstance(obj, np.ndarray):
21+
return obj.tolist()
22+
if isinstance(obj, np.generic):
23+
return obj.item()
24+
raise TypeError(f"Object of type {fullname(type(obj))} is not serializable")
25+
26+
def is_bare_numpy_array(typ) -> bool:
27+
"""
28+
Test if the type is `np.ndarray` or `npt.NDArray` without type args.
29+
30+
>>> import numpy as np
31+
>>> import numpy.typing as npt
32+
>>> is_bare_numpy_array(npt.NDArray[np.int64])
33+
False
34+
>>> is_bare_numpy_array(npt.NDArray)
35+
True
36+
>>> is_bare_numpy_array(np.ndarray)
37+
True
38+
"""
39+
return typ in (np.ndarray, npt.NDArray)
40+
41+
def is_numpy_scalar(typ) -> bool:
42+
try:
43+
return issubclass(typ, np.generic)
44+
except TypeError:
45+
return False
46+
47+
def serialize_numpy_scalar(arg) -> str:
48+
return f"{arg.varname}.item()"
49+
50+
def deserialize_numpy_scalar(arg):
51+
return f"{fullname(arg.type)}({arg.data})"
52+
53+
def is_numpy_array(typ) -> bool:
54+
origin = get_origin(typ)
55+
if origin is not None:
56+
typ = origin
57+
return typ is np.ndarray
58+
59+
def serialize_numpy_array(arg) -> str:
60+
return f"{arg.varname}.tolist()"
61+
62+
def deserialize_numpy_array(arg) -> str:
63+
if is_bare_numpy_array(arg.type):
64+
return f"numpy.array({arg.data})"
65+
66+
dtype = fullname(arg[1][0].type)
67+
return f"numpy.array({arg.data}, dtype={dtype})"
68+
69+
except ImportError:
70+
encode_numpy = None
71+
72+
def is_numpy_scalar(typ) -> bool:
73+
return False
74+
75+
def serialize_numpy_scalar(arg) -> str:
76+
return ""
77+
78+
def deserialize_numpy_scalar(arg):
79+
return ""
80+
81+
def is_numpy_array(typ) -> bool:
82+
return False
83+
84+
def serialize_numpy_array(arg) -> str:
85+
return ""
86+
87+
def deserialize_numpy_array(arg) -> str:
88+
return ""

0 commit comments

Comments
 (0)