Skip to content

Commit 0b78eba

Browse files
authored
Implemented yolo dataset support (#487)
* implemented yolo data loader * added yolo example configuration * fixed super call for yolo data loader * converted normalized values to pixels for yolo dataset * run pre-commit and fixed coordinate bug * fixed yolo categories indexed by zero * added readme hint for yolo format
1 parent 0036f94 commit 0b78eba

File tree

4 files changed

+314
-0
lines changed

4 files changed

+314
-0
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ NanoDet-RepVGG | RepVGG-A0 | 416*416 | 27.8 | 11.3G | 6.75M |
220220

221221
If your dataset annotations are pascal voc xml format, refer to [config/nanodet_custom_xml_dataset.yml](config/nanodet_custom_xml_dataset.yml)
222222

223+
Otherwise, if your dataset annotations are YOLO format ([Darknet TXT](https://github.com/AlexeyAB/Yolo_mark/issues/60#issuecomment-401854885)), refer to [config/nanodet-plus-m_416-yolo.yml](config/nanodet-plus-m_416-yolo.yml)
224+
223225
Or convert your dataset annotations to MS COCO format[(COCO annotation format details)](https://cocodataset.org/#format-data).
224226

225227
2. **Prepare config file**

config/nanodet-plus-m_416-yolo.yml

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# nanodet-plus-m_416
2+
# COCO mAP(0.5:0.95) = 0.304
3+
# AP_50 = 0.459
4+
# AP_75 = 0.317
5+
# AP_small = 0.106
6+
# AP_m = 0.322
7+
# AP_l = 0.477
8+
save_dir: workspace/nanodet-plus-m_416
9+
model:
10+
weight_averager:
11+
name: ExpMovingAverager
12+
decay: 0.9998
13+
arch:
14+
name: NanoDetPlus
15+
detach_epoch: 10
16+
backbone:
17+
name: ShuffleNetV2
18+
model_size: 1.0x
19+
out_stages: [2,3,4]
20+
activation: LeakyReLU
21+
fpn:
22+
name: GhostPAN
23+
in_channels: [116, 232, 464]
24+
out_channels: 96
25+
kernel_size: 5
26+
num_extra_level: 1
27+
use_depthwise: True
28+
activation: LeakyReLU
29+
head:
30+
name: NanoDetPlusHead
31+
num_classes: 80
32+
input_channel: 96
33+
feat_channels: 96
34+
stacked_convs: 2
35+
kernel_size: 5
36+
strides: [8, 16, 32, 64]
37+
activation: LeakyReLU
38+
reg_max: 7
39+
norm_cfg:
40+
type: BN
41+
loss:
42+
loss_qfl:
43+
name: QualityFocalLoss
44+
use_sigmoid: True
45+
beta: 2.0
46+
loss_weight: 1.0
47+
loss_dfl:
48+
name: DistributionFocalLoss
49+
loss_weight: 0.25
50+
loss_bbox:
51+
name: GIoULoss
52+
loss_weight: 2.0
53+
# Auxiliary head, only use in training time.
54+
aux_head:
55+
name: SimpleConvHead
56+
num_classes: 80
57+
input_channel: 192
58+
feat_channels: 192
59+
stacked_convs: 4
60+
strides: [8, 16, 32, 64]
61+
activation: LeakyReLU
62+
reg_max: 7
63+
64+
class_names: &class_names ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
65+
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
66+
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
67+
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
68+
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
69+
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
70+
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
71+
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
72+
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
73+
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
74+
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
75+
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
76+
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
77+
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush']
78+
79+
data:
80+
train:
81+
name: YoloDataset
82+
img_path: coco/train2017
83+
ann_path: coco/train2017
84+
class_names: *class_names
85+
input_size: [416,416] #[w,h]
86+
keep_ratio: False
87+
pipeline:
88+
perspective: 0.0
89+
scale: [0.6, 1.4]
90+
stretch: [[0.8, 1.2], [0.8, 1.2]]
91+
rotation: 0
92+
shear: 0
93+
translate: 0.2
94+
flip: 0.5
95+
brightness: 0.2
96+
contrast: [0.6, 1.4]
97+
saturation: [0.5, 1.2]
98+
normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]]
99+
val:
100+
name: YoloDataset
101+
img_path: coco/val2017
102+
ann_path: coco/val2017
103+
class_names: *class_names
104+
input_size: [416,416] #[w,h]
105+
keep_ratio: False
106+
pipeline:
107+
normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]]
108+
device:
109+
gpu_ids: [0]
110+
workers_per_gpu: 10
111+
batchsize_per_gpu: 96
112+
schedule:
113+
# resume:
114+
# load_model:
115+
optimizer:
116+
name: AdamW
117+
lr: 0.001
118+
weight_decay: 0.05
119+
warmup:
120+
name: linear
121+
steps: 500
122+
ratio: 0.0001
123+
total_epochs: 300
124+
lr_schedule:
125+
name: CosineAnnealingLR
126+
T_max: 300
127+
eta_min: 0.00005
128+
val_intervals: 10
129+
grad_clip: 35
130+
evaluator:
131+
name: CocoDetectionEvaluator
132+
save_key: mAP
133+
log:
134+
interval: 50

