-
Notifications
You must be signed in to change notification settings - Fork 77
/
Copy pathdict_read_write.py
154 lines (126 loc) · 4.74 KB
/
dict_read_write.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# -*- coding: utf-8 -*-
"""
DictReader and DictWriter.
This code is entirely copied from the Python csv module. The only exception is
that it uses the `reader` and `writer` classes from our package.
Author: Gertjan van den Burg
"""
from __future__ import annotations
import warnings
from collections import OrderedDict
from collections.abc import Collection
from typing import TYPE_CHECKING
from typing import Any
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import Literal
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import TypeVar
from typing import Union
from typing import cast
from clevercsv.read import reader
from clevercsv.write import writer
if TYPE_CHECKING:
from clevercsv._types import SupportsWrite
from clevercsv._types import _DialectLike
from clevercsv._types import _DictReadMapping
_T = TypeVar("_T")
class DictReader(
Generic[_T], Iterator["_DictReadMapping[Union[_T, Any], Union[str, Any]]"]
):
def __init__(
self,
f: Iterable[str],
fieldnames: Optional[Sequence[_T]] = None,
restkey: Optional[str] = None,
restval: Optional[str] = None,
dialect: "_DialectLike" = "excel",
*args: Any,
**kwds: Any,
) -> None:
self._fieldnames = fieldnames
self.restkey = restkey
self.restval = restval
self.reader: reader = reader(f, dialect, *args, **kwds)
self.dialect = dialect
self.line_num = 0
def __iter__(self) -> "DictReader[_T]":
return self
@property
def fieldnames(self) -> Sequence[_T]:
if self._fieldnames is None:
try:
fieldnames = next(self.reader)
self._fieldnames = [cast(_T, f) for f in fieldnames]
except StopIteration:
pass
assert self._fieldnames is not None
# Note: this was added because I don't think it's expected that Python
# simply drops information if there are duplicate headers. There is
# discussion on this issue in the Python bug tracker here:
# https://bugs.python.org/issue17537 (see linked thread therein). A
# warning is easy enough to suppress and should ensure that the user
# is at least aware of this behavior.
if not len(self._fieldnames) == len(set(self._fieldnames)):
warnings.warn(
"fieldnames are not unique, some columns will be dropped."
)
self.line_num = self.reader.line_num
return self._fieldnames
@fieldnames.setter
def fieldnames(self, value: Sequence[_T]) -> None:
self._fieldnames = value
def __next__(self) -> "_DictReadMapping[Union[_T, Any], Union[str, Any]]":
if self.line_num == 0:
self.fieldnames
row = next(self.reader)
self.line_num = self.reader.line_num
while row == []:
row = next(self.reader)
d: _DictReadMapping = OrderedDict(zip(self.fieldnames, row))
lf = len(self.fieldnames)
lr = len(row)
if lf < lr:
d[self.restkey] = row[lf:]
elif lf > lr:
for key in self.fieldnames[lr:]:
d[key] = self.restval
return d
class DictWriter(Generic[_T]):
def __init__(
self,
f: SupportsWrite[str],
fieldnames: Collection[_T],
restval: Optional[Any] = "",
extrasaction: Literal["raise", "ignore"] = "raise",
dialect: "_DialectLike" = "excel",
*args: Any,
**kwds: Any,
):
self.fieldnames = fieldnames
self.restval = restval
if extrasaction.lower() not in ("raise", "ignore"):
raise ValueError(
"extrasaction (%s) must be 'raise' or 'ignore'" % extrasaction
)
self.extrasaction = extrasaction
self.writer = writer(f, dialect, *args, **kwds)
def writeheader(self) -> Any:
header = dict(zip(self.fieldnames, self.fieldnames))
return self.writerow(header)
def _dict_to_list(self, rowdict: Mapping[_T, Any]) -> Iterator[Any]:
if self.extrasaction == "raise":
wrong_fields = rowdict.keys() - self.fieldnames
if wrong_fields:
raise ValueError(
"dict contains fields not in fieldnames: "
+ ", ".join([repr(x) for x in wrong_fields])
)
return (rowdict.get(key, self.restval) for key in self.fieldnames)
def writerow(self, rowdict: Mapping[_T, Any]) -> Any:
return self.writer.writerow(self._dict_to_list(rowdict))
def writerows(self, rowdicts: Iterable[Mapping[_T, Any]]) -> None:
return self.writer.writerows(map(self._dict_to_list, rowdicts))