Skip to content

Commit 460c9f3

Browse files
authored
fix(ONNX): avoids resizing unsupported dimensions (#3945)
Partially resolves #3453 by introducing better error reporting for unsupported configurations in the `onnx.Resize` lowering.
1 parent d4ee6ba commit 460c9f3

File tree

2 files changed

+141
-47
lines changed

2 files changed

+141
-47
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 119 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2700,12 +2700,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27002700
});
27012701
patterns.onOp(
27022702
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
2703-
Torch::ValueTensorType resultType;
2703+
Torch::ValueTensorType outputTensorType;
27042704
llvm::SmallVector<Value> operands;
27052705
std::string mode, nearest_mode, coordTfMode;
27062706
int64_t antialias, exclude_outside;
27072707
float extrapolation_value, cubic_coeff_a;
2708-
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
27092708

27102709
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
27112710
return rewriter.notifyMatchFailure(
@@ -2720,7 +2719,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27202719
}
27212720

27222721
if (binder.tensorOperandsList(operands) ||
2723-
binder.tensorResultType(resultType) ||
2722+
binder.tensorResultType(outputTensorType) ||
27242723
binder.customOpNameStringAttr(mode, "mode", "nearest") ||
27252724
binder.customOpNameStringAttr(
27262725
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
@@ -2732,6 +2731,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27322731
"round_prefer_floor") ||
27332732
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
27342733
return failure();
2734+
2735+
int64_t const /* */ batchDim = 0;
2736+
int64_t const /**/ channelDim = 1;
2737+
2738+
SmallVector<int64_t> nonResizableDims{
2739+
batchDim,
2740+
channelDim,
2741+
};
2742+
2743+
Value inputTensor = operands[0];
2744+
auto inputTensorType =
2745+
cast<Torch::BaseTensorType>(inputTensor.getType());
2746+
auto sizesOfInputTensor = inputTensorType.getSizes();
2747+
auto sizesOfOutputTensor = outputTensorType.getSizes();
2748+
2749+
auto unknownSize = Torch::kUnknownSize;
2750+
2751+
// Compile-time check for dimensions of static size
2752+
for (auto &eachDim : nonResizableDims) {
2753+
auto eachSizeOfInputTensor = sizesOfInputTensor[eachDim];
2754+
auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDim];
2755+
2756+
if (eachSizeOfInputTensor == unknownSize ||
2757+
eachSizeOfOutputTensor == unknownSize)
2758+
continue;
2759+
if (eachSizeOfInputTensor == eachSizeOfOutputTensor)
2760+
continue;
2761+
2762+
auto resizingIntentErrorMessage =
2763+
"unsupported: non-trivial intent to resize dimension: " +
2764+
std::to_string(eachDim);
2765+
2766+
return rewriter.notifyMatchFailure(binder.op,
2767+
resizingIntentErrorMessage);
2768+
};
2769+
27352770
if (antialias != 0) {
27362771
return rewriter.notifyMatchFailure(
27372772
binder.op,
@@ -2764,35 +2799,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
27642799
binder.op, "unimplemented: cubic coeff must be -0.75");
27652800
}
27662801

2767-
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
2768-
.getSizes()
2769-
.size();
2802+
auto loc = binder.getLoc();
27702803

2771-
Value cstFalse =
2772-
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
2773-
Value cstTrue =
2774-
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
2804+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
2805+
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
27752806
Value modeStrValue;
27762807

2777-
Value scalesValueList = noneVal;
2778-
Value sizesValueList = noneVal;
27792808
Value alignCorners =
27802809
coordTfMode == "align_corners" ? cstTrue : cstFalse;
27812810
if (mode == "cubic") {
27822811
std::string modeStr = "cubic";
27832812
if (coordTfMode != "half_pixel")
27842813
modeStr = modeStr + "_" + coordTfMode;
2785-
modeStrValue =
2786-
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
2814+
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
27872815
}
2816+
2817+
auto rankOfInputTensor = sizesOfInputTensor.size();
2818+
27882819
// supported modes:
27892820
// bilinear (half_pixel), bilinear with align_corners,
27902821
// bilinear_pytorch_half_pixel, bilinear_asymmetric nearest
27912822
// (asymmetric), nearest with align_corners, nearest_half_pixel,
27922823
// nearest_pytorch_half_pixel
27932824
if (mode == "linear") {
27942825
std::string modeStr;
2795-
switch (rank) {
2826+
switch (rankOfInputTensor) {
27962827
case 3:
27972828
modeStr = "linear";
27982829
break;
@@ -2809,8 +2840,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
28092840
// mode is apparently half_pixel, NOT pytorch_half_pixel
28102841
if (coordTfMode != "half_pixel" && coordTfMode != "align_corners")
28112842
modeStr = (modeStr + "_") + coordTfMode;
2812-
modeStrValue =
2813-
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
2843+
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
28142844
}
28152845
if (mode == "nearest") {
28162846
std::string modeStr = "nearest";
@@ -2820,33 +2850,84 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
28202850
modeStr = (modeStr + "_") + coordTfMode;
28212851
if (nearest_mode != "floor" && nearest_mode != "")
28222852
modeStr = modeStr + "," + nearest_mode;
2823-
modeStrValue =
2824-
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
2853+
modeStrValue = rewriter.create<Torch::ConstantStrOp>(loc, modeStr);
28252854
}
28262855

2827-
int64_t assumedForemostSpatialDim = 2;
2856+
auto numberOfOperands = operands.size();
28282857

2829-
if (operands.size() < 4) {
2830-
Value scaleOperand = operands[2];
2831-
scalesValueList =
2832-
createScalarSublist(binder.getLoc(), scaleOperand,
2833-
assumedForemostSpatialDim, rewriter);
2834-
sizesValueList = noneVal;
2835-
} else {
2836-
Value sizeOperand = operands[3];
2837-
scalesValueList = noneVal;
2838-
sizesValueList =
2839-
createScalarSublist(binder.getLoc(), sizeOperand,
2840-
assumedForemostSpatialDim, rewriter);
2841-
}
2842-
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
2843-
isa<Torch::NoneType>(sizesValueList.getType())) {
2858+
Type boolType = rewriter.getType<Torch::BoolType>();
2859+
2860+
int64_t assumedForemostSpatialDim = 1 + nonResizableDims.back();
2861+
2862+
Value supportedScaleFactors;
2863+
Value supportedSizes;
2864+
2865+
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(loc);
2866+
2867+
if (numberOfOperands == 3) {
2868+
Value proposedScaleFactors = operands[2];
2869+
2870+
Value scaleIdentity = rewriter.create<Torch::ConstantFloatOp>(
2871+
loc, rewriter.getF64FloatAttr(1.0));
2872+
2873+
// run-time scale factor check for dynamic sizes
2874+
for (auto &eachDim : nonResizableDims) {
2875+
Value eachProposedScaleFactor = extractTorchScalar(
2876+
loc, eachDim, proposedScaleFactors, rewriter);
2877+
2878+
Value eachScaleFactorIsIdentity =
2879+
rewriter.create<Torch::AtenEqFloatOp>(
2880+
loc, boolType, eachProposedScaleFactor, scaleIdentity);
2881+
2882+
auto errorMessageForEachDim =
2883+
"Unsupported: non-trivial scale factor for dimension " +
2884+
std::to_string(eachDim);
2885+
2886+
rewriter.create<Torch::RuntimeAssertOp>(
2887+
loc, eachScaleFactorIsIdentity,
2888+
rewriter.getStringAttr(errorMessageForEachDim));
2889+
};
2890+
2891+
supportedScaleFactors = createScalarSublist(
2892+
loc, proposedScaleFactors, assumedForemostSpatialDim, rewriter);
2893+
supportedSizes = noneVal;
2894+
} else if (numberOfOperands == 4) {
2895+
Value proposedSizes = operands[3];
2896+
2897+
// run-time target size check for dynamic sizes
2898+
for (auto &eachDimAsInt : nonResizableDims) {
2899+
Value eachDimAsValue =
2900+
rewriter.create<Torch::ConstantIntOp>(loc, eachDimAsInt);
2901+
2902+
Value eachSizeOfInputTensor = rewriter.create<Torch::AtenSizeIntOp>(
2903+
loc, inputTensor, eachDimAsValue);
2904+
2905+
Value eachProposedSize =
2906+
extractTorchScalar(loc, eachDimAsInt, proposedSizes, rewriter);
2907+
2908+
Value eachProposedSizeIsTrivial =
2909+
rewriter.create<Torch::AtenEqIntOp>(
2910+
loc, boolType, eachProposedSize, eachSizeOfInputTensor);
2911+
2912+
auto errorMessageForEachDim =
2913+
"Unsupported: non-trivial resizing of dimension " +
2914+
std::to_string(eachDimAsInt);
2915+
2916+
rewriter.create<Torch::RuntimeAssertOp>(
2917+
loc, eachProposedSizeIsTrivial,
2918+
rewriter.getStringAttr(errorMessageForEachDim));
2919+
};
2920+
2921+
supportedScaleFactors = noneVal;
2922+
supportedSizes = createScalarSublist(
2923+
loc, proposedSizes, assumedForemostSpatialDim, rewriter);
2924+
} else
28442925
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
2845-
}
2926+
28462927
rewriter
28472928
.replaceOpWithNewOp<Torch::Aten__InterpolateSizeListScaleListOp>(
2848-
binder.op, resultType, operands[0], sizesValueList,
2849-
scalesValueList, modeStrValue,
2929+
binder.op, outputTensorType, inputTensor, supportedSizes,
2930+
supportedScaleFactors, modeStrValue,
28502931
/* AnyTorchOptionalBoolType:$align_corners */ alignCorners,
28512932
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal,
28522933
/*Torch_BoolType:$antialias*/ cstFalse);

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2256,21 +2256,30 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
22562256
// CHECK-LABEL: func.func @test_resize_sizes_nearest
22572257
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22582258
%none = torch.constant.none
2259-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2260-
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
2259+
// CHECK: %[[MODE_STR:.*]] = torch.constant.str "nearest"
2260+
// CHECK: torch.aten.__interpolate.size_list_scale_list
2261+
// CHECK-SAME: %[[MODE_STR]]
2262+
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
2263+
torch.onnx.coordinate_transformation_mode = "asymmetric",
2264+
torch.onnx.cubic_coeff_a = -7.500000e-01 : f32,
2265+
torch.onnx.mode = "nearest",
2266+
torch.onnx.nearest_mode = "floor"
2267+
} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
22612268
return %0 : !torch.vtensor<[?,?,?,?],f32>
22622269
}
22632270

