Skip to content

Commit 65184fc

Browse files
yushangdifacebook-github-bot
authored andcommitted
Add torch._check hints for torch.export
Summary: Pull Request resolved: #5468 Fixes #5347 This allows for torch.export.export ImageList.from_tensors. Reviewed By: wat3rBro Differential Revision: D74835454 fbshipit-source-id: 0da4c30cb7fdebcab60313e43a41f73f21533d72
1 parent 536dc9d commit 65184fc

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

detectron2/structures/image_list.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.nn import functional as F
77

88
from detectron2.layers.wrappers import move_device_like, shapes_to_tensor
9+
from detectron2.utils.torch_version_utils import min_torch_version
910

1011

1112
class ImageList:
@@ -111,7 +112,13 @@ def from_tensors(
111112
# This seems slightly (2%) faster.
112113
# TODO: check whether it's faster for multiple images as well
113114
image_size = image_sizes[0]
114-
padding_size = [0, max_size[-1] - image_size[1], 0, max_size[-2] - image_size[0]]
115+
u0 = max_size[-1] - image_size[1]
116+
u1 = max_size[-2] - image_size[0]
117+
padding_size = [0, u0, 0, u1]
118+
if not torch.jit.is_scripting():
119+
if min_torch_version("2.6.0") and torch.compiler.is_compiling():
120+
torch._check(u0.item() >= 0)
121+
torch._check(u1.item() >= 0)
115122
batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0)
116123
else:
117124
# max_size can be a tensor in tracing mode, therefore convert to list

detectron2/utils/testing.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Callable
99
import torch
1010
import torch.onnx.symbolic_helper as sym_help
11-
from packaging import version
1211
from torch._C import ListType
1312
from torch.onnx import register_custom_op_symbolic
1413

@@ -19,6 +18,7 @@
1918
from detectron2.modeling import build_model
2019
from detectron2.structures import Boxes, Instances, ROIMasks
2120
from detectron2.utils.file_io import PathManager
21+
from detectron2.utils.torch_version_utils import min_torch_version
2222

2323

2424
"""
@@ -162,20 +162,6 @@ def reload_lazy_config(cfg):
162162
return LazyConfig.load(fname)
163163

164164

165-
def min_torch_version(min_version: str) -> bool:
166-
"""
167-
Returns True when torch's version is at least `min_version`.
168-
"""
169-
try:
170-
import torch
171-
except ImportError:
172-
return False
173-
174-
installed_version = version.parse(torch.__version__.split("+")[0])
175-
min_version = version.parse(min_version)
176-
return installed_version >= min_version
177-
178-
179165
def has_dynamic_axes(onnx_model):
180166
"""
181167
Return True when all ONNX input/output have only dynamic axes for all ranks
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from packaging import version
2+
3+
4+
def min_torch_version(min_version: str) -> bool:
5+
"""
6+
Returns True when torch's version is at least `min_version`.
7+
"""
8+
try:
9+
import torch
10+
except ImportError:
11+
return False
12+
13+
installed_version = version.parse(torch.__version__.split("+")[0])
14+
min_version = version.parse(min_version)
15+
return installed_version >= min_version

0 commit comments

Comments
 (0)