Skip to content

Commit ad3bfd8

Browse files
committed
0. 添加iou和dice曲线的绘制功能。
1. 简化绘图代码中关于轴对应的指标的选择逻辑 2. 基于ruff的格式,屏蔽对终端参数选项的格式化。 3. 调整converter.py中关于最优指标使用的指令(`\best` -> `\first`),以和两外两个位次的名称对应。
1 parent 8d8c86e commit ad3bfd8

File tree

4 files changed

+91
-180
lines changed

4 files changed

+91
-180
lines changed

eval.py

+13-67
Original file line numberDiff line numberDiff line change
@@ -64,82 +64,28 @@ def get_args():
6464
),
6565
formatter_class=argparse.RawTextHelpFormatter,
6666
)
67+
# fmt: off
6768
parser.add_argument("--dataset-json", required=True, type=str, help="Json file for datasets.")
68-
parser.add_argument(
69-
"--method-json", required=True, nargs="+", type=str, help="Json file for methods."
70-
)
69+
parser.add_argument("--method-json", required=True, nargs="+", type=str, help="Json file for methods.")
7170
parser.add_argument("--metric-npy", type=str, help="Npy file for saving metric results.")
7271
parser.add_argument("--curves-npy", type=str, help="Npy file for saving curve results.")
7372
parser.add_argument("--record-txt", type=str, help="Txt file for saving metric results.")
7473
parser.add_argument("--to-overwrite", action="store_true", help="To overwrite the txt file.")
7574
parser.add_argument("--record-xlsx", type=str, help="Xlsx file for saving metric results.")
76-
parser.add_argument(
77-
"--include-methods",
78-
type=str,
79-
nargs="+",
80-
help="Names of only specific methods you want to evaluate.",
81-
)
82-
parser.add_argument(
83-
"--exclude-methods",
84-
type=str,
85-
nargs="+",
86-
help="Names of some specific methods you do not want to evaluate.",
87-
)
88-
parser.add_argument(
89-
"--include-datasets",
90-
type=str,
91-
nargs="+",
92-
help="Names of only specific datasets you want to evaluate.",
93-
)
94-
parser.add_argument(
95-
"--exclude-datasets",
96-
type=str,
97-
nargs="+",
98-
help="Names of some specific datasets you do not want to evaluate.",
99-
)
100-
parser.add_argument(
101-
"--num-workers",
102-
type=int,
103-
default=4,
104-
help="Number of workers for multi-threading or multi-processing. Default: 4",
105-
)
106-
parser.add_argument(
107-
"--num-bits",
108-
type=int,
109-
default=3,
110-
help="Number of decimal places for showing results. Default: 3",
111-
)
112-
parser.add_argument(
113-
"--metric-names",
114-
type=str,
115-
nargs="+",
116-
default=["sm", "wfm", "mae", "fmeasure", "em", "precision", "recall", "msiou"],
117-
choices=SUPPORTED_METRICS,
118-
help="Names of metrics",
119-
)
120-
parser.add_argument(
121-
"--data-type",
122-
type=str,
123-
default="image",
124-
choices=["image", "video"],
125-
help="Type of data.",
126-
)
75+
parser.add_argument("--include-methods", type=str, nargs="+", help="Names of only specific methods you want to evaluate.")
76+
parser.add_argument("--exclude-methods", type=str, nargs="+", help="Names of some specific methods you do not want to evaluate.")
77+
parser.add_argument("--include-datasets", type=str, nargs="+", help="Names of only specific datasets you want to evaluate.")
78+
parser.add_argument("--exclude-datasets", type=str, nargs="+", help="Names of some specific datasets you do not want to evaluate.")
79+
parser.add_argument("--num-workers", type=int, default=4, help="Number of workers for multi-threading or multi-processing. Default: 4")
80+
parser.add_argument("--num-bits", type=int, default=3, help="Number of decimal places for showing results. Default: 3")
81+
parser.add_argument("--metric-names", type=str, nargs="+", default=["sm", "wfm", "mae", "fmeasure", "em", "precision", "recall", "msiou"], choices=SUPPORTED_METRICS, help="Names of metrics")
82+
parser.add_argument("--data-type", type=str, default="image", choices=["image", "video"], help="Type of data.")
12783

