Skip to content

Commit 2f8229d

Browse files
yushangdifacebook-github-bot
authored andcommitted
Add torch._check hints for torch.export
Summary: Fixes #5347 This allows for torch.export.export ImageList.from_tensors. Differential Revision: D74835454
1 parent 536dc9d commit 2f8229d

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

detectron2/structures/image_list.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
from __future__ import division
3-
from typing import Any, Dict, List, Optional, Tuple
43
import torch
54
from torch import device
65
from torch.nn import functional as F
76

87
from detectron2.layers.wrappers import move_device_like, shapes_to_tensor
8+
from detectron2.structures.torch_version_utils import torch_version_at_least
9+
10+
from typing import Any, Dict, List, Optional, Tuple
911

1012

1113
class ImageList:
@@ -111,7 +113,12 @@ def from_tensors(
111113
# This seems slightly (2%) faster.
112114
# TODO: check whether it's faster for multiple images as well
113115
image_size = image_sizes[0]
114-
padding_size = [0, max_size[-1] - image_size[1], 0, max_size[-2] - image_size[0]]
116+
u0 = max_size[-1] - image_size[1]
117+
u1 = max_size[-2] - image_size[0]
118+
padding_size = [0, u0, 0, u1]
119+
if torch_version_at_least("2.6.0") and torch.compiler.is_compiling():
120+
torch._check(u0.item() >= 0)
121+
torch._check(u1.item() >= 0)
115122
batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0)
116123
else:
117124
# max_size can be a tensor in tracing mode, therefore convert to list
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
3+
import re
4+
5+
6+
def is_fbcode():
7+
return not hasattr(torch.version, "git_version")
8+
9+
def parse_version(version_string):
10+
# Extract just the X.Y.Z part from the version string
11+
match = re.match(r"(\d+\.\d+\.\d+)", version_string)
12+
if match:
13+
version = match.group(1)
14+
return [int(x) for x in version.split(".")]
15+
else:
16+
raise ValueError(f"Invalid version string format: {version_string}")
17+
18+
def compare_versions(v1, v2):
19+
v1_parts = parse_version(v1)
20+
v2_parts = parse_version(v2)
21+
return (v1_parts > v2_parts) - (v1_parts < v2_parts)
22+
23+
def torch_version_at_least(min_version):
24+
return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0

0 commit comments

Comments
 (0)