Skip to content

Commit 1784af4

Browse files
committed
feat: add draw_chart
1 parent d1dc3e6 commit 1784af4

File tree

4 files changed

+330
-0
lines changed

4 files changed

+330
-0
lines changed

graph_datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
__version__ = "0.9.0"
44

55
from .load_data import load_data
6+
from .utils import *

graph_datasets/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .output import make_parent_dirs
1414
from .output import refresh_file
1515
from .output import save_to_csv_files
16+
from .plt import draw_chart
1617
from .statistics import edge_homo
1718
from .statistics import node_homo
1819
from .statistics import statistics

graph_datasets/utils/plt.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
"""Draw plots.
2+
"""
3+
from typing import Any
4+
from typing import List
5+
from typing import Tuple
6+
7+
import matplotlib.axes as Axes
8+
import matplotlib.pyplot as plt
9+
10+
11+
def charts(t, ax: Axes):
12+
return {
13+
"line": ax.plot,
14+
"scatter": ax.scatter,
15+
"bar": ax.bar,
16+
}[t]
17+
18+
19+
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
20+
def draw_chart(
21+
ys: List[List],
22+
lbls: List[str],
23+
y_colors: List[str] = None,
24+
xs: List[Any] = None,
25+
fill_between: List = None,
26+
fill_colors: List[str] = None,
27+
fill_alpha: List[float] or float = 0.5,
28+
figsize: Tuple = (24, 3),
29+
xmode: str = "s",
30+
# pylint: disable=unused-argument
31+
ymode: str = "s",
32+
title: str = None,
33+
xticks: bool = False,
34+
x_rotation: float = 45,
35+
boxes: List[List[int]] = None,
36+
box_colors: List[str] = None,
37+
box_alphas: List[float] = None,
38+
ylim: (float, float) = None,
39+
xlim: (float, float) = None,
40+
linewidth: List[int] = None,
41+
types: List[str] or str = "line",
42+
markersize: float = 1.25,
43+
save_path: str = None,
44+
bar_width: float = 0.25,
45+
legend_loc: str = "center left",
46+
legend_bbox_to_anchor: Tuple[float] = (1, 0.5),
47+
) -> None:
48+
"""Draw charts.
49+
50+
Args:
51+
ys (List[List]): value lists of the y axis.
52+
lbls (List[str]): legend labels of the value lists.
53+
y_colors (List[str], optional): colors of the the ys. Defaults to None.
54+
xs (List[Any], optional): value list of the x axis. \
55+
If None, idx will be used. Defaults to None.
56+
fill_between (List, optional): fill the range of \
57+
`[ys[idx][i] - fill_between[idx][i], ys[idx][i] + fill_between[idx][i]]`.\
58+
Defaults to None.
59+
fill_colors (List[str], optional): colors of the fill range. Defaults to None.
60+
fill_alpha (List[float] or float, optional): color alphas of the fill range. \
61+
Defaults to 0.5.
62+
figsize (Tuple, optional): size of the output figure. Defaults to (24, 3).
63+
xmode (str, optional): all ys with the same x ('s') or with different xs ('d'). \
64+
Defaults to "s".
65+
ymode (str, optional): all ys with the same tick ('s') or with different ticks ('d'). \
66+
Defaults to "s".
67+
title (str, optional): figure title. Defaults to None.
68+
xticks (bool, optional): set the x ticks. Defaults to True.
69+
x_rotation (float, optional): rotation of the x ticks. Defaults to 45.
70+
boxes (List[List[int]], optional): draw boxes. Defaults to None.
71+
box_colors (List[str], optional): colors of the boxes to draw. Defaults to None.
72+
box_alpha (List[float], optional): color alphas of the boxes to draw. Defaults to 0.5.
73+
ylim (float, optional): maximum value of the y axis. Defaults to None.
74+
xlim (float, optional): maximum value of the x axis. Defaults to None.
75+
linewidth (List[int], optional): line width of the line charts to draw. Defaults to None.
76+
types (List[str] or str, optional): types of charts of the ys. Defaults to "line".
77+
markersize (float, optional): dot size of the scatter charts. Defaults to 1.25.
78+
save_path (str, optional): If not None, save image to the path. Defaults to None.
79+
bar_width (float, optional): width of the bars. Defaults to 0.25.
80+
legend_loc (str, optional): location of the legend. Defaults to "center left".
81+
legend_bbox_to_anchor (Tuple[float], optional): bbox_to_anchor of the legend. \
82+
Defaults to (1, 0.5).
83+
84+
Raises:
85+
ValueError: "The arg 'types:'{types} has values not supported."
86+
"""
87+
88+
plt.clf()
89+
_, ax = plt.subplots(figsize=figsize)
90+
91+
# drawing func list
92+
funcs = []
93+
optional_args = []
94+
if isinstance(types, List):
95+
for t in types:
96+
funcs.append(charts(t, ax))
97+
if t == "bar":
98+
optional_args.append({
99+
"width": bar_width,
100+
})
101+
else:
102+
optional_args.append({})
103+
elif isinstance(types, str):
104+
for _ in range(len(ys)):
105+
funcs.append(charts(types, ax))
106+
if t == "bar":
107+
optional_args.append({
108+
"width": bar_width,
109+
})
110+
else:
111+
optional_args.append({})
112+
else:
113+
raise ValueError(f"The arg 'types:'{types} has values not supported.")
114+
115+
plt.rcParams["lines.markersize"] = markersize
116+
# ys with the same x
117+
if xmode == "s":
118+
for idx, y in enumerate(ys):
119+
funcs[idx](
120+
xs,
121+
y,
122+
label=lbls[idx],
123+
color=y_colors[idx] if y_colors is not None and y_colors[idx] is not None else None,
124+
linewidth=1 if linewidth is None else linewidth[idx],
125+
**optional_args[idx],
126+
)
127+
128+
# fill the range of
129+
# [ys[idx][i] - fill_between[idx][i], ys[idx][i] + fill_between[idx][i]]
130+
if fill_between is not None and fill_between[idx] is not None:
131+
ax.fill_between(
132+
xs,
133+
[y + fill_between[idx][i] for i, y in enumerate(y)],
134+
[y - fill_between[idx][i] for i, y in enumerate(y)],
135+
facecolor=fill_colors[idx]
136+
if fill_colors is not None and fill_colors[idx] is not None else None,
137+
alpha=fill_alpha[idx] if isinstance(fill_alpha, List) and
138+
fill_colors[idx] is not None else fill_alpha,
139+
)
140+
141+
# set the ticks of the x axis
142+
if xticks:
143+
ax.set_xticks(
144+
range(0, len(xs)),
145+
labels=xs,
146+
rotation=x_rotation,
147+
ha="right",
148+
)
149+
150+
# ys with different xs
151+
elif xmode == "d":
152+
for idx, y in enumerate(ys):
153+
funcs[idx](
154+
xs[idx],
155+
y,
156+
label=lbls[idx],
157+
color=y_colors[idx] if y_colors is not None and y_colors[idx] is not None else None,
158+
linewidth=1 if linewidth is None else linewidth[idx],
159+
**optional_args[idx],
160+
)
161+
162+
# fill the range of
163+
# [ys[idx][i] - fill_between[idx][i], ys[idx][i] + fill_between[idx][i]]
164+
if fill_between is not None and fill_between[idx] is not None:
165+
ax.fill_between(
166+
xs,
167+
[y + fill_between[idx][i] for i, y in enumerate(y)],
168+
[y - fill_between[idx][i] for i, y in enumerate(y)],
169+
facecolor=fill_colors[idx]
170+
if fill_colors is not None and fill_colors[idx] is not None else None,
171+
alpha=fill_alpha[idx] if isinstance(fill_alpha, List) and
172+
fill_colors[idx] is not None else fill_alpha,
173+
)
174+
175+
# set the ticks of the x axis
176+
if xticks:
177+
ax.set_xticks(
178+
range(0, len(xs[0])),
179+
labels=xs[0],
180+
rotation=x_rotation,
181+
ha="right",
182+
)
183+
184+
# set xs with numbers when no xs provided
185+
else:
186+
for idx, y in enumerate(ys):
187+
xs = list(range(len(y)))
188+
funcs[idx](
189+
xs,
190+
y,
191+
label=lbls[idx],
192+
color=y_colors[idx] if y_colors is not None and y_colors[idx] is not None else None,
193+
linewidth=1 if linewidth is None else linewidth[idx],
194+
**optional_args[idx],
195+
)
196+
197+
# fill the range of
198+
# [ys[idx][i] - fill_between[idx][i], ys[idx][i] + fill_between[idx][i]]
199+
if fill_between is not None and fill_between[idx] is not None:
200+
ax.fill_between(
201+
xs,
202+
[y + fill_between[idx][i] for i, y in enumerate(y)],
203+
[y - fill_between[idx][i] for i, y in enumerate(y)],
204+
facecolor=fill_colors[idx]
205+
if fill_colors is not None and fill_colors[idx] is not None else None,
206+
alpha=fill_alpha[idx] if isinstance(fill_alpha, List) and
207+
fill_colors[idx] is not None else fill_alpha,
208+
)
209+
210+
# set the ticks of the x axis
211+
if xticks:
212+
ax.set_xticks(
213+
range(0, len(xs)),
214+
list(range(len(ys[0]))),
215+
rotation=45,
216+
ha="right",
217+
)
218+
219+
# draw a box on the figure
220+
if boxes is not None:
221+
for idx, box in enumerate(boxes):
222+
ax.add_patch(
223+
plt.Rectangle(
224+
box[0],
225+
box[1],
226+
box[2],
227+
transform=ax.transAxes,
228+
color=box_colors[idx] if box_colors is not None else "darkgrey",
229+
alpha=box_alphas[idx] if box_alphas is not None else 0.5,
230+
)
231+
)
232+
233+
ax.legend(loc=legend_loc, bbox_to_anchor=legend_bbox_to_anchor)
234+
if title:
235+
plt.title(title)
236+
if xlim is not None:
237+
plt.xlim(xlim)
238+
if ylim is not None:
239+
plt.ylim(ylim)
240+
241+
if save_path is not None:
242+
plt.savefig(save_path)
243+
plt.show()

0 commit comments

Comments
 (0)