Skip to content

Commit 4df6b17

Browse files
authored
[Relax][ONNX] Add roi_pool op and MaxRoiPool frontend support (#18952)
## Summary Add Relax `roi_pool` support and wire it through the ONNX frontend for `MaxRoiPool`. ## Changes - add `relax.vision.roi_pool`, including attrs, Python wrapper, struct info inference, and legalization - add TOPI `roi_pool` compute for NCHW layout - support ONNX `MaxRoiPool` in the Relax ONNX frontend - handle empty / out-of-bound pooled bins according to ONNX/reference semantics, returning `0` instead of propagating invalid reductions - add regression tests for Relax op inference, legalization, and ONNX frontend import - add out-of-bound ROI coverage to make sure fully invalid pooled bins still match ONNX Runtime ## Validation - `pytest tests/python/relax/test_op_vision.py -k roi_pool` - `pytest tests/python/relax/test_frontend_onnx.py -k 'max_roi_pool'` This PR completes the `MaxRoiPool` portion of the Relax ONNX frontend operator work tracked in #18945.
1 parent 4de1f11 commit 4df6b17

File tree

15 files changed

+538
-12
lines changed

15 files changed

+538
-12
lines changed

include/tvm/relax/attrs/vision.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,23 @@ struct ROIAlignAttrs : public AttrsNodeReflAdapter<ROIAlignAttrs> {
7373
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", ROIAlignAttrs, BaseAttrsNode);
7474
}; // struct ROIAlignAttrs
7575

76+
/*! \brief Attributes used in ROIPool operator */
77+
struct ROIPoolAttrs : public AttrsNodeReflAdapter<ROIPoolAttrs> {
78+
ffi::Array<int64_t> pooled_size;
79+
double spatial_scale;
80+
ffi::String layout;
81+
82+
static void RegisterReflection() {
83+
namespace refl = tvm::ffi::reflection;
84+
refl::ObjectDef<ROIPoolAttrs>()
85+
.def_ro("pooled_size", &ROIPoolAttrs::pooled_size, "Output size of roi pool.")
86+
.def_ro("spatial_scale", &ROIPoolAttrs::spatial_scale,
87+
"Ratio of input feature map height (or width) to raw image height (or width).")
88+
.def_ro("layout", &ROIPoolAttrs::layout, "Dimension ordering of the input data.");
89+
}
90+
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIPoolAttrs", ROIPoolAttrs, BaseAttrsNode);
91+
}; // struct ROIPoolAttrs
92+
7693
/*! \brief Attributes used in GetValidCounts operator */
7794
struct GetValidCountsAttrs : public AttrsNodeReflAdapter<GetValidCountsAttrs> {
7895
double score_threshold;
@@ -132,7 +149,6 @@ struct NonMaximumSuppressionAttrs
132149
NonMaximumSuppressionAttrs, BaseAttrsNode);
133150
}; // struct NonMaximumSuppressionAttrs
134151

135-
136152
/*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box decode). */
137153
struct MultiboxTransformLocAttrs : public AttrsNodeReflAdapter<MultiboxTransformLocAttrs> {
138154
bool clip;

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2519,6 +2519,28 @@ def _impl_v16(cls, bb, inputs, attr, params):
25192519
return cls._impl(bb, inputs, attr, params, b"half_pixel")
25202520

25212521

2522+
class MaxRoiPool(OnnxOpConverter):
2523+
"""Converts an onnx MaxRoiPool node into an equivalent Relax expression."""
2524+
2525+
@classmethod
2526+
def _impl_v1(cls, bb, inputs, attr, params):
2527+
if len(inputs) != 2:
2528+
raise ValueError("MaxRoiPool expects exactly 2 inputs")
2529+
2530+
pooled_shape = attr.get("pooled_shape")
2531+
if pooled_shape is None:
2532+
raise ValueError("MaxRoiPool requires pooled_shape attribute")
2533+
2534+
spatial_scale = attr.get("spatial_scale", 1.0)
2535+
return relax.op.vision.roi_pool(
2536+
inputs[0],
2537+
inputs[1],
2538+
pooled_size=tuple(pooled_shape),
2539+
spatial_scale=spatial_scale,
2540+
layout="NCHW",
2541+
)
2542+
2543+
25222544
class Range(OnnxOpConverter):
25232545
"""Converts an onnx Range node into an equivalent Relax expression."""
25242546

@@ -4179,7 +4201,7 @@ def _get_convert_map():
41794201
"OneHot": OneHot,
41804202
"Unique": Unique,
41814203
"NonZero": NonZero,
4182-
# "MaxRoiPool": MaxRoiPool,
4204+
"MaxRoiPool": MaxRoiPool,
41834205
"RoiAlign": RoiAlign,
41844206
"NonMaxSuppression": NonMaxSuppression,
41854207
"AllClassNMS": AllClassNMS,

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
multibox_transform_loc,
164164
non_max_suppression,
165165
roi_align,
166+
roi_pool,
166167
)
167168

