Skip to content

Commit 4e6ec4e

Browse files
committed
add layout
1 parent 07436c2 commit 4e6ec4e

29 files changed

+1334
-111
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import cv2
2+
import numpy as np
3+
4+
import os
5+
import sys
6+
__dir__ = os.path.dirname(os.path.abspath(__file__))
7+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../")))
8+
9+
from mindocr.data.layout_dataset import xyxy2xywh
10+
11+
def letterbox(scaleup):
12+
def func(data):
13+
image = data["image"]
14+
hw_ori = data["raw_img_shape"]
15+
new_shape = data["target_size"]
16+
color = (114, 114, 114)
17+
# Resize and pad image while meeting stride-multiple constraints
18+
shape = image.shape[:2] # current shape [height, width]
19+
h, w = shape[:]
20+
# h0, w0 = hw_ori
21+
h0, w0 = new_shape
22+
# hw_scale = np.array([h / h0, w / w0])
23+
hw_scale = np.array([h0 / h, w0 / w])
24+
if isinstance(new_shape, int):
25+
new_shape = (new_shape, new_shape)
26+
27+
# Scale ratio (new / old)
28+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
29+
if not scaleup: # only scale down, do not scale up (for better test mAP)
30+
r = min(r, 1.0)
31+
32+
# Compute padding
33+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
34+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
35+
36+
dw, dh = dw / 2, dh / 2 # divide padding into 2 sides
37+
hw_pad = np.array([dh, dw])
38+
39+
if shape[::-1] != new_unpad: # resize
40+
image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
41+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
42+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
43+
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
44+
45+
data["image"] = image
46+
data["image_ids"] = 0
47+
data["hw_ori"] = hw_ori
48+
data["hw_scale"] = hw_scale
49+
data["pad"] = hw_pad
50+
return data
51+
52+
return func
53+
54+
55+
def image_norm(scale=255.0):
56+
def func(data):
57+
image = data["image"]
58+
image = image.astype(np.float32, copy=False)
59+
image /= scale
60+
data["image"] = image
61+
return data
62+
63+
return func
64+
65+
66+
def image_transpose(bgr2rgb=True, hwc2chw=True):
67+
def func(data):
68+
image = data["image"]
69+
if bgr2rgb:
70+
image = image[:, :, ::-1]
71+
if hwc2chw:
72+
image = image.transpose(2, 0, 1)
73+
data["image"] = image
74+
return data
75+
76+
return func
77+
78+
def label_norm(labels, xyxy2xywh_=True):
79+
def func(data):
80+
if len(labels) == 0:
81+
return data, labels
82+
83+
if xyxy2xywh_:
84+
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh
85+
86+
labels[:, [2, 4]] /= data.shape[0] # normalized height 0-1
87+
labels[:, [1, 3]] /= data.shape[1] # normalized width 0-1
88+
89+
return data, labels
90+
return func

mindocr/data/transforms/transforms_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .rec_transforms import *
1616
from .svtr_transform import *
1717
from .table_transform import *
18+
from .layout_transform import *
1819

1920
__all__ = ["create_transforms", "run_transforms", "transforms_dbnet_icdar15"]
2021
_logger = logging.getLogger(__name__)

mindocr/infer/classification/cls_infer_node.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def process(self, input_data):
3939
self.send_to_next_module(input_data)
4040
return
4141

42-
data = input_data.data
42+
data = input_data.data["cls_pre_res"]
4343
data = [np.expand_dims(d, 0) for d in data if len(d.shape) == 3]
4444
data = np.concatenate(data, axis=0)
4545

@@ -53,5 +53,6 @@ def process(self, input_data):
5353
pred = self.cls_model([d])
5454
preds.append(pred[0])
5555
preds = np.concatenate(preds, axis=0)
56-
input_data.data = {"pred": preds}
56+
# input_data.data = {"pred": preds}
57+
input_data.data["cls_infer_res"] = {"pred": preds}
5758
self.send_to_next_module(input_data)

mindocr/infer/classification/cls_post_node.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ def process(self, input_data):
4242
self.send_to_next_module(input_data)
4343
return
4444

45-
data = input_data.data
45+
data = input_data.data["cls_infer_res"]
4646
pred = data["pred"]
4747
output = self.cls_postprocess(pred)
4848
angles = output["angles"]
4949
scores = np.array(output["scores"]).tolist()
5050

