Skip to content

Commit e18ad8e

Browse files
committed
Update train script to use native or apex AMP, update to latest timm utils interface, fix sotabench, add resdet50 weights, add csp model defs, cleanup collate fn
1 parent b523cf6 commit e18ad8e

File tree

9 files changed

+215
-133
lines changed

9 files changed

+215
-133
lines changed

effdet/bench.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
Hacked together by Ross Wightman
44
"""
5+
from typing import Optional, Dict
56
import torch
67
import torch.nn as nn
78
from timm.utils import ModelEma
@@ -50,12 +51,15 @@ def _post_process(config, cls_outputs, box_outputs):
5051

5152
@torch.jit.script
5253
def _batch_detection(
53-
batch_size: int, class_out, box_out, anchor_boxes, indices, classes, img_scale, img_size):
54+
batch_size: int, class_out, box_out, anchor_boxes, indices, classes,
55+
img_scale: Optional[torch.Tensor] = None, img_size: Optional[torch.Tensor] = None):
5456
batch_detections = []
5557
# FIXME we may be able to do this as a batch with some tensor reshaping/indexing, PR welcome
5658
for i in range(batch_size):
59+
img_scale_i = None if img_scale is None else img_scale[i]
60+
img_size_i = None if img_size is None else img_size[i]
5761
detections = generate_detections(
58-
class_out[i], box_out[i], anchor_boxes, indices[i], classes[i], img_scale[i], img_size[i])
62+
class_out[i], box_out[i], anchor_boxes, indices[i], classes[i], img_scale_i, img_size_i)
5963
batch_detections.append(detections)
6064
return torch.stack(batch_detections, dim=0)
6165

@@ -70,11 +74,14 @@ def __init__(self, model):
7074
self.config.num_scales, self.config.aspect_ratios,
7175
self.config.anchor_scale, self.config.image_size)
7276

73-
def forward(self, x, img_scales, img_size):
77+
def forward(self, x, img_info: Dict[str, torch.Tensor] = None):
7478
class_out, box_out = self.model(x)
7579
class_out, box_out, indices, classes = _post_process(self.config, class_out, box_out)
80+
img_info = img_info or {}
81+
img_scale = img_info['img_scale'] if 'img_scale' in img_info else None
82+
img_size = img_info['img_size'] if 'img_size' in img_info else None
7683
return _batch_detection(
77-
x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes, img_scales, img_size)
84+
x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes, img_scale, img_size)
7885

7986

8087
class DetBenchTrain(nn.Module):
@@ -89,7 +96,7 @@ def __init__(self, model):
8996
self.anchor_labeler = AnchorLabeler(self.anchors, self.config.num_classes, match_threshold=0.5)
9097
self.loss_fn = DetectionLoss(self.config)
9198

92-
def forward(self, x, target):
99+
def forward(self, x, target: Dict[str, torch.Tensor]):
93100
class_out, box_out = self.model(x)
94101
cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
95102
x.shape[0], target['bbox'], target['cls'])

effdet/config/model_config.py

+45-3
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def default_detection_model_configs():
140140
# My own experimental configs with alternate models, training TBD
141141
# Note: any 'timm' model in the EfficientDet family can be used as a backbone here.
142142
resdet50=dict(
143-
name='resdet50', # 'wide'
143+
name='resdet50',
144144
backbone_name='resnet50',
145145
image_size=640,
146146
fpn_channels=88,
@@ -150,8 +150,50 @@ def default_detection_model_configs():
150150
act_type='relu',
151151
redundant_bias=False,
152152
separable_conv=False,
153-
backbone_args=dict(drop_path_rate=0.1),
154-
url='', # no pretrained weights yet
153+
backbone_args=dict(drop_path_rate=0.2),
154+
url='https://github.com/rwightman/efficientdet-pytorch/releases/download/v0.1/resdet50_416-08676892.pth',
155+
),
156+
cspresdet50=dict(
157+
name='cspresdet50',
158+
backbone_name='cspresnet50',
159+
image_size=640,
160+
fpn_channels=88,
161+
fpn_cell_repeats=4,
162+
box_class_repeats=3,
163+
pad_type='',
164+
act_type='leaky_relu',
165+
redundant_bias=False,
166+
separable_conv=False,
167+
backbone_args=dict(drop_path_rate=0.2),
168+
url='',
169+
),
170+
cspresdext50=dict(
171+
name='cspresdext50',
172+
backbone_name='cspresnext50',
173+
image_size=640,
174+
fpn_channels=88,
175+
fpn_cell_repeats=4,
176+
box_class_repeats=3,
177+
pad_type='',
178+
act_type='leaky_relu',
179+
redundant_bias=False,
180+
separable_conv=False,
181+
backbone_args=dict(drop_path_rate=0.2),
182+
url='',
183+
),
184+
cspdarkdet53=dict(
185+
name='cspdarkdet53',
186+
backbone_name='cspdarknet53',
187+
image_size=640,
188+
fpn_channels=88,
189+
fpn_cell_repeats=4,
190+
box_class_repeats=3,
191+
pad_type='',
192+
act_type='leaky_relu',
193+
redundant_bias=False,
194+
separable_conv=False,
195+
backbone_args=dict(drop_path_rate=0.2),
196+
url='',
155197
),
156198
efficientdet_w0=dict(
157199
name='efficientdet_w0', # 'wide'

effdet/data/loader.py

+39-33
Original file line numberDiff line numberDiff line change
@@ -11,48 +11,54 @@
1111
MAX_NUM_INSTANCES = 100
1212

1313

14-
class FastCollate:
14+
class DetectionFastCollate:
1515

16-
def __init__(self):
17-
pass
16+
def __init__(self, instance_keys=None, instance_shapes=None, instance_fill=-1, max_instances=MAX_NUM_INSTANCES):
17+
instance_keys = instance_keys or {'bbox', 'bbox_ignore', 'cls'}
18+
instance_shapes = instance_shapes or dict(
19+
bbox=(max_instances, 4), bbox_ignore=(max_instances, 4), cls=(max_instances,))
20+
self.instance_info = {k: dict(fill=instance_fill, shape=instance_shapes[k]) for k in instance_keys}
21+
self.max_instances = max_instances
1822

1923
def __call__(self, batch):
2024
batch_size = len(batch)
21-
22-
# FIXME this needs to be more robust
2325
target = dict()
24-
for k, v in batch[0][1].items():
25-
if isinstance(v, np.ndarray):
26-
# if a numpy array, assume it relates to object instances, pad to MAX_NUM_INSTANCES
27-
target_shape = (batch_size, MAX_NUM_INSTANCES)
28-
if len(v.shape) > 1:
29-
target_shape = target_shape + v.shape[1:]
30-
target_dtype = torch.float32
26+
27+
def _get_target(k, v):
28+
if k in target:
29+
return target[k], k in self.instance_info
30+
is_instance = False
31+
fill_value = 0
32+
if k in self.instance_info:
33+
info = self.instance_info[k]
34+
is_instance = True
35+
fill_value = info['fill']
36+
shape = (batch_size,) + info['shape']
37+
dtype = torch.float32
3138
elif isinstance(v, (tuple, list)):
32-
# if tuple or list, assume per elem
33-
target_shape = (batch_size, len(v))
34-
target_dtype = torch.float32 if isinstance(v[0], float) else torch.int32
39+
# per batch elem sequence
40+
shape = (batch_size, len(v))
41+
dtype = torch.float32 if isinstance(v[0], (float, np.floating)) else torch.int32
3542
else:
36-
# scalar, assume per elem
37-
target_shape = batch_size,
38-
target_dtype = torch.float32 if isinstance(v, float) else torch.int64
39-
target[k] = torch.zeros(target_shape, dtype=target_dtype)
40-
41-
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
43+
# per batch elem scalar
44+
shape = batch_size,
45+
dtype = torch.float32 if isinstance(v, (float, np.floating)) else torch.int64
46+
target_tensor = torch.full(shape, fill_value, dtype=dtype)
47+
target[k] = target_tensor
48+
return target_tensor, is_instance
49+
50+
img_tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
4251
for i in range(batch_size):
43-
tensor[i] += torch.from_numpy(batch[i][0])
52+
img_tensor[i] += torch.from_numpy(batch[i][0])
4453
for tk, tv in batch[i][1].items():
45-
if isinstance(tv, np.ndarray) and len(tv.shape):
46-
num_elem = min(tv.shape[0], MAX_NUM_INSTANCES)
47-
target[tk][i, 0:num_elem] = torch.from_numpy(tv[0:num_elem])
54+
target_tensor, is_instance = _get_target(tk, tv)
55+
if is_instance:
56+
num_elem = min(tv.shape[0], self.max_instances)
57+
target_tensor[i, 0:num_elem] = torch.from_numpy(tv[0:num_elem])
4858
else:
49-
target[tk][i] = torch.tensor(tv, dtype=target[tk].dtype)
50-
51-
return tensor, target
52-
59+
target_tensor[i] = torch.tensor(tv, dtype=target_tensor.dtype)
5360

54-
def _to_gpu(v):
55-
return v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v
61+
return img_tensor, target
5662

5763

5864
class PrefetchLoader:
@@ -81,7 +87,7 @@ def __iter__(self):
8187
with torch.cuda.stream(stream):
8288
next_input = next_input.cuda(non_blocking=True)
8389
next_input = next_input.float().sub_(self.mean).div_(self.std)
84-
next_target = {k: _to_gpu(v) for k, v in next_target.items()}
90+
next_target = {k: v.cuda(non_blocking=True) for k, v in next_target.items()}
8591
if self.random_erasing is not None:
8692
next_input = self.random_erasing(next_input, next_target)
8793

@@ -165,7 +171,7 @@ def create_loader(
165171
num_workers=num_workers,
166172
sampler=sampler,
167173
pin_memory=pin_mem,
168-
collate_fn=FastCollate() if use_prefetcher else torch.utils.data.dataloader.default_collate,
174+
collate_fn=DetectionFastCollate() if use_prefetcher else torch.utils.data.dataloader.default_collate,
169175
)
170176
if use_prefetcher:
171177
if is_train:

effdet/efficientdet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def __init__(self, config, feature_info):
252252
super(BiFpn, self).__init__()
253253
norm_layer = config.norm_layer or nn.BatchNorm2d
254254
norm_kwargs = config.norm_kwargs or {}
255-
act_layer = get_act_layer(config.act_layer) or _ACT_LAYER
255+
act_layer = get_act_layer(config.act_type) or _ACT_LAYER
256256
self.config = config
257257
fpn_config = config.fpn_config or get_fpn_config(
258258
config.fpn_name, min_level=config.min_level, max_level=config.max_level)
@@ -314,7 +314,7 @@ def __init__(self, config, num_outputs):
314314
super(HeadNet, self).__init__()
315315
norm_layer = config.norm_layer or nn.BatchNorm2d
316316
norm_kwargs = config.norm_kwargs or {}
317-
act_layer = get_act_layer(config.act_layer) or _ACT_LAYER
317+
act_layer = get_act_layer(config.act_type) or _ACT_LAYER
318318
self.config = config
319319
num_anchors = len(config.aspect_ratios) * config.num_scales
320320

effdet/factory.py

+6-18
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,12 @@ def create_model_from_config(
2222
if pretrained or checkpoint_path:
2323
pretrained_backbone = False # no point in loading backbone weights
2424

25-
# Config overrides, override some config value from args. FIXME need a cleaner mechanism or allow
26-
# config defs via files.
27-
redundant_bias = kwargs.pop('redundant_bias', None)
28-
if redundant_bias is not None:
29-
# override config if set to something
30-
config.redundant_bias = redundant_bias
31-
32-
label_smoothing = kwargs.pop('label_smoothing', None)
33-
if label_smoothing is not None:
34-
config.label_smoothing = label_smoothing
35-
36-
legacy_focal = kwargs.pop('legacy_focal', None)
37-
if legacy_focal is not None:
38-
config.legacy_focal = legacy_focal
39-
40-
jit_loss = kwargs.pop('jit_loss', None)
41-
if jit_loss is not None:
42-
config.jit_loss = jit_loss
25+
# Config overrides, override some config values via kwargs.
26+
overrides = ('redundant_bias', 'label_smoothing', 'legacy_focal', 'jit_loss')
27+
for ov in overrides:
28+
value = kwargs.pop(ov, None)
29+
if value is not None:
30+
setattr(config, ov, value)
4331

4432
# create the base model
4533
model = EfficientDet(config, pretrained_backbone=pretrained_backbone, **kwargs)

effdet/loss.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def focal_loss(logits, targets, alpha: float, gamma: float, normalizer, label_sm
101101
modulating_factor = (1. - p_t) ** gamma
102102

103103
# apply label smoothing for cross_entropy for each entry.
104-
targets = targets * (1. - label_smoothing) + .5 * label_smoothing
104+
if label_smoothing > 0.:
105+
targets = targets * (1. - label_smoothing) + .5 * label_smoothing
105106
ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
106107

107108
# compute the final loss and return
@@ -229,7 +230,7 @@ def loss_fn(
229230
alpha=alpha, gamma=gamma, normalizer=num_positives_sum, label_smoothing=label_smoothing)
230231
cls_loss = cls_loss.view(bs, height, width, -1, num_classes)
231232
cls_loss = cls_loss * (cls_targets_at_level != -2).unsqueeze(-1)
232-
cls_losses.append(cls_loss.sum())
233+
cls_losses.append(cls_loss.sum()) # FIXME reference code added a clamp here at some point ...clamp(0, 2))
233234

234235
box_losses.append(_box_loss(
235236
box_outputs[l].permute(0, 2, 3, 1).float(),
@@ -271,8 +272,6 @@ def forward(
271272
box_targets: List[torch.Tensor],
272273
num_positives: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
273274

274-
# FIXME I'd like to assign and script the loss fun in the init but deepcopy doesn't work with
275-
# ScriptedFunction/ScriptedModule members right now and deepcopy is required for ModelEma as currently impl
276275
loss_kwargs = dict(
277276
num_classes=self.num_classes, alpha=self.alpha, gamma=self.gamma, delta=self.delta,
278277
box_loss_weight=self.box_loss_weight, label_smoothing=self.label_smoothing, legacy_focal=self.legacy_focal)

sotabench.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
has_amp = False
99
from sotabencheval.object_detection import COCOEvaluator
1010
from sotabencheval.utils import is_server, extract_archive
11-
from effdet import create_model
12-
from effdet.data import DetectionDatset, create_loader
11+
from effdet import create_model, create_loader, create_dataset
1312

1413
NUM_GPU = 1
1514
BATCH_SIZE = (128 if has_amp else 64) * NUM_GPU
@@ -42,6 +41,17 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
4241
# NOTE For any original PyTorch models, I'll remove from this list when you add to sotabench to
4342
# avoid overlap and confusion. Please contact me.
4443
model_list = [
44+
45+
## Weights trained by myself or others in PyTorch
46+
_entry('resdet50', 'ResDet50', '1911.09070', batch_size=_bs(72),
47+
model_desc='Trained in PyTorch with https://github.com/rwightman/efficientdet-pytorch'),
48+
_entry('tf_efficientdet_lite0', 'EfficientDet-Lite0', '1911.09070', batch_size=_bs(128),
49+
model_desc='Trained in PyTorch with https://github.com/rwightman/efficientdet-pytorch'),
50+
_entry('efficientdet_d0', 'EfficientDet-D0', '1911.09070', batch_size=_bs(112),
51+
model_desc='Trained in PyTorch with https://github.com/rwightman/efficientdet-pytorch'),
52+
_entry('efficientdet_d1', 'EfficientDet-D1', '1911.09070', batch_size=_bs(72),
53+
model_desc='Trained in PyTorch with https://github.com/rwightman/efficientdet-pytorch'),
54+
4555
## Weights ported by myself from other frameworks
4656
_entry('tf_efficientdet_d0', 'EfficientDet-D0', '1911.09070', batch_size=_bs(112),
4757
model_desc='Ported from official Google AI Tensorflow weights'),
@@ -59,10 +69,8 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
5969
model_desc='Ported from official Google AI Tensorflow weights'),
6070
_entry('tf_efficientdet_d7', 'EfficientDet-D7', '1911.09070', batch_size=_bs(4),
6171
model_desc='Ported from official Google AI Tensorflow weights'),
62-
63-
## Weights trained by myself in PyTorch
64-
_entry('efficientdet_d0', 'EfficientDet-D0', '1911.09070', batch_size=_bs(112),
65-
model_desc='Trained in PyTorch with https://github.com/rwightman/efficientdet-pytorch'),
72+
# _entry('tf_efficientdet_d7x', 'EfficientDet-D7X', '1911.09070', batch_size=_bs(4),
73+
# model_desc='Ported from official Google AI Tensorflow weights'),
6674
]
6775

6876

@@ -87,14 +95,13 @@ def eval_model(model_name, paper_model_name, paper_arxiv_id, batch_size=64, mode
8795
else:
8896
print('AMP not installed, running network in FP32.')
8997

90-
annotation_path = os.path.join(DATA_ROOT, 'annotations', f'instances_{ANNO_SET}.json')
9198
evaluator = COCOEvaluator(
9299
root=DATA_ROOT,
93100
model_name=paper_model_name,
94101
model_description=model_description,
95102
paper_arxiv_id=paper_arxiv_id)
96103

97-
dataset = DetectionDatset(os.path.join(DATA_ROOT, ANNO_SET), annotation_path)
104+
dataset = create_dataset('coco', DATA_ROOT, splits='val')
98105

99106
loader = create_loader(
100107
dataset,
@@ -106,16 +113,17 @@ def eval_model(model_name, paper_model_name, paper_arxiv_id, batch_size=64, mode
106113
pin_mem=True)
107114

108115
iterator = tqdm.tqdm(loader, desc="Evaluation", mininterval=5)
116+
sample_count = 0
109117
evaluator.reset_time()
110-
111118
with torch.no_grad():
112119
for i, (input, target) in enumerate(iterator):
113-
output = bench(input, target['img_scale'], target['img_size'])
120+
output = bench(input, target)
114121
output = output.cpu()
115-
sample_ids = target['img_id'].cpu()
116122
results = []
117123
for index, sample in enumerate(output):
118-
image_id = int(sample_ids[index])
124+
image_id = int(dataset.parser.img_ids[sample_count])
125+
sample[:, 2] -= sample[:, 0]
126+
sample[:, 3] -= sample[:, 1]
119127
for det in sample:
120128
score = float(det[4])
121129
if score < .001: # stop when below this threshold, scores in descending order
@@ -126,6 +134,7 @@ def eval_model(model_name, paper_model_name, paper_arxiv_id, batch_size=64, mode
126134
score=score,
127135
category_id=int(det[5]))
128136
results.append(coco_det)
137+
sample_count += 1
129138
evaluator.add(results)
130139

131140
if evaluator.cache_exists:

0 commit comments

Comments
 (0)