diff --git a/rfdetr/main.py b/rfdetr/main.py index 14fe89c..2185eb1 100644 --- a/rfdetr/main.py +++ b/rfdetr/main.py @@ -524,7 +524,7 @@ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_ input_names = ['input'] output_names = ['features'] if backbone_only else ['dets', 'labels'] dynamic_axes = None - self.model.eval() + model.eval() with torch.no_grad(): if backbone_only: features = model(input_tensors) @@ -562,7 +562,6 @@ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_ print(f"Successfully simplified ONNX model to: {sim_output_file}") print("ONNX export completed successfully") - self.model = self.model.to(device) if __name__ == '__main__': @@ -1059,4 +1058,4 @@ def populate_args( gradient_checkpointing=gradient_checkpointing, **extra_kwargs ) - return args \ No newline at end of file + return args