Skip to content

Commit 8f9f71c

Browse files
committed
fix(numpy): Fix direct numpy array deserialization
* The existing numpy deserialization handles the case where a numpy array is a member of a dataclass, but not where the numpy array is deserialized directy.
1 parent 30763e1 commit 8f9f71c

File tree

4 files changed

+59
-4
lines changed

4 files changed

+59
-4
lines changed

serde/compat.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,26 @@
1515
try:
1616
import numpy.typing as npt
1717

18+
# Note: these functions are only needed on Python 3.8 or earlier
1819
def get_np_origin(tp):
1920
if isinstance(tp, npt._generic_alias._GenericAlias) and tp.__origin__ is not ClassVar:
2021
return tp.__origin__
2122
return None
2223

24+
def get_np_args(tp):
25+
if isinstance(tp, npt._generic_alias._GenericAlias) and tp.__origin__ is not ClassVar:
26+
return tp.__args__
27+
28+
return ()
29+
2330
except ImportError:
2431

2532
def get_np_origin(tp):
2633
return None
2734

35+
def get_np_args(tp):
36+
return ()
37+
2838

2939
__all__: List = []
3040

@@ -58,9 +68,9 @@ def get_args(typ):
5868
Provide `get_args` that works in all python versions.
5969
"""
6070
try:
61-
return typing.get_args(typ) # python>=3.8 typing module has get_args.
71+
return typing.get_args(typ) or get_np_args(typ) # python>=3.8 typing module has get_args.
6272
except AttributeError:
63-
return typing_inspect.get_args(typ)
73+
return typing_inspect.get_args(typ) or get_np_args(typ)
6474

6575

6676
def typename(typ) -> str:

serde/de.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@
5757
raise_unsupported_type,
5858
union_func_name,
5959
)
60-
from .numpy import deserialize_numpy_array, deserialize_numpy_scalar, is_numpy_array, is_numpy_scalar
60+
from .numpy import (
61+
deserialize_numpy_array,
62+
deserialize_numpy_array_direct,
63+
deserialize_numpy_scalar,
64+
is_numpy_array,
65+
is_numpy_scalar,
66+
)
6167

6268
__all__: List = ['deserialize', 'is_deserializable', 'from_dict', 'from_tuple']
6369

@@ -343,6 +349,8 @@ def from_obj(c: Type, o: Any, named: bool, reuse_instances: bool):
343349
return {k: v for k, v in o.items()}
344350
else:
345351
return {thisfunc(type_args(c)[0], k): thisfunc(type_args(c)[1], v) for k, v in o.items()}
352+
elif is_numpy_array(c):
353+
return deserialize_numpy_array_direct(c, o)
346354
elif c in DateTimeTypes:
347355
return c.fromisoformat(o)
348356
elif c is Any:

serde/numpy.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Callable, Optional
22

3-
from serde.compat import get_origin
3+
from serde.compat import get_args, get_origin
44

55

66
def fullname(klass):
@@ -66,6 +66,13 @@ def deserialize_numpy_array(arg) -> str:
6666
dtype = fullname(arg[1][0].type)
6767
return f"numpy.array({arg.data}, dtype={dtype})"
6868

69+
def deserialize_numpy_array_direct(typ, arg):
70+
if is_bare_numpy_array(typ):
71+
return np.array(arg)
72+
73+
dtype = get_args(get_args(typ)[1])[0]
74+
return np.array(arg, dtype=dtype)
75+
6976
except ImportError:
7077
encode_numpy = None
7178

@@ -86,3 +93,6 @@ def serialize_numpy_array(arg) -> str:
8693

8794
def deserialize_numpy_array(arg) -> str:
8895
return ""
96+
97+
def deserialize_numpy_array_direct(typ, arg):
98+
return arg

tests/test_numpy.py

+27
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,33 @@ class MisTyped:
140140

141141
assert de(MisTyped, se(test2)) == expected
142142

143+
for value in [np.int32(1), np.int64(1), np.bool_(False), np.bool_(True), int(1), False, True]:
144+
typ = type(value)
145+
de_value = de(typ, se(value))
146+
assert de_value == value
147+
assert type(de_value) == typ
148+
149+
# test bare numpy array deserialization
150+
arr = np.array([0, 1, 2], dtype=typ)
151+
de_arr = de(np.ndarray, se(arr))
152+
assert (de_arr == arr).all()
153+
154+
# test arrays with dtype=type(value)
155+
arr_typ = npt.NDArray[typ]
156+
157+
de_arr = de(arr_typ, se(arr))
158+
assert (de_arr == arr).all()
159+
assert de_arr.dtype == arr.dtype == typ
160+
161+
class BadClass:
162+
def __init__(self, x):
163+
self.x = x
164+
165+
b = BadClass(1)
166+
167+
with pytest.raises(TypeError):
168+
se(b)
169+
143170

144171
@pytest.mark.parametrize("se,de", format_json + format_msgpack)
145172
def test_encode_numpy_with_no_default_encoder(se, de):

0 commit comments

Comments
 (0)