Skip to content

Commit 420bbca

Browse files
committed
Address review comments and simplify lit tests
1 parent 7bb91f4 commit 420bbca

File tree

2 files changed

+112
-310
lines changed

2 files changed

+112
-310
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

+36-52
Original file line numberDiff line numberDiff line change
@@ -3701,16 +3701,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37013701
return rewriter.notifyMatchFailure(
37023702
binder.op, "expected center_point_box attribute to be 0 or 1");
37033703

3704-
Value cst0 = rewriter.create<Torch::ConstantIntOp>(
3705-
loc, rewriter.getI64IntegerAttr(0));
3706-
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
3707-
loc, rewriter.getI64IntegerAttr(1));
3708-
Value cst2 = rewriter.create<Torch::ConstantIntOp>(
3709-
loc, rewriter.getI64IntegerAttr(2));
3710-
Value cst3 = rewriter.create<Torch::ConstantIntOp>(
3711-
loc, rewriter.getI64IntegerAttr(3));
3712-
Value cst4 = rewriter.create<Torch::ConstantIntOp>(
3713-
loc, rewriter.getI64IntegerAttr(4));
3704+
Value cst0 = rewriter.create<Torch::ConstantIntOp>(loc, 0);
3705+
Value cst1 = rewriter.create<Torch::ConstantIntOp>(loc, 1);
3706+
Value cst2 = rewriter.create<Torch::ConstantIntOp>(loc, 2);
3707+
Value cst3 = rewriter.create<Torch::ConstantIntOp>(loc, 3);
3708+
Value cst4 = rewriter.create<Torch::ConstantIntOp>(loc, 4);
37143709
Value cst2F = rewriter.create<Torch::ConstantFloatOp>(
37153710
loc, rewriter.getF64FloatAttr(2.0));
37163711
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
@@ -3813,36 +3808,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
38133808
// Create an empty tensor of shape (B*C*N, 3) to store the final result.
38143809
// We slice this to required elements at the end
38153810

3816-
// FIXME:: Currently empty tensors created with dynamic sizes are not
3817-
// fully supported. Uncomment the below lines once dynamic sizes for
3818-
// empty tensors are supported end to end.
3819-
3820-
/*
38213811
Value numResults = rewriter.create<Torch::AtenMulIntOp>(
38223812
loc, numClasses.getType(), numBatches, numClasses);
38233813
numResults = rewriter.create<Torch::AtenMulIntOp>(
38243814
loc, numClasses.getType(), numResults, maxOutputBoxesPerClass);
3825-
auto finalResultType = resultType;
3826-
*/
3827-
3828-
if (!scoreTensorType.toBuiltinTensor().hasStaticShape()) {
3829-
llvm_unreachable("Unimplemented: Encountered dynamic shaped tensors "
3830-
"while lowering Onnx NonMaxSuppression op to torch");
3831-
}
3832-
auto numResultElements =
3833-
scoreTensorType.toBuiltinTensor().getNumElements();
3834-
auto numResults = rewriter.create<Torch::ConstantIntOp>(
3835-
loc, rewriter.getI64IntegerAttr(numResultElements));
38363815

38373816
auto intTy = rewriter.getType<Torch::IntType>();
38383817
auto intListTy = rewriter.getType<Torch::ListType>(intTy);
38393818

38403819
Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
38413820
loc, intListTy, SmallVector<Value>{numResults, cst3});
3842-
auto finalResultType = rewriter.getType<Torch::ValueTensorType>(
3843-
ArrayRef<int64_t>{numResultElements, 3}, resultType.getDtype());
38443821
Value finalResult = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
3845-
loc, finalResultType, resultShapeList, /*dtype=*/cst4,
3822+
loc, resultType, resultShapeList, /*dtype=*/cst4,
38463823
/*layout=*/cstNone,
38473824
/*device=*/cstNone, /*pinMemory=*/cstNone,
38483825
/*memoryFormat=*/cstNone);
@@ -3855,16 +3832,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
38553832
SmallVector<int64_t>{}, nmsTy.getDtype());
38563833

38573834
auto nmsBatchLoop = rewriter.create<Torch::PrimLoopOp>(
3858-
loc, TypeRange({finalResultType, intTy, intTy}), numBatches,
3859-
cstTrue,
3835+
loc, TypeRange({resultType, intTy, intTy}), numBatches, cstTrue,
38603836
ValueRange({finalResult, /*Index to finalResult*/ cst0,
38613837
/*Num values in result*/ cst0}));
38623838
{
38633839
// Batch loop body
38643840
PatternRewriter::InsertionGuard guard(rewriter);
38653841
Block *batchLoopBody = rewriter.createBlock(
38663842
&nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(),
3867-
TypeRange({intTy, finalResultType, intTy, intTy}),
3843+
TypeRange({intTy, resultType, intTy, intTy}),
38683844
{loc, loc, loc, loc});
38693845

38703846
auto batchIV = batchLoopBody->getArgument(0);
@@ -3877,31 +3853,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
38773853
auto batchValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
38783854
loc, emptyTensorTy, batchIV);
38793855