168169

python/tvm/relax/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,11 @@ class ROIAlignAttrs(Attrs):
266266
"""Attributes for vision.roi_align"""
267267

268268

269+
@tvm_ffi.register_object("relax.attrs.ROIPoolAttrs")
270+
class ROIPoolAttrs(Attrs):
271+
"""Attributes for vision.roi_pool"""
272+
273+
269274
@tvm_ffi.register_object("relax.attrs.MultiboxTransformLocAttrs")
270275
class MultiboxTransformLocAttrs(Attrs):
271276
"""Attributes for vision.multibox_transform_loc"""

python/tvm/relax/op/vision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from .multibox_transform_loc import *
2121
from .nms import *
2222
from .roi_align import *
23+
from .roi_pool import *
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""ROI Pool operator"""
18+
19+
from ..base import Expr
20+
from . import _ffi_api
21+
22+
23+
def roi_pool(
24+
data: Expr,
25+
rois: Expr,
26+
pooled_size: int | tuple[int, int] | list[int],
27+
spatial_scale: float,
28+
layout: str = "NCHW",
29+
):
30+
"""ROI Pool operator.
31+
32+
Parameters
33+
----------
34+
data : relax.Expr
35+
4-D input tensor.
36+
37+
rois : relax.Expr
38+
2-D input tensor with shape `(num_roi, 5)` in
39+
`[batch_idx, x1, y1, x2, y2]` format.
40+
41+
pooled_size : Union[int, Tuple[int, int], List[int]]
42+
Output pooled size.
43+
44+
spatial_scale : float
45+
Ratio of input feature map height (or width) to raw image height (or width).
46+
47+
layout : str, optional
48+
Layout of the input data. Currently only `NCHW` is supported.
49+
50+
Returns
51+
-------
52+
result : relax.Expr
53+
The computed result.
54+
"""
55+
if isinstance(pooled_size, int):
56+
pooled_size = (pooled_size, pooled_size)
57+
return _ffi_api.roi_pool(data, rois, pooled_size, spatial_scale, layout)

python/tvm/relax/transform/legalize_ops/vision.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,18 @@ def _non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr:
150150
)
151151

152152

153+
@register_legalize("relax.vision.roi_pool")
154+
def _roi_pool(bb: BlockBuilder, call: Call) -> Expr:
155+
return bb.call_te(
156+
topi.vision.roi_pool,
157+
call.args[0],
158+
call.args[1],
159+
pooled_size=call.attrs.pooled_size,
160+
spatial_scale=call.attrs.spatial_scale,
161+
layout=call.attrs.layout,
162+
)
163+
164+
153165
@register_legalize("relax.vision.multibox_transform_loc")
154166
def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr:
155167
variances = tuple(float(x) for x in call.attrs.variances)

python/tvm/runtime/support.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,17 @@ def method(*args, **kwargs):
146146
fields = metadata.get("fields", [])
147147
methods = metadata.get("methods", [])
148148

149-
class TVMDerivedObject(metadata["cls"]): # type: ignore
149+
base_cls = metadata["cls"]
150+
slots = []
151+
if getattr(base_cls, "__dictoffset__", 0) == 0:
152+
slots.append("__dict__")
153+
if getattr(base_cls, "__weakrefoffset__", 0) == 0:
154+
slots.append("__weakref__")
155+
156+
class TVMDerivedObject(base_cls): # type: ignore
150157
"""The derived object to avoid cyclic dependency."""
151158

152-
__slots__ = ("__dict__", "__weakref__",)
159+
__slots__ = tuple(slots)
153160

154161
_cls = cls
155162
_type = "TVMDerivedObject"

python/tvm/s_tir/meta_schedule/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,17 @@ def method(*args, **kwargs):
106106
fields = metadata.get("fields", [])
107107
methods = metadata.get("methods", [])
108108

109-
class TVMDerivedObject(metadata["cls"]): # type: ignore
109+
base_cls = metadata["cls"]
110+
slots = []
111+
if getattr(base_cls, "__dictoffset__", 0) == 0:
112+
slots.append("__dict__")
113+
if getattr(base_cls, "__weakrefoffset__", 0) == 0:
114+
slots.append("__weakref__")
115+
116+
class TVMDerivedObject(base_cls): # type: ignore
110117
"""The derived object to avoid cyclic dependency."""
111118

112-
__slots__ = ("__dict__", "__weakref__",)
119+
__slots__ = tuple(slots)
113120

114121
_cls = cls
115122
_type = "TVMDerivedObject"

python/tvm/topi/vision/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from .multibox_transform_loc import *
2121
from .nms import *
2222
from .roi_align import *
23+
from .roi_pool import *

0 commit comments

Comments
 (0)