@@ -3701,16 +3701,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3701
3701
return rewriter.notifyMatchFailure (
3702
3702
binder.op , " expected center_point_box attribute to be 0 or 1" );
3703
3703
3704
- Value cst0 = rewriter.create <Torch::ConstantIntOp>(
3705
- loc, rewriter.getI64IntegerAttr (0 ));
3706
- Value cst1 = rewriter.create <Torch::ConstantIntOp>(
3707
- loc, rewriter.getI64IntegerAttr (1 ));
3708
- Value cst2 = rewriter.create <Torch::ConstantIntOp>(
3709
- loc, rewriter.getI64IntegerAttr (2 ));
3710
- Value cst3 = rewriter.create <Torch::ConstantIntOp>(
3711
- loc, rewriter.getI64IntegerAttr (3 ));
3712
- Value cst4 = rewriter.create <Torch::ConstantIntOp>(
3713
- loc, rewriter.getI64IntegerAttr (4 ));
3704
+ Value cst0 = rewriter.create <Torch::ConstantIntOp>(loc, 0 );
3705
+ Value cst1 = rewriter.create <Torch::ConstantIntOp>(loc, 1 );
3706
+ Value cst2 = rewriter.create <Torch::ConstantIntOp>(loc, 2 );
3707
+ Value cst3 = rewriter.create <Torch::ConstantIntOp>(loc, 3 );
3708
+ Value cst4 = rewriter.create <Torch::ConstantIntOp>(loc, 4 );
3714
3709
Value cst2F = rewriter.create <Torch::ConstantFloatOp>(
3715
3710
loc, rewriter.getF64FloatAttr (2.0 ));
3716
3711
Value cstNone = rewriter.create <Torch::ConstantNoneOp>(loc);
@@ -3813,36 +3808,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3813
3808
// Create an empty tensor of shape (B*C*N, 3) to store the final result.
3814
3809
// We slice this to required elements at the end
3815
3810
3816
- // FIXME:: Currently empty tensors created with dynamic sizes are not
3817
- // fully supported. Uncomment the below lines once dynamic sizes for
3818
- // empty tensors are supported end to end.
3819
-
3820
- /*
3821
3811
Value numResults = rewriter.create <Torch::AtenMulIntOp>(
3822
3812
loc, numClasses.getType (), numBatches, numClasses);
3823
3813
numResults = rewriter.create <Torch::AtenMulIntOp>(
3824
3814
loc, numClasses.getType (), numResults, maxOutputBoxesPerClass);
3825
- auto finalResultType = resultType;
3826
- */
3827
-
3828
- if (!scoreTensorType.toBuiltinTensor ().hasStaticShape ()) {
3829
- llvm_unreachable (" Unimplemented: Encountered dynamic shaped tensors "
3830
- " while lowering Onnx NonMaxSuppression op to torch" );
3831
- }
3832
- auto numResultElements =
3833
- scoreTensorType.toBuiltinTensor ().getNumElements ();
3834
- auto numResults = rewriter.create <Torch::ConstantIntOp>(
3835
- loc, rewriter.getI64IntegerAttr (numResultElements));
3836
3815
3837
3816
auto intTy = rewriter.getType <Torch::IntType>();
3838
3817
auto intListTy = rewriter.getType <Torch::ListType>(intTy);
3839
3818
3840
3819
Value resultShapeList = rewriter.create <Torch::PrimListConstructOp>(
3841
3820
loc, intListTy, SmallVector<Value>{numResults, cst3});
3842
- auto finalResultType = rewriter.getType <Torch::ValueTensorType>(
3843
- ArrayRef<int64_t >{numResultElements, 3 }, resultType.getDtype ());
3844
3821
Value finalResult = rewriter.create <Torch::AtenEmptyMemoryFormatOp>(
3845
- loc, finalResultType , resultShapeList, /* dtype=*/ cst4,
3822
+ loc, resultType , resultShapeList, /* dtype=*/ cst4,
3846
3823
/* layout=*/ cstNone,
3847
3824
/* device=*/ cstNone, /* pinMemory=*/ cstNone,
3848
3825
/* memoryFormat=*/ cstNone);
@@ -3855,16 +3832,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3855
3832
SmallVector<int64_t >{}, nmsTy.getDtype ());
3856
3833
3857
3834
auto nmsBatchLoop = rewriter.create <Torch::PrimLoopOp>(
3858
- loc, TypeRange ({finalResultType, intTy, intTy}), numBatches,
3859
- cstTrue,
3835
+ loc, TypeRange ({resultType, intTy, intTy}), numBatches, cstTrue,
3860
3836
ValueRange ({finalResult, /* Index to finalResult*/ cst0,
3861
3837
/* Num values in result*/ cst0}));
3862
3838
{
3863
3839
// Batch loop body
3864
3840
PatternRewriter::InsertionGuard guard (rewriter);
3865
3841
Block *batchLoopBody = rewriter.createBlock (
3866
3842
&nmsBatchLoop.getRegion (), nmsBatchLoop.getRegion ().begin (),
3867
- TypeRange ({intTy, finalResultType , intTy, intTy}),
3843
+ TypeRange ({intTy, resultType , intTy, intTy}),
3868
3844
{loc, loc, loc, loc});
3869
3845
3870
3846
auto batchIV = batchLoopBody->getArgument (0 );
@@ -3877,31 +3853,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3877
3853
auto batchValue = rewriter.create <Torch::PrimNumToTensorScalarOp>(
3878
3854
loc, emptyTensorTy, batchIV);
3879
3855
3856
+ auto scoreSelect = rewriter.create <Torch::AtenSelectIntOp>(
3857
+ loc, scoreSlicedType, scores, cst0, batchIV);
3858
+ auto scoreSelectType =
3859
+ cast<Torch::ValueTensorType>(scoreSelect.getType ());
3860
+ auto scoreValueType = rewriter.getType <Torch::ValueTensorType>(
3861
+ scoreSelectType.getSizes ().slice (1 ), scoreSelectType.getDtype ());
3862
+
3880
3863
auto nmsClassLoop = rewriter.create <Torch::PrimLoopOp>(
3881
- loc, TypeRange ({finalResultType , intTy, intTy}), numClasses,
3882
- cstTrue, ValueRange ({currRes, finalResIdx, numResultValues}));
3864
+ loc, TypeRange ({resultType , intTy, intTy}), numClasses, cstTrue ,
3865
+ ValueRange ({currRes, finalResIdx, numResultValues}));
3883
3866
3884
3867
{
3885
3868
// Class loop body
3886
3869
PatternRewriter::InsertionGuard guard (rewriter);
3887
3870
Block *classLoopBody = rewriter.createBlock (
3888
3871
&nmsClassLoop.getRegion (), nmsClassLoop.getRegion ().begin (),
3889
- TypeRange ({intTy, finalResultType , intTy, intTy}),
3872
+ TypeRange ({intTy, resultType , intTy, intTy}),
3890
3873
{loc, loc, loc, loc});
3891
3874
3892
3875
auto classIV = classLoopBody->getArgument (0 );
3893
3876
auto currRes = classLoopBody->getArgument (1 );
3894
3877
auto finalResIdx = classLoopBody->getArgument (2 );
3895
3878
Value numResultValues = classLoopBody->getArgument (3 );
3896
3879
3897
- auto scoreSelect = rewriter.create <Torch::AtenSelectIntOp>(
3898
- loc, scoreSlicedType, scores, cst0, batchIV);
3899
- auto scoreSelectType =
3900
- cast<Torch::ValueTensorType>(scoreSelect.getType ());
3901
- auto scoreValueType = rewriter.getType <Torch::ValueTensorType>(
3902
- scoreSelectType.getSizes ().slice (1 ),
3903
- scoreSelectType.getDtype ());
3904
-
3905
3880
auto scoreValue = rewriter.create <Torch::AtenSelectIntOp>(
3906
3881
loc, scoreValueType, scoreSelect, cst0, classIV);
3907
3882
auto classValue = rewriter.create <Torch::PrimNumToTensorScalarOp>(
@@ -3920,20 +3895,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3920
3895
rewriter.create <Torch::PrimNumToTensorScalarOp>(
3921
3896
loc, emptyTensorTy, maxOutputBoxesPerClass);
3922
3897
auto minVal = rewriter.create <Torch::AtenMinimumOp>(
3923
- loc, numOutputBoxes.getType (), numOutputBoxes,
3924
- maxBoxesPerClass);
3898
+ loc, emptyTensorTy, numOutputBoxes, maxBoxesPerClass);
3925
3899
numOutputBoxes =
3926
3900
rewriter.create <Torch::AtenItemOp>(loc, intTy, minVal);
3927
3901
3928
3902
// Loop through the nms result
3903
+ // The resulting shape of torchvision nms op is [num_selected] while
3904
+ // that of onnx is [num_selected, 3] where the selected format is
3905
+ // [batch_index, class_index, box_index].
3906
+ // Insert the triplet [batch_index, class_index, box_index] into
3907
+ // `finalResult` element by element for each box.
3908
+
3909
+ // TODO:: This can be simplified by concatinating the result of nms
3910
+ // with that of tensors filled with batch and class indices instead
3911
+ // of using the below loop. Currently this approach results in
3912
+ // failures while lowering due to dynamic dims
3913
+
3929
3914
auto nmsLoop = rewriter.create <Torch::PrimLoopOp>(
3930
- loc, TypeRange ({finalResultType , intTy}), numOutputBoxes,
3931
- cstTrue, ValueRange ({currRes, finalResIdx}));
3915
+ loc, TypeRange ({resultType , intTy}), numOutputBoxes, cstTrue ,
3916
+ ValueRange ({currRes, finalResIdx}));
3932
3917
{
3933
3918
PatternRewriter::InsertionGuard guard (rewriter);
3934
3919
Block *loopBody = rewriter.createBlock (
3935
3920
&nmsLoop.getRegion (), nmsLoop.getRegion ().begin (),
3936
- TypeRange ({intTy, finalResultType , intTy}), {loc, loc, loc});
3921
+ TypeRange ({intTy, resultType , intTy}), {loc, loc, loc});
3937
3922
auto iter = loopBody->getArgument (0 );
3938
3923
auto currRes = loopBody->getArgument (1 );
3939
3924
auto idxCst = loopBody->getArgument (2 );
@@ -3955,7 +3940,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3955
3940
auto scatterBatch = rewriter.create <Torch::AtenSelectScatterOp>(
3956
3941
loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0);
3957
3942
auto batchResult = rewriter.create <Torch::AtenSelectScatterOp>(
3958
- loc, finalResultType , currRes, scatterBatch, cst0, idxCst);
3943
+ loc, resultType , currRes, scatterBatch, cst0, idxCst);
3959
3944
3960
3945
// Update class dimension
3961
3946
auto classDim3D = rewriter.create <Torch::AtenSelectIntOp>(
@@ -3970,8 +3955,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3970
3955
auto scatterClass = rewriter.create <Torch::AtenSelectScatterOp>(
3971
3956
loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1);
3972
3957
auto classRes = rewriter.create <Torch::AtenSelectScatterOp>(
3973
- loc, finalResultType, batchResult, scatterClass, cst0,
3974
- idxCst);
3958
+ loc, resultType, batchResult, scatterClass, cst0, idxCst);
3975
3959
3976
3960
// Update nms result dimension
3977
3961
auto resDim3D = rewriter.create <Torch::AtenSelectIntOp>(
@@ -3988,7 +3972,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
3988
3972
auto scatterRes = rewriter.create <Torch::AtenSelectScatterOp>(
3989
3973
loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2);
3990
3974
Value nmsResult = rewriter.create <Torch::AtenSelectScatterOp>(
3991
- loc, finalResultType , classRes, scatterRes, cst0, idxCst);
3975
+ loc, resultType , classRes, scatterRes, cst0, idxCst);
3992
3976
3993
3977
// Increment the result index
3994
3978
Value next =
0 commit comments