Skip to content

Commit 5e521cf

Browse files
committed
feat(numpy): Support serialization of numpy.datetime64
1 parent 8f9f71c commit 5e521cf

File tree

5 files changed

+52
-15
lines changed

5 files changed

+52
-15
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ examples:
2525
pushd examples && $(POETRY) run $(PYTHON) runner.py && popd
2626

2727
coverage:
28-
$(POETRY) run pytest tests --doctest-modules serde -v -nauto --cov=serde --cov-report term --cov-report xml
28+
$(POETRY) run pytest tests --doctest-modules serde -v -nauto --cov=serde --cov-report term --cov-report xml --cov-report html
2929

3030
pep8:
3131
$(POETRY) run flake8

serde/de.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,13 @@ def render(self, arg: DeField) -> str:
519519
res = self.tuple(arg)
520520
elif is_enum(arg.type):
521521
res = self.enum(arg)
522-
elif is_primitive(arg.type) and not is_numpy_scalar(arg.type):
522+
elif is_numpy_scalar(arg.type):
523+
self.import_numpy = True
524+
res = deserialize_numpy_scalar(arg)
525+
elif is_numpy_array(arg.type):
526+
self.import_numpy = True
527+
res = deserialize_numpy_array(arg)
528+
elif is_primitive(arg.type):
523529
res = self.primitive(arg)
524530
elif is_union(arg.type):
525531
res = self.union_func(arg)
@@ -539,12 +545,6 @@ def render(self, arg: DeField) -> str:
539545
elif is_generic(arg.type):
540546
arg.type = get_origin(arg.type)
541547
res = self.render(arg)
542-
elif is_numpy_scalar(arg.type):
543-
self.import_numpy = True
544-
res = deserialize_numpy_scalar(arg)
545-
elif is_numpy_array(arg.type):
546-
self.import_numpy = True
547-
res = deserialize_numpy_array(arg)
548548
else:
549549
return f"raise_unsupported_type({arg.data})"
550550

serde/numpy.py

+17
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def fullname(klass):
1919
def encode_numpy(obj: Any):
2020
if isinstance(obj, np.ndarray):
2121
return obj.tolist()
22+
if isinstance(obj, np.datetime64):
23+
return obj.item().isoformat()
2224
if isinstance(obj, np.generic):
2325
return obj.item()
2426
raise TypeError(f"Object of type {fullname(type(obj))} is not serializable")
@@ -44,6 +46,12 @@ def is_numpy_scalar(typ) -> bool:
4446
except TypeError:
4547
return False
4648

49+
def is_numpy_datetime(typ) -> bool:
50+
try:
51+
return issubclass(typ, np.datetime64)
52+
except TypeError:
53+
return False
54+
4755
def serialize_numpy_scalar(arg) -> str:
4856
return f"{arg.varname}.item()"
4957

@@ -59,6 +67,9 @@ def is_numpy_array(typ) -> bool:
5967
def serialize_numpy_array(arg) -> str:
6068
return f"{arg.varname}.tolist()"
6169

70+
def serialize_numpy_datetime(arg) -> str:
71+
return f"{arg.varname}.item().isoformat()"
72+
6273
def deserialize_numpy_array(arg) -> str:
6374
if is_bare_numpy_array(arg.type):
6475
return f"numpy.array({arg.data})"
@@ -79,6 +90,9 @@ def deserialize_numpy_array_direct(typ, arg):
7990
def is_numpy_scalar(typ) -> bool:
8091
return False
8192

93+
def is_numpy_datetime(typ) -> bool:
94+
return False
95+
8296
def serialize_numpy_scalar(arg) -> str:
8397
return ""
8498

@@ -91,6 +105,9 @@ def is_numpy_array(typ) -> bool:
91105
def serialize_numpy_array(arg) -> str:
92106
return ""
93107

108+
def serialize_numpy_datetime(arg) -> str:
109+
return ""
110+
94111
def deserialize_numpy_array(arg) -> str:
95112
return ""
96113

serde/se.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,14 @@
5555
raise_unsupported_type,
5656
union_func_name,
5757
)
58-
from .numpy import is_numpy_array, is_numpy_scalar, serialize_numpy_array, serialize_numpy_scalar
58+
from .numpy import (
59+
is_numpy_array,
60+
is_numpy_datetime,
61+
is_numpy_scalar,
62+
serialize_numpy_array,
63+
serialize_numpy_datetime,
64+
serialize_numpy_scalar,
65+
)
5966

6067
__all__ = ["serialize", "is_serializable", "to_dict", "to_tuple"]
6168

@@ -589,7 +596,13 @@ def render(self, arg: SeField) -> str:
589596
res = self.tuple(arg)
590597
elif is_enum(arg.type):
591598
res = self.enum(arg)
592-
elif is_primitive(arg.type) and not is_numpy_scalar(arg.type):
599+
elif is_numpy_datetime(arg.type):
600+
res = serialize_numpy_datetime(arg)
601+
elif is_numpy_scalar(arg.type):
602+
res = serialize_numpy_scalar(arg)
603+
elif is_numpy_array(arg.type):
604+
res = serialize_numpy_array(arg)
605+
elif is_primitive(arg.type):
593606
res = self.primitive(arg)
594607
elif is_union(arg.type):
595608
res = self.union_func(arg)
@@ -604,10 +617,6 @@ def render(self, arg: SeField) -> str:
604617
elif is_generic(arg.type):
605618
arg.type = get_origin(arg.type)
606619
res = self.render(arg)
607-
elif is_numpy_scalar(arg.type):
608-
res = serialize_numpy_scalar(arg)
609-
elif is_numpy_array(arg.type):
610-
res = serialize_numpy_array(arg)
611620
else:
612621
res = f"raise_unsupported_type({arg.varname})"
613622

tests/test_numpy.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ def __eq__(self, other):
8282
assert type(de_value) == typ
8383
assert type(de_value).dtype == field.type.dtype
8484

85+
@serde.serde(**opt)
86+
class NumpyDate:
87+
d: np.datetime64
88+
89+
date_test = NumpyDate(np.datetime64(10, "Y"))
90+
91+
assert de(NumpyDate, se(date_test)) == date_test
92+
8593

8694
@pytest.mark.parametrize("opt", opt_case, ids=opt_case_ids())
8795
@pytest.mark.parametrize("se,de", format_json + format_msgpack)
@@ -138,7 +146,10 @@ class MisTyped:
138146
np.array([np.bool_(i) for i in [True, False]]),
139147
)
140148

141-
assert de(MisTyped, se(test2)) == expected
149+
assert de(MisTyped, se(test3)) == expected
150+
151+
np_datetime = np.datetime64("2022-04-27")
152+
assert de(np.datetime64, se(np_datetime)) == np_datetime
142153

143154
for value in [np.int32(1), np.int64(1), np.bool_(False), np.bool_(True), int(1), False, True]:
144155
typ = type(value)

0 commit comments

Comments
 (0)