|
1 | 1 | import logging
|
2 | 2 | from dataclasses import fields
|
| 3 | +from typing import List |
3 | 4 |
|
4 | 5 | import numpy as np
|
5 | 6 | import numpy.typing as npt
|
|
8 | 9 | import serde
|
9 | 10 | from serde.compat import get_origin
|
10 | 11 |
|
11 |
| -from .common import all_formats, opt_case, opt_case_ids |
| 12 | +from .common import all_formats, format_json, format_msgpack, opt_case, opt_case_ids |
12 | 13 |
|
13 | 14 | log = logging.getLogger("test")
|
14 | 15 |
|
@@ -80,3 +81,109 @@ def __eq__(self, other):
|
80 | 81 |
|
81 | 82 | assert type(de_value) == typ
|
82 | 83 | assert type(de_value).dtype == field.type.dtype
|
| 84 | + |
| 85 | + |
| 86 | +@pytest.mark.parametrize("opt", opt_case, ids=opt_case_ids()) |
| 87 | +@pytest.mark.parametrize("se,de", format_json + format_msgpack) |
| 88 | +def test_encode_numpy(se, de, opt): |
| 89 | + log.info(f"Running test with se={se.__name__} de={de.__name__} opts={opt}") |
| 90 | + |
| 91 | + de_int = de(int, se(np.int32(1))) |
| 92 | + assert de_int == 1 |
| 93 | + assert isinstance(de_int, int) |
| 94 | + |
| 95 | + de_arr = de(list, se(np.array([1, 2, 3]))) |
| 96 | + assert isinstance(de_arr, list) |
| 97 | + assert de_arr == [1, 2, 3] |
| 98 | + |
| 99 | + @serde.serde(**opt) |
| 100 | + class MisTyped: |
| 101 | + a: int |
| 102 | + b: float |
| 103 | + h: bool |
| 104 | + c: List[int] |
| 105 | + d: List[float] |
| 106 | + e: List[bool] |
| 107 | + |
| 108 | + expected = MisTyped(1, 3.0, False, [1, 2], [5.0, 6.0], [True, False]) |
| 109 | + |
| 110 | + test1 = MisTyped( |
| 111 | + np.int32(1), |
| 112 | + np.float32(3.0), |
| 113 | + np.bool_(False), |
| 114 | + np.array([np.int32(i) for i in [1, 2]]), |
| 115 | + np.array([np.float32(i) for i in [5.0, 6.0]]), |
| 116 | + np.array([np.bool_(i) for i in [True, False]]), |
| 117 | + ) |
| 118 | + |
| 119 | + assert de(MisTyped, se(test1)) == expected |
| 120 | + |
| 121 | + test2 = MisTyped( |
| 122 | + np.int64(1), |
| 123 | + np.float64(3.0), |
| 124 | + np.bool_(False), |
| 125 | + np.array([np.int64(i) for i in [1, 2]]), |
| 126 | + np.array([np.float64(i) for i in [5.0, 6.0]]), |
| 127 | + np.array([np.bool_(i) for i in [True, False]]), |
| 128 | + ) |
| 129 | + |
| 130 | + assert de(MisTyped, se(test2)) == expected |
| 131 | + |
| 132 | + test3 = MisTyped( |
| 133 | + np.int64(1), |
| 134 | + np.float64(3.0), |
| 135 | + np.bool_(False), |
| 136 | + np.array([np.int64(i) for i in [1, 2]]), |
| 137 | + np.array([np.float64(i) for i in [5.0, 6.0]]), |
| 138 | + np.array([np.bool_(i) for i in [True, False]]), |
| 139 | + ) |
| 140 | + |
| 141 | + assert de(MisTyped, se(test2)) == expected |
| 142 | + |
| 143 | + |
| 144 | +@pytest.mark.parametrize("se,de", format_json + format_msgpack) |
| 145 | +def test_encode_numpy_with_no_default_encoder(se, de): |
| 146 | + log.info(f"Running test with se={se.__name__} de={de.__name__} with no default encoder") |
| 147 | + |
| 148 | + with pytest.raises(TypeError): |
| 149 | + se(np.int32(1), default=None) |
| 150 | + |
| 151 | + with pytest.raises(TypeError): |
| 152 | + se(np.array([1, 2, 3]), default=None) |
| 153 | + |
| 154 | + @serde.serde |
| 155 | + class MisTypedNoDefaultEncoder: |
| 156 | + a: int |
| 157 | + b: float |
| 158 | + h: bool |
| 159 | + c: List[int] |
| 160 | + d: List[float] |
| 161 | + e: List[bool] |
| 162 | + |
| 163 | + test1 = MisTypedNoDefaultEncoder( |
| 164 | + np.int32(1), |
| 165 | + np.float32(3.0), |
| 166 | + np.bool_(False), |
| 167 | + np.array([np.int32(i) for i in [1, 2]]), |
| 168 | + np.array([np.float32(i) for i in [5.0, 6.0]]), |
| 169 | + np.array([np.bool_(i) for i in [True, False]]), |
| 170 | + ) |
| 171 | + |
| 172 | + with pytest.raises(TypeError): |
| 173 | + se(test1, default=None) |
| 174 | + |
| 175 | + test2 = MisTypedNoDefaultEncoder( |
| 176 | + np.int64(1), |
| 177 | + np.float64(3.0), |
| 178 | + np.bool_(False), |
| 179 | + np.array([np.int64(i) for i in [1, 2]]), |
| 180 | + np.array([np.float64(i) for i in [5.0, 6.0]]), |
| 181 | + np.array([np.bool_(i) for i in [True, False]]), |
| 182 | + ) |
| 183 | + |
| 184 | + with pytest.raises(TypeError): |
| 185 | + se(test2, default=None) |
| 186 | + |
| 187 | + |
| 188 | +def test_numpy_misc(): |
| 189 | + assert serde.numpy.fullname(str) == "str" |
0 commit comments