3856+
auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>(
3857+
loc, scoreSlicedType, scores, cst0, batchIV);
3858+
auto scoreSelectType =
3859+
cast<Torch::ValueTensorType>(scoreSelect.getType());
3860+
auto scoreValueType = rewriter.getType<Torch::ValueTensorType>(
3861+
scoreSelectType.getSizes().slice(1), scoreSelectType.getDtype());
3862+
38803863
auto nmsClassLoop = rewriter.create<Torch::PrimLoopOp>(
3881-
loc, TypeRange({finalResultType, intTy, intTy}), numClasses,
3882-
cstTrue, ValueRange({currRes, finalResIdx, numResultValues}));
3864+
loc, TypeRange({resultType, intTy, intTy}), numClasses, cstTrue,
3865+
ValueRange({currRes, finalResIdx, numResultValues}));
38833866

38843867
{
38853868
// Class loop body
38863869
PatternRewriter::InsertionGuard guard(rewriter);
38873870
Block *classLoopBody = rewriter.createBlock(
38883871
&nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(),
3889-
TypeRange({intTy, finalResultType, intTy, intTy}),
3872+
TypeRange({intTy, resultType, intTy, intTy}),
38903873
{loc, loc, loc, loc});
38913874

38923875
auto classIV = classLoopBody->getArgument(0);
38933876
auto currRes = classLoopBody->getArgument(1);
38943877
auto finalResIdx = classLoopBody->getArgument(2);
38953878
Value numResultValues = classLoopBody->getArgument(3);
38963879

3897-
auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>(
3898-
loc, scoreSlicedType, scores, cst0, batchIV);
3899-
auto scoreSelectType =
3900-
cast<Torch::ValueTensorType>(scoreSelect.getType());
3901-
auto scoreValueType = rewriter.getType<Torch::ValueTensorType>(
3902-
scoreSelectType.getSizes().slice(1),
3903-
scoreSelectType.getDtype());
3904-
39053880
auto scoreValue = rewriter.create<Torch::AtenSelectIntOp>(
39063881
loc, scoreValueType, scoreSelect, cst0, classIV);
39073882
auto classValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
@@ -3920,20 +3895,30 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
39203895
rewriter.create<Torch::PrimNumToTensorScalarOp>(
39213896
loc, emptyTensorTy, maxOutputBoxesPerClass);
39223897
auto minVal = rewriter.create<Torch::AtenMinimumOp>(
3923-
loc, numOutputBoxes.getType(), numOutputBoxes,
3924-
maxBoxesPerClass);
3898+
loc, emptyTensorTy, numOutputBoxes, maxBoxesPerClass);
39253899
numOutputBoxes =
39263900
rewriter.create<Torch::AtenItemOp>(loc, intTy, minVal);
39273901

39283902
// Loop through the nms result
3903+
// The resulting shape of torchvision nms op is [num_selected] while
3904+
// that of onnx is [num_selected, 3] where the selected format is
3905+
// [batch_index, class_index, box_index].
3906+
// Insert the triplet [batch_index, class_index, box_index] into
3907+
// `finalResult` element by element for each box.
3908+
3909+
// TODO:: This can be simplified by concatinating the result of nms
3910+
// with that of tensors filled with batch and class indices instead
3911+
// of using the below loop. Currently this approach results in
3912+
// failures while lowering due to dynamic dims
3913+
39293914
auto nmsLoop = rewriter.create<Torch::PrimLoopOp>(
3930-
loc, TypeRange({finalResultType, intTy}), numOutputBoxes,
3931-
cstTrue, ValueRange({currRes, finalResIdx}));
3915+
loc, TypeRange({resultType, intTy}), numOutputBoxes, cstTrue,
3916+
ValueRange({currRes, finalResIdx}));
39323917
{
39333918
PatternRewriter::InsertionGuard guard(rewriter);
39343919
Block *loopBody = rewriter.createBlock(
39353920
&nmsLoop.getRegion(), nmsLoop.getRegion().begin(),
3936-
TypeRange({intTy, finalResultType, intTy}), {loc, loc, loc});
3921+
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
39373922
auto iter = loopBody->getArgument(0);
39383923
auto currRes = loopBody->getArgument(1);
39393924
auto idxCst = loopBody->getArgument(2);
@@ -3955,7 +3940,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
39553940
auto scatterBatch = rewriter.create<Torch::AtenSelectScatterOp>(
39563941
loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0);
39573942
auto batchResult = rewriter.create<Torch::AtenSelectScatterOp>(
3958-
loc, finalResultType, currRes, scatterBatch, cst0, idxCst);
3943+
loc, resultType, currRes, scatterBatch, cst0, idxCst);
39593944

39603945
// Update class dimension
39613946
auto classDim3D = rewriter.create<Torch::AtenSelectIntOp>(
@@ -3970,8 +3955,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
39703955
auto scatterClass = rewriter.create<Torch::AtenSelectScatterOp>(
39713956
loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1);
39723957
auto classRes = rewriter.create<Torch::AtenSelectScatterOp>(
3973-
loc, finalResultType, batchResult, scatterClass, cst0,
3974-
idxCst);
3958+
loc, resultType, batchResult, scatterClass, cst0, idxCst);
39753959

39763960
// Update nms result dimension
39773961
auto resDim3D = rewriter.create<Torch::AtenSelectIntOp>(
@@ -3988,7 +3972,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
39883972
auto scatterRes = rewriter.create<Torch::AtenSelectScatterOp>(
39893973
loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2);
39903974
Value nmsResult = rewriter.create<Torch::AtenSelectScatterOp>(
3991-
loc, finalResultType, classRes, scatterRes, cst0, idxCst);
3975+
loc, resultType, classRes, scatterRes, cst0, idxCst);
39923976

39933977
// Increment the result index
39943978
Value next =

0 commit comments

Comments
 (0)