diff --git a/rfdetr/main.py b/rfdetr/main.py index f52a238..ae34ec2 100644 --- a/rfdetr/main.py +++ b/rfdetr/main.py @@ -540,7 +540,13 @@ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_ input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device) input_names = ['input'] - output_names = ['features'] if backbone_only else ['dets', 'labels'] + if backbone_only: + output_names = ['features'] + elif self.args.segmentation_head: + output_names = ['dets', 'labels', 'masks'] + else: + output_names = ['dets', 'labels'] + dynamic_axes = None self.model.eval() with torch.no_grad():