nanodet/data/dataset/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from .coco import CocoDataset
1919
from .xml_dataset import XMLDataset
20+
from .yolo import YoloDataset
2021

2122

2223
def build_dataset(cfg, mode):
@@ -27,6 +28,8 @@ def build_dataset(cfg, mode):
2728
"Dataset name coco has been deprecated. Please use CocoDataset instead."
2829
)
2930
return CocoDataset(mode=mode, **dataset_cfg)
31+
elif name == "yolo":
32+
return YoloDataset(mode=mode, **dataset_cfg)
3033
elif name == "xml_dataset":
3134
warnings.warn(
3235
"Dataset name xml_dataset has been deprecated. "
@@ -35,6 +38,8 @@ def build_dataset(cfg, mode):
3538
return XMLDataset(mode=mode, **dataset_cfg)
3639
elif name == "CocoDataset":
3740
return CocoDataset(mode=mode, **dataset_cfg)
41+
elif name == "YoloDataset":
42+
return YoloDataset(mode=mode, **dataset_cfg)
3843
elif name == "XMLDataset":
3944
return XMLDataset(mode=mode, **dataset_cfg)
4045
else:

nanodet/data/dataset/yolo.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright 2023 cansik.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import os
17+
import time
18+
from collections import defaultdict
19+
from typing import Optional, Sequence
20+
21+
import cv2
22+
import numpy as np
23+
from pycocotools.coco import COCO
24+
25+
from .coco import CocoDataset
26+
from .xml_dataset import get_file_list
27+
28+
29+
class CocoYolo(COCO):
30+
def __init__(self, annotation):
31+
"""
32+
Constructor of Microsoft COCO helper class for
33+
reading and visualizing annotations.
34+
:param annotation: annotation dict
35+
:return:
36+
"""
37+
# load dataset
38+
super().__init__()
39+
self.dataset, self.anns, self.cats, self.imgs = dict(), dict(), dict(), dict()
40+
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
41+
dataset = annotation
42+
assert type(dataset) == dict, "annotation file format {} not supported".format(
43+
type(dataset)
44+
)
45+
self.dataset = dataset
46+
self.createIndex()
47+
48+
49+
class YoloDataset(CocoDataset):
50+
def __init__(self, class_names, **kwargs):
51+
self.class_names = class_names
52+
super(YoloDataset, self).__init__(**kwargs)
53+
54+
@staticmethod
55+
def _find_image(
56+
image_prefix: str,
57+
image_types: Sequence[str] = (".png", ".jpg", ".jpeg", ".bmp", ".tiff"),
58+
) -> Optional[str]:
59+
for image_type in image_types:
60+
path = f"{image_prefix}{image_type}"
61+
if os.path.exists(path):
62+
return path
63+
return None
64+
65+
def yolo_to_coco(self, ann_path):
66+
"""
67+
convert xml annotations to coco_api
68+
:param ann_path:
69+
:return:
70+
"""
71+
logging.info("loading annotations into memory...")
72+
tic = time.time()
73+
ann_file_names = get_file_list(ann_path, type=".txt")
74+
logging.info("Found {} annotation files.".format(len(ann_file_names)))
75+
image_info = []
76+
categories = []
77+
annotations = []
78+
for idx, supercat in enumerate(self.class_names):
79+
categories.append(
80+
{"supercategory": supercat, "id": idx + 1, "name": supercat}
81+
)
82+
ann_id = 1
83+
84+
for idx, txt_name in enumerate(ann_file_names):
85+
ann_file = os.path.join(ann_path, txt_name)
86+
image_file = self._find_image(os.path.splitext(ann_file)[0])
87+
88+
if image_file is None:
89+
logging.warning(f"Could not find image for {ann_file}")
90+
continue
91+
92+
with open(ann_file, "r") as f:
93+
lines = f.readlines()
94+
95+
image = cv2.imread(image_file)
96+
height, width = image.shape[:2]
97+
98+
file_name = os.path.basename(image_file)
99+
info = {
100+
"file_name": file_name,
101+
"height": height,
102+
"width": width,
103+
"id": idx + 1,
104+
}
105+
image_info.append(info)
106+
for line in lines:
107+
data = [float(t) for t in line.split(" ")]
108+
cat_id = int(data[0])
109+
locations = np.array(data[1:]).reshape((len(data) // 2, 2))
110+
bbox = locations[0:2]
111+
112+
bbox[0] -= bbox[1] * 0.5
113+
114+
bbox = np.round(bbox * np.array([width, height])).astype(int)
115+
x, y = bbox[0][0], bbox[0][1]
116+
w, h = bbox[1][0], bbox[1][1]
117+
118+
if cat_id >= len(self.class_names):
119+
logging.warning(
120+
f"Category {cat_id} is not defined in config ({txt_name})"
121+
)
122+
continue
123+
124+
if w < 0 or h < 0:
125+
logging.warning(
126+
"WARNING! Find error data in file {}! Box w and "
127+
"h should > 0. Pass this box annotation.".format(txt_name)
128+
)
129+
continue
130+
131+
coco_box = [max(x, 0), max(y, 0), min(w, width), min(h, height)]
132+
ann = {
133+
"image_id": idx + 1,
134+
"bbox": coco_box,
135+
"category_id": cat_id + 1,
136+
"iscrowd": 0,
137+
"id": ann_id,
138+
"area": coco_box[2] * coco_box[3],
139+
}
140+
annotations.append(ann)
141+
ann_id += 1
142+
143+
coco_dict = {
144+
"images": image_info,
145+
"categories": categories,
146+
"annotations": annotations,
147+
}
148+
logging.info(
149+
"Load {} txt files and {} boxes".format(len(image_info), len(annotations))
150+
)
151+
logging.info("Done (t={:0.2f}s)".format(time.time() - tic))
152+
return coco_dict
153+
154+
def get_data_info(self, ann_path):
155+
"""
156+
Load basic information of dataset such as image path, label and so on.
157+
:param ann_path: coco json file path
158+
:return: image info:
159+
[{'file_name': '000000000139.jpg',
160+
'height': 426,
161+
'width': 640,
162+
'id': 139},
163+
...
164+
]
165+
"""
166+
coco_dict = self.yolo_to_coco(ann_path)
167+
self.coco_api = CocoYolo(coco_dict)
168+
self.cat_ids = sorted(self.coco_api.getCatIds())
169+
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
170+
self.cats = self.coco_api.loadCats(self.cat_ids)
171+
self.img_ids = sorted(self.coco_api.imgs.keys())
172+
img_info = self.coco_api.loadImgs(self.img_ids)
173+
return img_info

0 commit comments

Comments
 (0)