Skip to content

Commit 4fa045c

Browse files
committed
test: Add more numpy tests
1 parent 731876f commit 4fa045c

File tree

1 file changed

+108
-1
lines changed

1 file changed

+108
-1
lines changed

tests/test_numpy.py

+108-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from dataclasses import fields
3+
from typing import List
34

45
import numpy as np
56
import numpy.typing as npt
@@ -8,7 +9,7 @@
89
import serde
910
from serde.compat import get_origin
1011

11-
from .common import all_formats, opt_case, opt_case_ids
12+
from .common import all_formats, format_json, format_msgpack, opt_case, opt_case_ids
1213

1314
log = logging.getLogger("test")
1415

@@ -80,3 +81,109 @@ def __eq__(self, other):
8081

8182
assert type(de_value) == typ
8283
assert type(de_value).dtype == field.type.dtype
84+
85+
86+
@pytest.mark.parametrize("opt", opt_case, ids=opt_case_ids())
87+
@pytest.mark.parametrize("se,de", format_json + format_msgpack)
88+
def test_encode_numpy(se, de, opt):
89+
log.info(f"Running test with se={se.__name__} de={de.__name__} opts={opt}")
90+
91+
de_int = de(int, se(np.int32(1)))
92+
assert de_int == 1
93+
assert isinstance(de_int, int)
94+
95+
de_arr = de(list, se(np.array([1, 2, 3])))
96+
assert isinstance(de_arr, list)
97+
assert de_arr == [1, 2, 3]
98+
99+
@serde.serde(**opt)
100+
class MisTyped:
101+
a: int
102+
b: float
103+
h: bool
104+
c: List[int]
105+
d: List[float]
106+
e: List[bool]
107+
108+
expected = MisTyped(1, 3.0, False, [1, 2], [5.0, 6.0], [True, False])
109+
110+
test1 = MisTyped(
111+
np.int32(1),
112+
np.float32(3.0),
113+
np.bool_(False),
114+
np.array([np.int32(i) for i in [1, 2]]),
115+
np.array([np.float32(i) for i in [5.0, 6.0]]),
116+
np.array([np.bool_(i) for i in [True, False]]),
117+
)
118+
119+
assert de(MisTyped, se(test1)) == expected
120+
121+
test2 = MisTyped(
122+
np.int64(1),
123+
np.float64(3.0),
124+
np.bool_(False),
125+
np.array([np.int64(i) for i in [1, 2]]),
126+
np.array([np.float64(i) for i in [5.0, 6.0]]),
127+
np.array([np.bool_(i) for i in [True, False]]),
128+
)
129+
130+
assert de(MisTyped, se(test2)) == expected
131+
132+
test3 = MisTyped(
133+
np.int64(1),
134+
np.float64(3.0),
135+
np.bool_(False),
136+
np.array([np.int64(i) for i in [1, 2]]),
137+
np.array([np.float64(i) for i in [5.0, 6.0]]),
138+
np.array([np.bool_(i) for i in [True, False]]),
139+
)
140+
141+
assert de(MisTyped, se(test2)) == expected
142+
143+
144+
@pytest.mark.parametrize("se,de", format_json + format_msgpack)
145+
def test_encode_numpy_with_no_default_encoder(se, de):
146+
log.info(f"Running test with se={se.__name__} de={de.__name__} with no default encoder")
147+
148+
with pytest.raises(TypeError):
149+
se(np.int32(1), default=None)
150+
151+
with pytest.raises(TypeError):
152+
se(np.array([1, 2, 3]), default=None)
153+
154+
@serde.serde
155+
class MisTypedNoDefaultEncoder:
156+
a: int
157+
b: float
158+
h: bool
159+
c: List[int]
160+
d: List[float]
161+
e: List[bool]
162+
163+
test1 = MisTypedNoDefaultEncoder(
164+
np.int32(1),
165+
np.float32(3.0),
166+
np.bool_(False),
167+
np.array([np.int32(i) for i in [1, 2]]),
168+
np.array([np.float32(i) for i in [5.0, 6.0]]),
169+
np.array([np.bool_(i) for i in [True, False]]),
170+
)
171+
172+
with pytest.raises(TypeError):
173+
se(test1, default=None)
174+
175+
test2 = MisTypedNoDefaultEncoder(
176+
np.int64(1),
177+
np.float64(3.0),
178+
np.bool_(False),
179+
np.array([np.int64(i) for i in [1, 2]]),
180+
np.array([np.float64(i) for i in [5.0, 6.0]]),
181+
np.array([np.bool_(i) for i in [True, False]]),
182+
)
183+
184+
with pytest.raises(TypeError):
185+
se(test2, default=None)
186+
187+
188+
def test_numpy_misc():
189+
assert serde.numpy.fullname(str) == "str"

0 commit comments

Comments
 (0)