Skip to content

Commit 8244f8a

Browse files
committed
And support for non-square anchor layouts, allow generation of anchor labels in loader (collate) for reduced GPU load and less CPU load on primary process.
1 parent 1f12d3a commit 8244f8a

File tree

9 files changed

+316
-222
lines changed

9 files changed

+316
-222
lines changed

effdet/anchors.py

+90-120
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
This module is borrowed from TPU RetinaNet implementation:
2525
https://github.com/tensorflow/tpu/blob/master/models/official/retinanet/anchors.py
2626
"""
27-
from typing import Optional
27+
from typing import Optional, Tuple, Sequence
2828

2929
import numpy as np
3030
import torch
@@ -87,83 +87,6 @@ def decode_box_outputs(rel_codes, anchors, output_xyxy: bool=False):
8787
return out
8888

8989

90-
def _generate_anchor_configs(min_level, max_level, num_scales, aspect_ratios):
91-
"""Generates mapping from output level to a list of anchor configurations.
92-
93-
A configuration is a tuple of (num_anchors, scale, aspect_ratio).
94-
95-
Args:
96-
min_level: integer number of minimum level of the output feature pyramid.
97-
98-
max_level: integer number of maximum level of the output feature pyramid.
99-
100-
num_scales: integer number representing intermediate scales added on each level.
101-
For instances, num_scales=2 adds two additional anchor scales [2^0, 2^0.5] on each level.
102-
103-
aspect_ratios: list of tuples representing the aspect ratio anchors added on each level.
104-
For instances, aspect_ratios = [(1, 1), (1.4, 0.7), (0.7, 1.4)] adds three anchors on each level.
105-
106-
Returns:
107-
anchor_configs: a dictionary with keys as the levels of anchors and
108-
values as a list of anchor configuration.
109-
"""
110-
anchor_configs = {}
111-
for level in range(min_level, max_level + 1):
112-
anchor_configs[level] = []
113-
for scale_octave in range(num_scales):
114-
for aspect in aspect_ratios:
115-
anchor_configs[level].append((2 ** level, scale_octave / float(num_scales), aspect))
116-
return anchor_configs
117-
118-
119-
def _generate_anchor_boxes(image_size, anchor_scale, anchor_configs):
120-
"""Generates multiscale anchor boxes.
121-
122-
Args:
123-
image_size: integer number of input image size. The input image has the same dimension for
124-
width and height. The image_size should be divided by the largest feature stride 2^max_level.
125-
126-
anchor_scale: float number representing the scale of size of the base
127-
anchor to the feature stride 2^level.
128-
129-
anchor_configs: a dictionary with keys as the levels of anchors and
130-
values as a list of anchor configuration.
131-
132-
Returns:
133-
anchor_boxes: a numpy array with shape [N, 4], which stacks anchors on all feature levels.
134-
135-
Raises:
136-
ValueError: input size must be the multiple of largest feature stride.
137-
"""
138-
boxes_all = []
139-
for _, configs in anchor_configs.items():
140-
boxes_level = []
141-
for config in configs:
142-
stride, octave_scale, aspect = config
143-
if image_size % stride != 0:
144-
raise ValueError("input size must be divided by the stride.")
145-
base_anchor_size = anchor_scale * stride * 2 ** octave_scale
146-
anchor_size_x_2 = base_anchor_size * aspect[0] / 2.0
147-
anchor_size_y_2 = base_anchor_size * aspect[1] / 2.0
148-
149-
x = np.arange(stride / 2, image_size, stride)
150-
y = np.arange(stride / 2, image_size, stride)
151-
xv, yv = np.meshgrid(x, y)
152-
xv = xv.reshape(-1)
153-
yv = yv.reshape(-1)
154-
155-
boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2,
156-
yv + anchor_size_y_2, xv + anchor_size_x_2))
157-
boxes = np.swapaxes(boxes, 0, 1)
158-
boxes_level.append(np.expand_dims(boxes, axis=1))
159-
# concat anchors on the same level to the reshape NxAx4
160-
boxes_level = np.concatenate(boxes_level, axis=1)
161-
boxes_all.append(boxes_level.reshape([-1, 4]))
162-
163-
anchor_boxes = np.vstack(boxes_all)
164-
return anchor_boxes
165-
166-
16790
def clip_boxes_xyxy(boxes: torch.Tensor, size: torch.Tensor):
16891
boxes = boxes.clamp(min=0)
16992
size = torch.cat([size, size], dim=0)
@@ -247,10 +170,26 @@ def generate_detections(
247170
return detections
248171

249172

173+
def get_feat_sizes(image_size: Tuple[int, int], max_level: int):
174+
"""Get feat widths and heights for all levels.
175+
Args:
176+
image_size: a tuple (H, W)
177+
max_level: maximum feature level.
178+
Returns:
179+
feat_sizes: a list of tuples (height, width) for each level.
180+
"""
181+
feat_size = image_size
182+
feat_sizes = [feat_size]
183+
for _ in range(1, max_level + 1):
184+
feat_size = ((feat_size[0] - 1) // 2 + 1, (feat_size[1] - 1) // 2 + 1)
185+
feat_sizes.append(feat_size)
186+
return feat_sizes
187+
188+
250189
class Anchors(nn.Module):
251190
"""RetinaNet Anchors class."""
252191

253-
def __init__(self, min_level, max_level, num_scales, aspect_ratios, anchor_scale, image_size):
192+
def __init__(self, min_level, max_level, num_scales, aspect_ratios, anchor_scale, image_size: Tuple[int, int]):
254193
"""Constructs multiscale RetinaNet anchors.
255194
256195
Args:
@@ -278,26 +217,77 @@ def __init__(self, min_level, max_level, num_scales, aspect_ratios, anchor_scale
278217
self.max_level = max_level
279218
self.num_scales = num_scales
280219
self.aspect_ratios = aspect_ratios
281-
self.anchor_scale = anchor_scale
282-
self.image_size = image_size
220+
if isinstance(anchor_scale, Sequence):
221+
assert len(anchor_scale) == max_level - min_level + 1
222+
self.anchor_scales = anchor_scale
223+
else:
224+
self.anchor_scales = [anchor_scale] * (max_level - min_level + 1)
225+
226+
assert isinstance(image_size, Sequence) and len(image_size) == 2
227+
# FIXME this restriction can likely be relaxed with some additional changes
228+
assert image_size[0] % 2 ** max_level == 0, 'Image size must be divisible by 2 ** max_level (128)'
229+
assert image_size[1] % 2 ** max_level == 0, 'Image size must be divisible by 2 ** max_level (128)'
230+
self.image_size = tuple(image_size)
231+
self.feat_sizes = get_feat_sizes(image_size, max_level)
283232
self.config = self._generate_configs()
284233
self.register_buffer('boxes', self._generate_boxes())
285234

286235
def _generate_configs(self):
287236
"""Generate configurations of anchor boxes."""
288-
return _generate_anchor_configs(self.min_level, self.max_level, self.num_scales, self.aspect_ratios)
237+
anchor_configs = {}
238+
feat_sizes = self.feat_sizes
239+
for level in range(self.min_level, self.max_level + 1):
240+
anchor_configs[level] = []
241+
for scale_octave in range(self.num_scales):
242+
for aspect in self.aspect_ratios:
243+
anchor_configs[level].append(
244+
((feat_sizes[0][0] // feat_sizes[level][0],
245+
feat_sizes[0][1] // feat_sizes[level][1]),
246+
scale_octave / float(self.num_scales), aspect,
247+
self.anchor_scales[level - self.min_level]))
248+
return anchor_configs
289249

290250
def _generate_boxes(self):
291251
"""Generates multiscale anchor boxes."""
292-
boxes = _generate_anchor_boxes(self.image_size, self.anchor_scale, self.config)
293-
boxes = torch.from_numpy(boxes).float()
294-
return boxes
252+
boxes_all = []
253+
for _, configs in self.config.items():
254+
boxes_level = []
255+
for config in configs:
256+
stride, octave_scale, aspect, anchor_scale = config
257+
base_anchor_size_x = anchor_scale * stride[1] * 2 ** octave_scale
258+
base_anchor_size_y = anchor_scale * stride[0] * 2 ** octave_scale
259+
if isinstance(aspect, Sequence):
260+
aspect_x = aspect[0]
261+
aspect_y = aspect[1]
262+
else:
263+
aspect_x = np.sqrt(aspect)
264+
aspect_y = 1.0 / aspect_x
265+
anchor_size_x_2 = base_anchor_size_x * aspect_x / 2.0
266+
anchor_size_y_2 = base_anchor_size_y * aspect_y / 2.0
267+
268+
x = np.arange(stride[1] / 2, self.image_size[1], stride[1])
269+
y = np.arange(stride[0] / 2, self.image_size[0], stride[0])
270+
xv, yv = np.meshgrid(x, y)
271+
xv = xv.reshape(-1)
272+
yv = yv.reshape(-1)
273+
274+
boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2,
275+
yv + anchor_size_y_2, xv + anchor_size_x_2))
276+
boxes = np.swapaxes(boxes, 0, 1)
277+
boxes_level.append(np.expand_dims(boxes, axis=1))
278+
279+
# concat anchors on the same level to the reshape NxAx4
280+
boxes_level = np.concatenate(boxes_level, axis=1)
281+
boxes_all.append(boxes_level.reshape([-1, 4]))
282+
283+
anchor_boxes = np.vstack(boxes_all)
284+
anchor_boxes = torch.from_numpy(anchor_boxes).float()
285+
return anchor_boxes
295286

