@@ -2700,12 +2700,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2700
2700
});
2701
2701
patterns.onOp (
2702
2702
" Resize" , 11 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
2703
- Torch::ValueTensorType resultType ;
2703
+ Torch::ValueTensorType outputTensorType ;
2704
2704
llvm::SmallVector<Value> operands;
2705
2705
std::string mode, nearest_mode, coordTfMode;
2706
2706
int64_t antialias, exclude_outside;
2707
2707
float extrapolation_value, cubic_coeff_a;
2708
- Value noneVal = rewriter.create <Torch::ConstantNoneOp>(binder.getLoc ());
2709
2708
2710
2709
if (auto attr = binder.op ->getAttr (" torch.onnx.axes" )) {
2711
2710
return rewriter.notifyMatchFailure (
@@ -2720,7 +2719,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2720
2719
}
2721
2720
2722
2721
if (binder.tensorOperandsList (operands) ||
2723
- binder.tensorResultType (resultType ) ||
2722
+ binder.tensorResultType (outputTensorType ) ||
2724
2723
binder.customOpNameStringAttr (mode, " mode" , " nearest" ) ||
2725
2724
binder.customOpNameStringAttr (
2726
2725
coordTfMode, " coordinate_transformation_mode" , " half_pixel" ) ||
@@ -2732,6 +2731,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2732
2731
" round_prefer_floor" ) ||
2733
2732
binder.f32FloatAttr (cubic_coeff_a, " cubic_coeff_a" , -0.75 ))
2734
2733
return failure ();
2734
+
2735
+ int64_t const /* */ batchDim = 0 ;
2736
+ int64_t const /* */ channelDim = 1 ;
2737
+
2738
+ SmallVector<int64_t > nonResizableDims{
2739
+ batchDim,
2740
+ channelDim,
2741
+ };
2742
+
2743
+ Value inputTensor = operands[0 ];
2744
+ auto inputTensorType =
2745
+ cast<Torch::BaseTensorType>(inputTensor.getType ());
2746
+ auto sizesOfInputTensor = inputTensorType.getSizes ();
2747
+ auto sizesOfOutputTensor = outputTensorType.getSizes ();
2748
+
2749
+ auto unknownSize = Torch::kUnknownSize ;
2750
+
2751
+ // Compile-time check for dimensions of static size
2752
+ for (auto &eachDim : nonResizableDims) {
2753
+ auto eachSizeOfInputTensor = sizesOfInputTensor[eachDim];
2754
+ auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDim];
2755
+
2756
+ if (eachSizeOfInputTensor == unknownSize ||
2757
+ eachSizeOfOutputTensor == unknownSize)
2758
+ continue ;
2759
+ if (eachSizeOfInputTensor == eachSizeOfOutputTensor)
2760
+ continue ;
2761
+
2762
+ auto resizingIntentErrorMessage =
2763
+ " unsupported: non-trivial intent to resize dimension: " +
2764
+ std::to_string (eachDim);
2765
+
2766
+ return rewriter.notifyMatchFailure (binder.op ,
2767
+ resizingIntentErrorMessage);
2768
+ };
2769
+
2735
2770
if (antialias != 0 ) {
2736
2771
return rewriter.notifyMatchFailure (
2737
2772
binder.op ,
@@ -2764,35 +2799,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2764
2799
binder.op , " unimplemented: cubic coeff must be -0.75" );
2765
2800
}
2766
2801
2767
- unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0 ].getType ())
2768
- .getSizes ()
2769
- .size ();
2802
+ auto loc = binder.getLoc ();
2770
2803
2771
- Value cstFalse =
2772
- rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), false );
2773
- Value cstTrue =
2774
- rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), true );
2804
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
2805
+ Value cstTrue = rewriter.create <Torch::ConstantBoolOp>(loc, true );
2775
2806
Value modeStrValue;
2776
2807
2777
- Value scalesValueList = noneVal;
2778
- Value sizesValueList = noneVal;
2779
2808
Value alignCorners =
2780
2809
coordTfMode == " align_corners" ? cstTrue : cstFalse;
2781
2810
if (mode == " cubic" ) {
2782
2811
std::string modeStr = " cubic" ;
2783
2812
if (coordTfMode != " half_pixel" )
2784
2813
modeStr = modeStr + " _" + coordTfMode;
2785
- modeStrValue =
2786
- rewriter.create <Torch::ConstantStrOp>(binder.getLoc (), modeStr);
2814
+ modeStrValue = rewriter.create <Torch::ConstantStrOp>(loc, modeStr);
2787
2815
}
2816
+
2817
+ auto rankOfInputTensor = sizesOfInputTensor.size ();
2818
+
2788
2819
// supported modes:
2789
2820
// bilinear (half_pixel), bilinear with align_corners,
2790
2821
// bilinear_pytorch_half_pixel, bilinear_asymmetric nearest
2791
2822
// (asymmetric), nearest with align_corners, nearest_half_pixel,
2792
2823
// nearest_pytorch_half_pixel
2793
2824
if (mode == " linear" ) {
2794
2825
std::string modeStr;
2795
- switch (rank ) {
2826
+ switch (rankOfInputTensor ) {
2796
2827
case 3 :
2797
2828
modeStr = " linear" ;
2798
2829
break ;
@@ -2809,8 +2840,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2809
2840
// mode is apparently half_pixel, NOT pytorch_half_pixel
2810
2841
if (coordTfMode != " half_pixel" && coordTfMode != " align_corners" )
2811
2842
modeStr = (modeStr + " _" ) + coordTfMode;
2812
- modeStrValue =
2813
- rewriter.create <Torch::ConstantStrOp>(binder.getLoc (), modeStr);
2843
+ modeStrValue = rewriter.create <Torch::ConstantStrOp>(loc, modeStr);
2814
2844
}
2815
2845
if (mode == " nearest" ) {
2816
2846
std::string modeStr = " nearest" ;
@@ -2820,33 +2850,84 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
2820
2850
modeStr = (modeStr + " _" ) + coordTfMode;
2821
2851
if (nearest_mode != " floor" && nearest_mode != " " )
2822
2852
modeStr = modeStr + " ," + nearest_mode;
2823
- modeStrValue =
2824
- rewriter.create <Torch::ConstantStrOp>(binder.getLoc (), modeStr);
2853
+ modeStrValue = rewriter.create <Torch::ConstantStrOp>(loc, modeStr);
2825
2854
}
2826
2855
2827
- int64_t assumedForemostSpatialDim = 2 ;
2856
+ auto numberOfOperands = operands. size () ;
2828
2857
2829
- if (operands.size () < 4 ) {
2830
- Value scaleOperand = operands[2 ];
2831
- scalesValueList =
2832
- createScalarSublist (binder.getLoc (), scaleOperand,
2833
- assumedForemostSpatialDim, rewriter);
2834
- sizesValueList = noneVal;
2835
- } else {
2836
- Value sizeOperand = operands[3 ];
2837
- scalesValueList = noneVal;
2838
- sizesValueList =
2839
- createScalarSublist (binder.getLoc (), sizeOperand,
2840
- assumedForemostSpatialDim, rewriter);
2841
- }
2842
- if (isa<Torch::NoneType>(scalesValueList.getType ()) &&
2843
- isa<Torch::NoneType>(sizesValueList.getType ())) {
2858
+ Type boolType = rewriter.getType <Torch::BoolType>();
2859
+
2860
+ int64_t assumedForemostSpatialDim = 1 + nonResizableDims.back ();
2861
+
2862
+ Value supportedScaleFactors;
2863
+ Value supportedSizes;
2864
+
2865
+ Value noneVal = rewriter.create <Torch::ConstantNoneOp>(loc);
2866
+
2867
+ if (numberOfOperands == 3 ) {
2868
+ Value proposedScaleFactors = operands[2 ];
2869
+
2870
+ Value scaleIdentity = rewriter.create <Torch::ConstantFloatOp>(
2871
+ loc, rewriter.getF64FloatAttr (1.0 ));
2872
+
2873
+ // run-time scale factor check for dynamic sizes
2874
+ for (auto &eachDim : nonResizableDims) {
2875
+ Value eachProposedScaleFactor = extractTorchScalar (
2876
+ loc, eachDim, proposedScaleFactors, rewriter);
2877
+
2878
+ Value eachScaleFactorIsIdentity =
2879
+ rewriter.create <Torch::AtenEqFloatOp>(
2880
+ loc, boolType, eachProposedScaleFactor, scaleIdentity);
2881
+
2882
+ auto errorMessageForEachDim =
2883
+ " Unsupported: non-trivial scale factor for dimension " +
2884
+ std::to_string (eachDim);
2885
+
2886
+ rewriter.create <Torch::RuntimeAssertOp>(
2887
+ loc, eachScaleFactorIsIdentity,
2888
+ rewriter.getStringAttr (errorMessageForEachDim));
2889
+ };
2890
+
2891
+ supportedScaleFactors = createScalarSublist (
2892
+ loc, proposedScaleFactors, assumedForemostSpatialDim, rewriter);
2893
+ supportedSizes = noneVal;
2894
+ } else if (numberOfOperands == 4 ) {
2895
+ Value proposedSizes = operands[3 ];
2896
+
2897
+ // run-time target size check for dynamic sizes
2898
+ for (auto &eachDimAsInt : nonResizableDims) {
2899
+ Value eachDimAsValue =
2900
+ rewriter.create <Torch::ConstantIntOp>(loc, eachDimAsInt);
2901
+
2902
+ Value eachSizeOfInputTensor = rewriter.create <Torch::AtenSizeIntOp>(
2903
+ loc, inputTensor, eachDimAsValue);
2904
+
2905
+ Value eachProposedSize =
2906
+ extractTorchScalar (loc, eachDimAsInt, proposedSizes, rewriter);
2907
+
2908
+ Value eachProposedSizeIsTrivial =
2909
+ rewriter.create <Torch::AtenEqIntOp>(
2910
+ loc, boolType, eachProposedSize, eachSizeOfInputTensor);
2911
+
2912
+ auto errorMessageForEachDim =
2913
+ " Unsupported: non-trivial resizing of dimension " +
2914
+ std::to_string (eachDimAsInt);
2915
+
2916
+ rewriter.create <Torch::RuntimeAssertOp>(
2917
+ loc, eachProposedSizeIsTrivial,
2918
+ rewriter.getStringAttr (errorMessageForEachDim));
2919
+ };
2920
+
2921
+ supportedScaleFactors = noneVal;
2922
+ supportedSizes = createScalarSublist (
2923
+ loc, proposedSizes, assumedForemostSpatialDim, rewriter);
2924
+ } else
2844
2925
return rewriter.notifyMatchFailure (binder.op , " unknown scaling mode" );
2845
- }
2926
+
2846
2927
rewriter
2847
2928
.replaceOpWithNewOp <Torch::Aten__InterpolateSizeListScaleListOp>(
2848
- binder.op , resultType, operands[ 0 ], sizesValueList ,
2849
- scalesValueList , modeStrValue,
2929
+ binder.op , outputTensorType, inputTensor, supportedSizes ,
2930
+ supportedScaleFactors , modeStrValue,
2850
2931
/* AnyTorchOptionalBoolType:$align_corners */ alignCorners,
2851
2932
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal,
2852
2933
/* Torch_BoolType:$antialias*/ cstFalse);
0 commit comments