Skip to content

Commit df74223

Browse files
Pass input image shape as arg (#546)
1 parent 1ced664 commit df74223

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

rtdetrv2_pytorch/tools/export_onnx.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def forward(self, images, orig_target_sizes):
4343

4444
model = Model()
4545

46-
data = torch.rand(1, 3, 640, 640)
47-
size = torch.tensor([[640, 640]])
46+
data = torch.rand(1, 3, args.input_size, args.input_size)
47+
size = torch.tensor([[args.input_size, args.input_size]])
4848
_ = model(data, size)
4949

5050
dynamic_axes = {
@@ -87,8 +87,10 @@ def forward(self, images, orig_target_sizes):
8787
parser.add_argument('--config', '-c', type=str, )
8888
parser.add_argument('--resume', '-r', type=str, )
8989
parser.add_argument('--output_file', '-o', type=str, default='model.onnx')
90+
parser.add_argument('--input_size', '-s', type=int, default=640)
9091
parser.add_argument('--check', action='store_true', default=False,)
9192
parser.add_argument('--simplify', action='store_true', default=False,)
93+
9294

9395
args = parser.parse_args()
9496

0 commit comments

Comments
 (0)