296287
def get_anchors_per_location(self):
297288
return self.num_scales * len(self.aspect_ratios)
298289

299290

300-
#@torch.jit.script
301291
class AnchorLabeler(object):
302292
"""Labeler for multiscale anchor boxes.
303293
"""
@@ -325,9 +315,6 @@ def __init__(self, anchors, num_classes: int, match_threshold: float = 0.5):
325315
self.anchors = anchors
326316
self.match_threshold = match_threshold
327317
self.num_classes = num_classes
328-
self.feat_size = {}
329-
for level in range(self.anchors.min_level, self.anchors.max_level + 1):
330-
self.feat_size[level] = int(self.anchors.image_size / 2 ** level)
331318
self.indices_cache = {}
332319

333320
def label_anchors(self, gt_boxes, gt_labels):
@@ -360,44 +347,25 @@ def label_anchors(self, gt_boxes, gt_labels):
360347
cls_targets, _, box_targets, _, matches = self.target_assigner.assign(anchor_box_list, gt_box_list, gt_labels)
361348

362349
# class labels start from 1 and the background class = -1
363-
cls_targets -= 1
364-
cls_targets = cls_targets.long()
350+
cls_targets = (cls_targets - 1).long()
365351

366352
# Unpack labels.
367353
"""Unpacks an array of cls/box into multiple scales."""
368354
count = 0
369355
for level in range(self.anchors.min_level, self.anchors.max_level + 1):
370-
feat_size = self.feat_size[level]
371-
steps = feat_size ** 2 * self.anchors.get_anchors_per_location()
372-
indices = torch.arange(count, count + steps, device=cls_targets.device)
356+
feat_size = self.anchors.feat_sizes[level]
357+
steps = feat_size[0] * feat_size[1] * self.anchors.get_anchors_per_location()
358+
cls_targets_out.append(cls_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
359+
box_targets_out.append(box_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
373360
count += steps
374-
cls_targets_out.append(
375-
torch.index_select(cls_targets, 0, indices).view([feat_size, feat_size, -1]))
376-
box_targets_out.append(
377-
torch.index_select(box_targets, 0, indices).view([feat_size, feat_size, -1]))
378361

379362
num_positives = (matches.match_results != -1).float().sum()
380363

381364
return cls_targets_out, box_targets_out, num_positives
382365

383-
def _build_indices(self, device):
384-
anchors_per_loc = self.anchors.get_anchors_per_location()
385-
indices_dict = {}
386-
count = 0
387-
for level in range(self.anchors.min_level, self.anchors.max_level + 1):
388-
feat_size = self.feat_size[level]
389-
steps = feat_size ** 2 * anchors_per_loc
390-
indices = torch.arange(count, count + steps, device=device)
391-
indices_dict[level] = indices
392-
count += steps
393-
return indices_dict
394-
395-
def _get_indices(self, device, level):
396-
if device not in self.indices_cache:
397-
self.indices_cache[device] = self._build_indices(device)
398-
return self.indices_cache[device][level]
399-
400-
def batch_label_anchors(self, batch_size: int, gt_boxes, gt_classes):
366+
def batch_label_anchors(self, gt_boxes, gt_classes):
367+
batch_size = len(gt_boxes)
368+
assert len(gt_classes) == len(gt_boxes)
401369
num_levels = self.anchors.max_level - self.anchors.min_level + 1
402370
cls_targets_out = [[] for _ in range(num_levels)]
403371
box_targets_out = [[] for _ in range(num_levels)]
@@ -416,14 +384,16 @@ def batch_label_anchors(self, batch_size: int, gt_boxes, gt_classes):
416384

417385
# Unpack labels.
418386
"""Unpacks an array of cls/box into multiple scales."""
387+
count = 0
419388
for level in range(self.anchors.min_level, self.anchors.max_level + 1):
420389
level_index = level - self.anchors.min_level
421-
feat_size = self.feat_size[level]
422-
indices = self._get_indices(cls_targets.device, level)
390+
feat_size = self.anchors.feat_sizes[level]
391+
steps = feat_size[0] * feat_size[1] * self.anchors.get_anchors_per_location()
423392
cls_targets_out[level_index].append(
424-
torch.index_select(cls_targets, 0, indices).view([feat_size, feat_size, -1]))
393+
cls_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
425394
box_targets_out[level_index].append(
426-
torch.index_select(box_targets, 0, indices).view([feat_size, feat_size, -1]))
395+
box_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
396+
count += steps
427397
if last_sample:
428398
cls_targets_out[level_index] = torch.stack(cls_targets_out[level_index])
429399
box_targets_out[level_index] = torch.stack(box_targets_out[level_index])

effdet/bench.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,31 @@ def forward(self, x, img_info: Dict[str, torch.Tensor] = None):
8585

8686

8787
class DetBenchTrain(nn.Module):
88-
def __init__(self, model):
88+
def __init__(self, model, no_labeler=False):
8989
super(DetBenchTrain, self).__init__()
9090
self.model = model
9191
self.config = model.config
9292
self.anchors = Anchors(
9393
self.config.min_level, self.config.max_level,
9494
self.config.num_scales, self.config.aspect_ratios,
9595
self.config.anchor_scale, self.config.image_size)
96-
self.anchor_labeler = AnchorLabeler(self.anchors, self.config.num_classes, match_threshold=0.5)
96+
self.anchor_labeler = None
97+
if not no_labeler:
98+
self.anchor_labeler = AnchorLabeler(self.anchors, self.config.num_classes, match_threshold=0.5)
9799
self.loss_fn = DetectionLoss(self.config)
98100

99101
def forward(self, x, target: Dict[str, torch.Tensor]):
100102
class_out, box_out = self.model(x)
101-
cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
102-
x.shape[0], target['bbox'], target['cls'])
103+
if self.anchor_labeler is None:
104+
# target should contain pre-computed anchor labels
105+
assert 'label_num_positives' in target
106+
cls_targets = [target[f'label_cls_{l}'] for l in range(self.config.num_levels)]
107+
box_targets = [target[f'label_bbox_{l}'] for l in range(self.config.num_levels)]
108+
num_positives = target['label_num_positives']
109+
else:
110+
cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
111+
target['bbox'], target['cls'])
112+
103113
loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives)
104114
output = dict(loss=loss, class_loss=class_loss, box_loss=box_loss)
105115
if not self.training:

0 commit comments

Comments
 (0)