Skip to content

Commit 39f030d

Browse files
yushangdifacebook-github-bot
authored andcommitted
Add torch._check hints for torch.export
Summary: Fixes #5347 This allows for torch.export.export ImageList.from_tensors. Differential Revision: D74835454
1 parent 536dc9d commit 39f030d

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

detectron2/structures/image_list.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.nn import functional as F
77

88
from detectron2.layers.wrappers import move_device_like, shapes_to_tensor
9+
from detectron2.utils.torch_version_utils import min_torch_version
910

1011

1112
class ImageList:
@@ -111,7 +112,12 @@ def from_tensors(
111112
# This seems slightly (2%) faster.
112113
# TODO: check whether it's faster for multiple images as well
113114
image_size = image_sizes[0]
114-
padding_size = [0, max_size[-1] - image_size[1], 0, max_size[-2] - image_size[0]]
115+
u0 = max_size[-1] - image_size[1]
116+
u1 = max_size[-2] - image_size[0]
117+
padding_size = [0, u0, 0, u1]
118+
if min_torch_version("2.6.0") and torch.compiler.is_compiling():
119+
torch._check(u0.item() >= 0)
120+
torch._check(u1.item() >= 0)
115121
batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0)
116122
else:
117123
# max_size can be a tensor in tracing mode, therefore convert to list

detectron2/utils/testing.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Callable
99
import torch
1010
import torch.onnx.symbolic_helper as sym_help
11-
from packaging import version
1211
from torch._C import ListType
1312
from torch.onnx import register_custom_op_symbolic
1413

@@ -19,6 +18,7 @@
1918
from detectron2.modeling import build_model
2019
from detectron2.structures import Boxes, Instances, ROIMasks
2120
from detectron2.utils.file_io import PathManager
21+
from detectron2.utils.torch_version_utils import min_torch_version
2222

2323

2424
"""
@@ -162,20 +162,6 @@ def reload_lazy_config(cfg):
162162
return LazyConfig.load(fname)
163163

164164

165-
def min_torch_version(min_version: str) -> bool:
166-
"""
167-
Returns True when torch's version is at least `min_version`.
168-
"""
169-
try:
170-
import torch
171-
except ImportError:
172-
return False
173-
174-
installed_version = version.parse(torch.__version__.split("+")[0])
175-
min_version = version.parse(min_version)
176-
return installed_version >= min_version
177-
178-
179165
def has_dynamic_axes(onnx_model):
180166
"""
181167
Return True when all ONNX input/output have only dynamic axes for all ranks
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from packaging import version
2+
3+
def min_torch_version(min_version: str) -> bool:
4+
"""
5+
Returns True when torch's version is at least `min_version`.
6+
"""
7+
try:
8+
import torch
9+
except ImportError:
10+
return False
11+
12+
installed_version = version.parse(torch.__version__.split("+")[0])
13+
min_version = version.parse(min_version)
14+
return installed_version >= min_version

0 commit comments

Comments
 (0)