diff --git a/rfdetr/detr.py b/rfdetr/detr.py index 1cb03a3..000283b 100644 --- a/rfdetr/detr.py +++ b/rfdetr/detr.py @@ -319,8 +319,9 @@ def predict( predictions = { "pred_logits": predictions[1], "pred_boxes": predictions[0], - "pred_masks": predictions[2] } + if len(predictions) == 3: + predictions["pred_masks"] = predictions[2] target_sizes = torch.tensor(orig_sizes, device=self.model.device) results = self.model.postprocess(predictions, target_sizes=target_sizes) diff --git a/rfdetr/models/lwdetr.py b/rfdetr/models/lwdetr.py index 9c1f058..95ae7e4 100644 --- a/rfdetr/models/lwdetr.py +++ b/rfdetr/models/lwdetr.py @@ -244,7 +244,10 @@ def forward_export(self, tensors): if self.segmentation_head is not None: outputs_masks = self.segmentation_head(srcs[0], [hs_enc,], tensors.shape[-2:], skip_blocks=True)[0] - return outputs_coord, outputs_class, outputs_masks + if outputs_masks is not None: + return outputs_coord, outputs_class, outputs_masks + else: + return outputs_coord, outputs_class @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_coord, outputs_masks):