5151
batch = input_data.sub_image_size
52-
if self.task_type.value == TaskType.DET_CLS_REC.value:
52+
if self.task_type.value in (TaskType.DET_CLS_REC.value, TaskType.Layout_DET_CLS_REC.value):
5353
sub_images = input_data.sub_image_list
5454
for i in range(batch):
5555
angle, score = angles[i], scores[i]
@@ -59,5 +59,4 @@ def process(self, input_data):
5959
else:
6060
input_data.infer_result = [(angle, score) for angle, score in zip(angles, scores)]
6161

62-
input_data.data = None
6362
self.send_to_next_module(input_data)

mindocr/infer/classification/cls_pre_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,5 @@ def process(self, input_data):
3636
else:
3737
sub_image_list = input_data.sub_image_list
3838
data = [self.cls_preprocesser(split_image)["image"] for split_image in sub_image_list]
39-
input_data.data = data
39+
input_data.data["cls_pre_res"] = data
4040
self.send_to_next_module(input_data)

mindocr/infer/common/collect_node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
TaskType.DET_REC: "pipeline_results.txt",
2525
TaskType.DET_CLS_REC: "pipeline_results.txt",
2626
TaskType.LAYOUT: "layout_results.txt",
27+
TaskType.LAYOUT_DET_REC: "pipeline_results.txt",
28+
2729
}
2830

