-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathvisualize_data.py
101 lines (89 loc) · 3.75 KB
/
visualize_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
import numpy as np
import os
from itertools import chain
import cv2
import tqdm
from PIL import Image
from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_train_loader
from detectron2.data import detection_utils as utils
from detectron2.data.build import filter_images_with_few_keypoints
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
from adet.config import get_cfg
from adet.data.dataset_mapper import DatasetMapperWithBasis
def setup(args):
cfg = get_cfg()
if args.config_file:
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def parse_args(in_args=None):
parser = argparse.ArgumentParser(description="Visualize ground-truth data")
parser.add_argument(
"--source",
choices=["annotation", "dataloader"],
required=True,
help="visualize the annotations or the data loader (with pre-processing)",
)
parser.add_argument("--config-file", metavar="FILE", help="path to config file")
parser.add_argument("--output-dir", default="./", help="path to output directory")
parser.add_argument("--show", action="store_true", help="show output in a window")
parser.add_argument(
"--opts",
help="Modify config options using the command-line",
default=[],
nargs=argparse.REMAINDER,
)
return parser.parse_args(in_args)
if __name__ == "__main__":
args = parse_args()
logger = setup_logger()
logger.info("Arguments: " + str(args))
cfg = setup(args)
dirname = args.output_dir
os.makedirs(dirname, exist_ok=True)
metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
def output(vis, fname):
if args.show:
print(fname)
cv2.imshow("window", vis.get_image()[:, :, ::-1])
cv2.waitKey()
else:
filepath = os.path.join(dirname, fname)
print("Saving to {} ...".format(filepath))
vis.save(filepath)
scale = 2.0 if args.show else 1.0
if args.source == "dataloader":
mapper = DatasetMapperWithBasis(cfg, True)
train_data_loader = build_detection_train_loader(cfg, mapper)
for batch in train_data_loader:
for per_image in batch:
# Pytorch tensor is in (C, H, W) format
img = per_image["image"].permute(1, 2, 0)
if cfg.INPUT.FORMAT == "BGR":
img = img[:, :, [2, 1, 0]]
else:
img = np.asarray(Image.fromarray(img, mode=cfg.INPUT.FORMAT).convert("RGB"))
visualizer = Visualizer(img, metadata=metadata, scale=scale)
target_fields = per_image["instances"].get_fields()
labels = [metadata.thing_classes[i] for i in target_fields["gt_classes"]]
vis = visualizer.overlay_instances(
labels=labels,
boxes=target_fields.get("gt_boxes", None),
masks=target_fields.get("gt_masks", None),
keypoints=target_fields.get("gt_keypoints", None),
)
output(vis, str(per_image["image_id"]) + ".jpg")
else:
dicts = list(chain.from_iterable([DatasetCatalog.get(k) for k in cfg.DATASETS.TRAIN]))
if cfg.MODEL.KEYPOINT_ON:
dicts = filter_images_with_few_keypoints(dicts, 1)
for dic in tqdm.tqdm(dicts):
img = utils.read_image(dic["file_name"], "RGB")
visualizer = Visualizer(img, metadata=metadata, scale=scale)
vis = visualizer.draw_dataset_dict(dic)
output(vis, os.path.basename(dic["file_name"]))