12884
known_args = parser.parse_known_args()[0]
12985
if known_args.data_type == "video":
130-
parser.add_argument(
131-
"--valid-frame-start",
132-
type=int,
133-
default=0,
134-
help="Valid start index of the frame in each gt video. Defaults to 1, it will skip the first frame. If it is set to None, the code will not skip frames.",
135-
)
136-
parser.add_argument(
137-
"--valid-frame-end",
138-
type=int,
139-
default=0,
140-
help="Valid end index of the frame in each gt video. Defaults to -1, it will skip the last frame. If it is set to 0, the code will not skip frames.",
141-
)
142-
86+
parser.add_argument("--valid-frame-start", type=int, default=0, help="Valid start index of the frame in each gt video. Defaults to 1, it will skip the first frame. If it is set to None, the code will not skip frames.")
87+
parser.add_argument("--valid-frame-end", type=int, default=0, help="Valid end index of the frame in each gt video. Defaults to -1, it will skip the last frame. If it is set to 0, the code will not skip frames.")
88+
# fmt: on
14389
args = parser.parse_args()
14490

14591
if args.data_type == "video":

metrics/draw_curves.py

+21-34
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@
77

88
from utils.recorders import CurveDrawer
99

10+
# Align the mode with those in GRAYSCALE_METRICS
11+
_YX_AXIS_NAMES = {
12+
"pr": ("precision", "recall"),
13+
"fm": ("fmeasure", None),
14+
"fmeasure": ("fmeasure", None),
15+
"em": ("em", None),
16+
"iou": ("iou", None),
17+
"dice": ("dice", None),
18+
}
19+
1020

