Skip to content

Commit faadf03

Browse files
committed
A bit of cleanup, testing w/ PyTorch 2.0 nightlies, initial torchcompile support
1 parent b035355 commit faadf03

File tree

7 files changed

+377
-140
lines changed

7 files changed

+377
-140
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ Aside from the default model configs, there is a lot of flexibility to facilitat
1616

1717
## Updates
1818

19+
### 2023-02-09
20+
* Testing with PyTorch 2.0 (nightlies), add --torchcompile support to train and validate scripts
21+
* A small code cleanup pass, support bwd/fwd compat across timm 0.8.x and previous releases
22+
* Use `timm` convert_sync_batchnorm function as it handles updated models w/ BatchNormAct2d layers
23+
1924
### 2022-01-06
2025
* New `efficientnetv2_ds` weights 50.1 mAP @ 1024x0124, using AGC clipping and `timm`'s `efficientnetv2_rw_s` backbone. Memory use comparable to D3, speed faster than D4. Smaller than optimal training batch size so can probably do better...
2126

effdet/bench.py

+59-15
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,16 @@ def _batch_detection(
8080
img_scale_i = None if img_scale is None else img_scale[i]
8181
img_size_i = None if img_size is None else img_size[i]
8282
detections = generate_detections(
83-
class_out[i], box_out[i], anchor_boxes, indices[i], classes[i],
84-
img_scale_i, img_size_i, max_det_per_image=max_det_per_image, soft_nms=soft_nms)
83+
class_out[i],
84+
box_out[i],
85+
anchor_boxes,
86+
indices[i],
87+
classes[i],
88+
img_scale_i,
89+
img_size_i,
90+
max_det_per_image=max_det_per_image,
91+
soft_nms=soft_nms,
92+
)
8593
batch_detections.append(detections)
8694
return torch.stack(batch_detections, dim=0)
8795

@@ -101,15 +109,27 @@ def __init__(self, model):
101109
def forward(self, x, img_info: Optional[Dict[str, torch.Tensor]] = None):
102110
class_out, box_out = self.model(x)
103111
class_out, box_out, indices, classes = _post_process(
104-
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes,
105-
max_detection_points=self.max_detection_points)
112+
class_out,
113+
box_out,
114+
num_levels=self.num_levels,
115+
num_classes=self.num_classes,
116+
max_detection_points=self.max_detection_points,
117+
)
106118
if img_info is None:
107119
img_scale, img_size = None, None
108120
else:
109121
img_scale, img_size = img_info['img_scale'], img_info['img_size']
110122
return _batch_detection(
111-
x.shape[0], class_out, box_out, self.anchors.boxes, indices, classes,
112-
img_scale, img_size, max_det_per_image=self.max_det_per_image, soft_nms=self.soft_nms
123+
x.shape[0],
124+
class_out,
125+
box_out,
126+
self.anchors.boxes,
127+
indices,
128+
classes,
129+
img_scale,
130+
img_size,
131+
max_det_per_image=self.max_det_per_image,
132+
soft_nms=self.soft_nms,
113133
)
114134

115135

@@ -126,7 +146,11 @@ def __init__(self, model, create_labeler=True):
126146
self.soft_nms = model.config.soft_nms
127147
self.anchor_labeler = None
128148
if create_labeler:
129-
self.anchor_labeler = AnchorLabeler(self.anchors, self.num_classes, match_threshold=0.5)
149+
self.anchor_labeler = AnchorLabeler(
150+
self.anchors,
151+
self.num_classes,
152+
match_threshold=0.5,
153+
)
130154
self.loss_fn = DetectionLoss(model.config)
131155

132156
def forward(self, x, target: Dict[str, torch.Tensor]):
@@ -139,19 +163,39 @@ def forward(self, x, target: Dict[str, torch.Tensor]):
139163
num_positives = target['label_num_positives']
140164
else:
141165
cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
142-
target['bbox'], target['cls'])
143-
144-
loss, class_loss, box_loss = self.loss_fn(class_out, box_out, cls_targets, box_targets, num_positives)
166+
target['bbox'],
167+
target['cls'],
168+
)
169+
170+
loss, class_loss, box_loss = self.loss_fn(
171+
class_out,
172+
box_out,
173+
cls_targets,
174+
box_targets,
175+
num_positives,
176+
)
145177
output = {'loss': loss, 'class_loss': class_loss, 'box_loss': box_loss}
146178
if not self.training:
147179
# if eval mode, output detections for evaluation
148180
class_out_pp, box_out_pp, indices, classes = _post_process(
149-
class_out, box_out, num_levels=self.num_levels, num_classes=self.num_classes,
150-
max_detection_points=self.max_detection_points)
181+
class_out,
182+
box_out,
183+
num_levels=self.num_levels,
184+
num_classes=self.num_classes,
185+
max_detection_points=self.max_detection_points,
186+
)
151187
output['detections'] = _batch_detection(
152-
x.shape[0], class_out_pp, box_out_pp, self.anchors.boxes, indices, classes,
153-
target['img_scale'], target['img_size'],
154-
max_det_per_image=self.max_det_per_image, soft_nms=self.soft_nms)
188+
x.shape[0],
189+
class_out_pp,
190+
box_out_pp,
191+
self.anchors.boxes,
192+
indices,
193+
classes,
194+
target['img_scale'],
195+
target['img_size'],
196+
max_det_per_image=self.max_det_per_image,
197+
soft_nms=self.soft_nms,
198+
)
155199
return output
156200

157201

0 commit comments

Comments
 (0)