2931

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import os
2+
from collections import defaultdict
3+
from ctypes import c_uint64
4+
from multiprocessing import Manager
5+
6+
import numpy as np
7+
8+
import os
9+
import sys
10+
11+
__dir__ = os.path.dirname(os.path.abspath(__file__))
12+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../")))
13+
14+
from pipeline.data_process.utils import cv_utils
15+
from pipeline.tasks import TaskType
16+
from pipeline.utils import log, safe_list_writer, visual_utils
17+
from pipeline.datatype import ProcessData, ProfilingData, StopData
18+
from pipeline.framework.module_base import ModuleBase
19+
20+
RESULTS_SAVE_FILENAME = {
21+
TaskType.DET: "det_results.txt",
22+
TaskType.CLS: "cls_results.txt",
23+
TaskType.REC: "rec_results.txt",
24+
TaskType.DET_REC: "pipeline_results.txt",
25+
TaskType.DET_CLS_REC: "pipeline_results.txt",
26+
TaskType.LAYOUT: "layout_results.txt",
27+
TaskType.LAYOUT_DET_REC: "pipeline_results.txt",
28+
TaskType.LAYOUT_DET_CLS_REC: "pipeline_results.txt",
29+
}
30+
31+
32+
class CollectNode(ModuleBase):
33+
def __init__(self, args, msg_queue, tqdm_info):
34+
super().__init__(args, msg_queue, tqdm_info)
35+
self.image_sub_remaining = defaultdict(defaultdict)
36+
self.image_pipeline_res = defaultdict(defaultdict)
37+
self.infer_size = defaultdict(int)
38+
self.image_total = Manager().Value(c_uint64, 0)
39+
self.task_type = args.task_type
40+
self.res_save_dir = args.res_save_dir
41+
self.save_filename = RESULTS_SAVE_FILENAME[TaskType(self.task_type.value)]
42+
43+
def init_self_args(self):
44+
super().init_self_args()
45+
46+
def _collect_stop(self, input_data):
47+
self.image_total.value = input_data.image_total
48+
49+
def _vis_results(self, image_name, image, taskid, data_type, task=None):
50+
if self.args.crop_save_dir and (data_type == 0 or (data_type == 1 and self.args.input_array_save_dir)):
51+
basename = os.path.basename(image_name)
52+
filename = os.path.join(self.args.crop_save_dir, os.path.splitext(basename)[0])
53+
box_list = [np.array(x["points"]).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]]
54+
crop_list = visual_utils.vis_crop(image, box_list)
55+
for i, crop in enumerate(crop_list):
56+
cv_utils.img_write(filename + "_crop_" + str(i) + ".jpg", crop)
57+
58+
if self.args.vis_pipeline_save_dir:
59+
basename = os.path.basename(image_name)
60+
filename = os.path.join(self.args.vis_pipeline_save_dir, os.path.splitext(basename)[0])
61+
box_list = [np.array(x["points"]).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]]
62+
text_list = [x["transcription"] for x in self.image_pipeline_res[taskid][image_name]]
63+
box_text = visual_utils.vis_bbox_text(image, box_list, text_list, font_path=self.args.vis_font_path)
64+
cv_utils.img_write(filename + ".jpg", box_text)
65+
66+
if self.args.vis_det_save_dir and (data_type == 0 or (data_type == 1 and self.args.input_array_save_dir)):
67+
basename = os.path.basename(image_name)
68+
filename = os.path.join(self.args.vis_det_save_dir, os.path.splitext(basename)[0])
69+
box_list = [np.array(x).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]]
70+
box_line = visual_utils.vis_bbox(image, box_list, [255, 255, 0], 2)
71+
cv_utils.img_write(filename + ".jpg", box_line)
72+
73+
if self.args.vis_layout_save_dir and (data_type == 0 or (data_type == 1 and self.args.input_array_save_dir)):
74+
basename = os.path.basename(image_name)
75+
filename = os.path.join(self.args.vis_layout_save_dir, os.path.splitext(basename)[0])
76+
box_list = []
77+
for x in self.image_pipeline_res[taskid][image_name]:
78+
x, y, dx, dy = x['bbox']
79+
box_list.append(np.array([[x, y+dy], [x+dx, y+dy], [x+dx, y], [x, y]]))
80+
box_line = visual_utils.vis_bbox(image, box_list, [255, 255, 0], 2)
81+
cv_utils.img_write(filename + ".jpg", box_line)
82+
# log.info(f"{image_name} is finished.")
83+
84+
def final_text_save(self):
85+
rst_dict = dict()
86+
for rst in self.image_pipeline_res.values():
87+
rst_dict.update(rst)
88+
save_filename = os.path.join(self.res_save_dir, self.save_filename)
89+
safe_list_writer(rst_dict, save_filename)
90+
# log.info(f"save infer result to {save_filename} successfully")
91+
92+
def _update_layout_result(self, input_data):
93+
taskid = input_data.taskid
94+
image_path = input_data.image_path[0]
95+
layout_rsts = input_data.data
96+
97+
for layout_rst in layout_rsts["layout_collect_res"]:
98+
# X, Y = layout_rst.data["raw_img_shape"]
99+
layout_bbox = layout_rst.data["layout_result"]
100+
lx, ly, _, _ = layout_bbox['bbox']
101+
for rec_rst in layout_rst.infer_result:
102+
bbox, transcription, score = rec_rst[:-2], rec_rst[-2], rec_rst[-1]
103+
bbox = [[b[0]+lx, b[1]+ly] for b in bbox]
104+
if score > 0.5:
105+
if self.args.result_contain_score:
106+
self.image_pipeline_res[taskid][image_path].append(
107+
{"transcription": transcription, "points": bbox, "score": str(score)}
108+
)
109+
else:
110+
self.image_pipeline_res[taskid][image_path].append(
111+
{"transcription": transcription, "points": bbox}
112+
)
113+
114+
115+
def _collect_results(self, input_data: ProcessData):
116+
taskid = input_data.taskid
117+
if self.task_type.value in (TaskType.DET_REC.value, TaskType.DET_CLS_REC.value):
118+
image_path = input_data.image_path[0] # bs=1
119+
# print(f"input_data.infer_result:{input_data.infer_result}")
120+
for result in input_data.infer_result:
121+
# print(f"result:{result}")
122+
if result[-1] > 0.5:
123+
if self.args.result_contain_score:
124+
self.image_pipeline_res[taskid][image_path].append(
125+
{"transcription": result[-2], "points": result[:-2], "score": str(result[-1])}
126+
)
127+
else:
128+
self.image_pipeline_res[taskid][image_path].append(
129+
{"transcription": result[-2], "points": result[:-2]}
130+
)
131+
if not input_data.infer_result:
132+
self.image_pipeline_res[taskid][image_path] = []
133+
elif self.task_type.value == TaskType.DET.value:
134+
image_path = input_data.image_path[0] # bs=1
135+
self.image_pipeline_res[taskid][image_path] = input_data.infer_result
136+
elif self.task_type.value in (TaskType.REC.value, TaskType.CLS.value):
137+
for image_path, infer_result in zip(input_data.image_path, input_data.infer_result):
138+
self.image_pipeline_res[taskid][image_path] = infer_result
139+
elif self.task_type.value == TaskType.LAYOUT.value:
140+
for infer_result in input_data.infer_result:
141+
image_path = infer_result.pop("image_id")[0]
142+
if image_path in self.image_pipeline_res[taskid]:
143+
self.image_pipeline_res[taskid][image_path].append(infer_result)
144+
else:
145+
self.image_pipeline_res[taskid][image_path] = [infer_result]
146+
elif self.task_type.value in (TaskType.LAYOUT_DET_REC.value, TaskType.LAYOUT_DET_CLS_REC.value,):
147+
self._update_layout_result(input_data)
148+
else:
149+
raise NotImplementedError("Task type do not support.")
150+
151+
self._update_remaining(input_data)
152+
153+
def _update_remaining(self, input_data: ProcessData):
154+
taskid = input_data.taskid
155+
data_type = input_data.data_type
156+
# if self.task_type.value in (TaskType.DET_REC.value, TaskType.DET_CLS_REC.value, TaskType.LAYOUT_DET_REC.value): # with sub image
157+
# for idx, image_path in enumerate(input_data.image_path):
158+
# if image_path in self.image_sub_remaining[taskid]:
159+
# self.image_sub_remaining[taskid][image_path] -= input_data.sub_image_size
160+
# if not self.image_sub_remaining[taskid][image_path]:
161+
# self.image_sub_remaining[taskid].pop(image_path)
162+
# self.infer_size[taskid] += 1
163+
# if self.task_type.value in (TaskType.LAYOUT_DET_REC.value, ):
164+
# self._vis_results(image_path, input_data.data["layout_images"][idx], taskid, data_type) if input_data.frame else ...
165+
# else:
166+
# self._vis_results(
167+
# image_path, input_data.frame[idx], taskid, data_type
168+
# ) if input_data.frame else ...
169+
# else:
170+
# remaining = input_data.sub_image_total - input_data.sub_image_size
171+
# if remaining:
172+
# self.image_sub_remaining[taskid][image_path] = remaining
173+
# else:
174+
# self.infer_size[taskid] += 1
175+
# if self.task_type.value in (TaskType.LAYOUT_DET_REC.value, ):
176+
# self._vis_results(image_path, input_data.data["layout_images"][idx], taskid, data_type) if input_data.frame else ...
177+
# else:
178+
# self._vis_results(
179+
# image_path, input_data.frame[idx], taskid, data_type
180+
# ) if input_data.frame else ...
181+
# else: # without sub image
182+
# if self.task_type.value not in (TaskType.LAYOUT_DET_REC, ):
183+
for idx, image_path in enumerate(input_data.image_path):
184+
self.infer_size[taskid] += 1
185+
if self.task_type.value in (TaskType.LAYOUT_DET_REC.value, ):
186+
self._vis_results(image_path, input_data.frame[idx], taskid, data_type) if input_data.frame else ...
187+
else:
188+
self._vis_results(image_path, input_data.frame[idx], taskid, data_type) if input_data.frame else ...
189+
190+
191+
def process(self, input_data):
192+
if isinstance(input_data, ProcessData):
193+
# print(f"ProcessData:{input_data.image_path}")
194+
taskid = input_data.taskid
195+
if input_data.taskid not in self.image_sub_remaining.keys():
196+
self.image_sub_remaining[input_data.taskid] = defaultdict(int)
197+
if input_data.taskid not in self.image_pipeline_res.keys():
198+
self.image_pipeline_res[input_data.taskid] = defaultdict(list)
199+
self._collect_results(input_data)
200+
if self.infer_size[taskid] == input_data.task_images_num:
201+
self.send_to_next_module({taskid: self.image_pipeline_res[taskid]})
202+
203+
elif isinstance(input_data, StopData):
204+
self._collect_stop(input_data)
205+
if input_data.exception:
206+
self.stop_manager.value = True
207+
else:
208+
raise ValueError("unknown input data")
209+
210+
infer_size_sum = sum(self.infer_size.values())
211+
if self.image_total.value and infer_size_sum == self.image_total.value:
212+
self.final_text_save()
213+
self.stop_manager.value = True
214+
215+
def stop(self):
216+
profiling_data = ProfilingData(
217+
module_name=self.module_name,
218+
instance_id=self.instance_id,
219+
process_cost_time=self.process_cost.value,
220+
send_cost_time=self.send_cost.value,
221+
image_total=self.image_total.value,
222+
)
223+
self.msg_queue.put(profiling_data, block=False)
224+
self.is_stop = True

mindocr/infer/detection/det_infer_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def process(self, input_data):
3232
self.send_to_next_module(input_data)
3333
return
3434

35-
data = input_data.data["image"]
35+
data = input_data.data["det_pre_res"]["image"]
3636
pred = self.det_model([data])
3737

38-
input_data.data = {"pred": pred, "shape_list": input_data.data["shape_list"]}
38+
input_data.data["det_infer_res"] = pred
3939

4040
self.send_to_next_module(input_data)

0 commit comments

Comments
 (0)