diff --git a/rfdetr/main.py b/rfdetr/main.py index f52a238..41629ff 100644 --- a/rfdetr/main.py +++ b/rfdetr/main.py @@ -516,7 +516,7 @@ def lr_lambda(current_step: int): for callback in callbacks["on_train_end"]: callback() - def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_only=False, opset_version=17, verbose=True, force=False, shape=None, batch_size=1, **kwargs): + def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_only=False, opset_version=17, verbose=True, force=False, batch_size=1, **kwargs): """Export the trained model to ONNX format""" print(f"Exporting model to ONNX format") try: @@ -532,11 +532,7 @@ def export(self, output_dir="output", infer_dir=None, simplify=False, backbone_ os.makedirs(output_dir, exist_ok=True) output_dir = Path(output_dir) - if shape is None: - shape = (self.resolution, self.resolution) - else: - if shape[0] % 14 != 0 or shape[1] % 14 != 0: - raise ValueError("Shape must be divisible by 14") + shape = (self.resolution, self.resolution) input_tensors = make_infer_image(infer_dir, shape, batch_size, device).to(device) input_names = ['input']