1121
def draw_curves(
1222
mode: str,
@@ -40,12 +50,8 @@ def draw_curves(
4050
line_width (int, optional): Width of lines. Defaults to 3.
4151
save_name (str, optional): Name or path (without the extension format). Defaults to None.
4252
"""
43-
assert mode in ["pr", "fm", "em"]
4453
save_name = save_name or mode
45-
mode_axes_setting = axes_setting[mode]
46-
47-
x_label, y_label = mode_axes_setting["x_label"], mode_axes_setting["y_label"]
48-
x_ticks, y_ticks = mode_axes_setting["x_ticks"], mode_axes_setting["y_ticks"]
54+
y_axis_name, x_axis_name = _YX_AXIS_NAMES[mode]
4955

5056
assert curves_npy_path
5157
if not isinstance(curves_npy_path, (list, tuple)):
@@ -137,14 +143,6 @@ def draw_curves(
137143

138144
for idx, (dataset_name, dataset_alias) in enumerate(dataset_aliases.items()):
139145
dataset_results = curves[dataset_name]
140-
curve_drawer.set_axis_property(
141-
idx=idx,
142-
title=dataset_alias.upper(),
143-
x_label=x_label,
144-
y_label=y_label,
145-
x_ticks=x_ticks,
146-
y_ticks=y_ticks,
147-
)
148146

149147
for method_name in target_unique_method_names:
150148
method_setting = unique_method_settings[method_name]
@@ -154,30 +152,19 @@ def draw_curves(
154152
continue
155153

156154
method_results = dataset_results[method_name]
157-
if mode == "pr":
158-
y_data = method_results.get("p")
159-
if y_data is None:
160-
y_data = method_results["precision"]
161-
assert isinstance(y_data, (list, tuple)), (method_name, method_results.keys())
162-
163-
x_data = method_results.get("r")
164-
if x_data is None:
165-
x_data = method_results["recall"]
166-
assert isinstance(x_data, (list, tuple)), (method_name, method_results.keys())
167-
elif mode == "fm":
168-
y_data = method_results.get("fm")
169-
if y_data is None:
170-
y_data = method_results["fmeasure"]
171-
assert isinstance(y_data, (list, tuple)), (method_name, method_results.keys())
172155

173-
x_data = np.linspace(0, 1, 256)
174-
elif mode == "em":
175-
y_data = method_results["em"]
156+
if y_axis_name is None:
157+
y_data = np.linspace(0, 1, 256)
158+
else:
159+
y_data = method_results[y_axis_name]
176160
assert isinstance(y_data, (list, tuple)), (method_name, method_results.keys())
177161

162+
if x_axis_name is None:
178163
x_data = np.linspace(0, 1, 256)
164+
else:
165+
x_data = method_results[x_axis_name]
166+
assert isinstance(x_data, (list, tuple)), (method_name, method_results.keys())
179167

180-
curve_drawer.plot_at_axis(
181-
idx=idx, method_curve_setting=method_setting, x_data=x_data, y_data=y_data
182-
)
168+
curve_drawer.plot_at_axis(idx, method_setting, x_data=x_data, y_data=y_data)
169+
curve_drawer.set_axis_property(idx, dataset_alias, **axes_setting[mode])
183170
curve_drawer.save(path=save_name)

plot.py

+49-56
Original file line numberDiff line numberDiff line change
@@ -47,41 +47,18 @@ def get_args():
4747
),
4848
formatter_class=argparse.RawTextHelpFormatter,
4949
)
50+
# fmt: off
5051
parser.add_argument("--alias-yaml", type=str, help="Yaml file for datasets and methods alias.")
51-
parser.add_argument(
52-
"--style-cfg",
53-
type=str,
54-
required=True,
55-
help="Yaml file for plotting curves.",
56-
)
57-
parser.add_argument(
58-
"--curves-npys",
59-
required=True,
60-
type=str,
61-
nargs="+",
62-
help="Npy file for saving curve results.",
63-
)
64-
parser.add_argument(
65-
"--our-methods", type=str, nargs="+", help="Names of our methods for highlighting it."
66-
)
67-
parser.add_argument(
68-
"--num-rows", type=int, default=1, help="Number of rows for subplots. Default: 1"
69-
)
70-
parser.add_argument(
71-
"--num-col-legend", type=int, default=1, help="Number of columns in the legend. Default: 1"
72-
)
73-
parser.add_argument(
74-
"--mode",
75-
type=str,
76-
choices=["pr", "fm", "em"],
77-
default="pr",
78-
help="Mode for plotting. Default: pr",
79-
)
80-
parser.add_argument(
81-
"--separated-legend", action="store_true", help="Use the separated legend."
82-
)
52+
parser.add_argument("--style-cfg", type=str, required=True, help="Yaml file for plotting curves.")
53+
parser.add_argument("--curves-npys", required=True, type=str, nargs="+", help="Npy file for saving curve results.")
54+
parser.add_argument("--our-methods", type=str, nargs="+", help="Names of our methods for highlighting it.")
55+
parser.add_argument("--num-rows", type=int, default=1, help="Number of rows for subplots. Default: 1")
56+
parser.add_argument("--num-col-legend", type=int, default=1, help="Number of columns in the legend. Default: 1")
57+
parser.add_argument("--mode", type=str, choices=["pr", "fm", "em", "iou", "dice"], default="pr", help="Mode for plotting. Default: pr")
58+
parser.add_argument("--separated-legend", action="store_true", help="Use the separated legend.")
8359
parser.add_argument("--sharey", action="store_true", help="Use the shared y-axis.")
8460
parser.add_argument("--save-name", type=str, help="the exported file path")
61+
# fmt: on
8562
args = parser.parse_args()
8663

8764
return args
@@ -95,32 +72,48 @@ def main(args):
9572
method_aliases = aliases.get("method")
9673
dataset_aliases = aliases.get("dataset")
9774

75+
# TODO: Better method to set axes_setting
76+
axes_setting = {
77+
# pr curve
78+
"pr": {
79+
"x_label": "Recall",
80+
"y_label": "Precision",
81+
"x_ticks": np.linspace(0.5, 1, 6),
82+
"y_ticks": np.linspace(0.7, 1, 6),
83+
},
84+
# fm curve
85+
"fm": {
86+
"x_label": "Threshold",
87+
"y_label": r"F$_{\beta}$",
88+
"x_ticks": np.linspace(0, 1, 6),
89+
"y_ticks": np.linspace(0.6, 1, 6),
90+
},
91+
# em curve
92+
"em": {
93+
"x_label": "Threshold",
94+
"y_label": r"E$_{m}$",
95+
"x_ticks": np.linspace(0, 1, 6),
96+
"y_ticks": np.linspace(0.7, 1, 6),
97+
},
98+
# iou curve
99+
"iou": {
100+
"x_label": "Threshold",
101+
"y_label": "IoU",
102+
"x_ticks": np.linspace(0, 1, 6),
103+
"y_ticks": np.linspace(0.4, 1, 6),
104+
},
105+
# dice curve
106+
"dice": {
107+
"x_label": "Threshold",
108+
"y_label": "Dice",
109+
"x_ticks": np.linspace(0, 1, 6),
110+
"y_ticks": np.linspace(0.4, 1, 6),
111+
},
112+
}
113+
98114
draw_curves.draw_curves(
99115
mode=args.mode,
100-
# 不同曲线的绘图配置
101-
axes_setting={
102-
# pr曲线的配置
103-
"pr": {
104-
"x_label": "Recall",
105-
"y_label": "Precision",
106-
"x_ticks": np.linspace(0.5, 1, 6),
107-
"y_ticks": np.linspace(0.7, 1, 6),
108-
},
109-
# fm曲线的配置
110-
"fm": {
111-
"x_label": "Threshold",
112-
"y_label": r"F$_{\beta}$",
113-
"x_ticks": np.linspace(0, 1, 6),
114-
"y_ticks": np.linspace(0.6, 1, 6),
115-
},
116-
# em曲线的配置
117-
"em": {
118-
"x_label": "Threshold",
119-
"y_label": r"E$_{m}$",
120-
"x_ticks": np.linspace(0, 1, 6),
121-
"y_ticks": np.linspace(0.7, 1, 6),
122-
},
123-
},
116+
axes_setting=axes_setting,
124117
curves_npy_path=args.curves_npys,
125118
row_num=args.num_rows,
126119
method_aliases=method_aliases,

tools/converter.py

+8-23
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,15 @@
1010
import numpy as np
1111
import yaml
1212

13-
parser = argparse.ArgumentParser(
14-
description="A useful and convenient tool to convert your .npy results into the table code in latex."
15-
)
16-
parser.add_argument(
17-
"-i",
18-
"--result-file",
19-
required=True,
20-
nargs="+",
21-
action="extend",
22-
help="The path of the *_metrics.npy file.",
23-
)
24-
parser.add_argument(
25-
"-o", "--tex-file", required=True, type=str, help="The path of the exported tex file."
26-
)
27-
parser.add_argument(
28-
"-c", "--config-file", type=str, help="The path of the customized config yaml file."
29-
)
30-
parser.add_argument(
31-
"--contain-table-env",
32-
action="store_true",
33-
help="Whether to containe the table env in the exported code.",
34-
)
13+
# fmt: off
14+
parser = argparse.ArgumentParser(description="A useful and convenient tool to convert your .npy results into the table code in latex.")
15+
parser.add_argument("-i", "--result-file", required=True, nargs="+", action="extend", help="The path of the *_metrics.npy file.")
16+
parser.add_argument("-o", "--tex-file", required=True, type=str, help="The path of the exported tex file.")
17+
parser.add_argument("-c", "--config-file", type=str, help="The path of the customized config yaml file.")
18+
parser.add_argument("--contain-table-env", action="store_true", help="Whether to containe the table env in the exported code.")
3519
parser.add_argument("--num-bits", type=int, default=3, help="Number of valid digits.")
3620
parser.add_argument("--transpose", action="store_true", help="Whether to transpose the table.")
21+
# fmt: on
3722
args = parser.parse_args()
3823

3924
arg_head = f"%% Generated by: {vars(args)}"
@@ -139,7 +124,7 @@ def update_dict(parent_dict, sub_dict):
139124
metric_row_head=" ",
140125
metric_column_head="& ",
141126
body=[
142-
"& \\best{{{txt:.03f}}}", # style for top1
127+
"& \\first{{{txt:.03f}}}", # style for top1
143128
"& \\second{{{txt:.03f}}}", # style for top2
144129
"& \\third{{{txt:.03f}}}", # style for top3
145130
"& {txt:.03f}", # style for other

0 commit comments

Comments
 (0)