22642271
// -----
22652272

2266-
// CHECK-LABEL: func.func @test_resize_sizes_nearest
2267-
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
2273+
// CHECK-LABEL: func.func @test_resize_sizes_nearest_half_pixel
2274+
func.func @test_resize_sizes_nearest_half_pixel(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22682275
%none = torch.constant.none
2269-
// CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
2270-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2276+
// CHECK: %[[MODE_STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor"
2277+
// CHECK: torch.aten.__interpolate.size_list_scale_list
2278+
// CHECK-SAME: %[[MODE_STR]]
22712279
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
22722280
torch.onnx.coordinate_transformation_mode = "half_pixel",
2273-
torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
2281+
torch.onnx.mode = "nearest"
2282+
} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
22742283
return %0 : !torch.vtensor<[?,?,?,?],f32>
22752284
}
22762285

@@ -2280,8 +2289,12 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1
22802289
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
22812290
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
22822291
%none = torch.constant.none
2283-
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
2284-
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
2292+
// CHECK: %[[MODE_STR:.*]] = torch.constant.str "bilinear"
2293+
// CHECK: torch.aten.__interpolate.size_list_scale_list
2294+
// CHECK-SAME: %[[MODE_STR]]
2295+
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {
2296+
torch.onnx.mode = "linear"
2297+
} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
22852298
return %0 : !torch.vtensor<[?,?,?,?],f32>
22862299
}
22872300

0 commit comments

Comments
 (0)