Skip to content

[TORCH] Add support for logcumsumexp Op #4187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8625,6 +8625,54 @@ def Torch_AtenCumprodOp : Torch_Op<"aten.cumprod", [
}];
}

def Torch_AtenLogcumsumexpOp : Torch_Op<"aten.logcumsumexp", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::logcumsumexp : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogcumsumexpOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogcumsumexpOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_Aten_LogcumsumexpOp : Torch_Op<"aten._logcumsumexp", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_logcumsumexp : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_LogcumsumexpOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten_LogcumsumexpOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
24 changes: 24 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9361,6 +9361,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.logcumsumexp\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._logcumsumexp\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
Expand Down Expand Up @@ -12507,6 +12513,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.logcumsumexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._logcumsumexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" %none = torch.constant.none\n"
Expand Down
66 changes: 66 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2884,6 +2884,68 @@ class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
};
} // namespace

// Decompose AtenLogCumsumExpOp into: AtenExpOp,
// AtenCumsumOp and AtenLogOp
// logcumsumexp(x)[i][j] = log(sum_{k=0}^{j} exp(x[i][k]))

namespace {
template <typename OpTy>

class DecomposeAtenLogCumsumExpOp : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();

auto inputType = dyn_cast<BaseTensorType>(input.getType());
if (!inputType)
return rewriter.notifyMatchFailure(op, "Supports only tensor type");

if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()))
return rewriter.notifyMatchFailure(
op, "Currently Support only floating point type");

int64_t inputRank = inputType.getSizes().size();
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
op, "Unimplemented: Only constant dim value is supported");
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "invalid dim");

Type elementType = inputType.getDtype();
torch_upstream::ScalarType scalarType;
// Currently it supports for float datatype
if (elementType.isF16())
scalarType = torch_upstream::ScalarType::Half;
else if (elementType.isF32())
scalarType = torch_upstream::ScalarType::Float;
else if (elementType.isF64())
scalarType = torch_upstream::ScalarType::Double;
else
return rewriter.notifyMatchFailure(op, "Unsupported data type");

int64_t scalarVal = static_cast<int64_t>(scalarType);

Value dtypeVal = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(scalarVal));

Value expInput = rewriter.create<AtenExpOp>(loc, op.getType(), input);

Value cumsum = rewriter.create<AtenCumsumOp>(loc, op.getType(), expInput,
op.getDim(), dtypeVal);

Value result = rewriter.create<AtenLogOp>(loc, op.getType(), cumsum);

rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenLogSigmoidOp : public OpRewritePattern<AtenLogSigmoidOp> {
public:
Expand Down Expand Up @@ -11929,6 +11991,10 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenLogCumsumExpOp<AtenLogcumsumexpOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenLogCumsumExpOp<Aten_LogcumsumexpOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardshrinkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftshrinkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_LogSoftmaxOp>();
target.addIllegalOp<AtenLogSoftmaxIntOp>();
target.addIllegalOp<AtenLogSigmoidOp>();
target.addIllegalOp<Aten_LogcumsumexpOp, AtenLogcumsumexpOp>();
target.addIllegalOp<AtenHardshrinkOp>();
target.addIllegalOp<AtenSoftshrinkOp>();
target.addIllegalOp<AtenEmptyLikeOp>();
Expand Down
18 changes: 18 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2965,6 +2965,10 @@
"LinalgNormKeepDimComplexModule_basic",
"LinalgVectorNormComplexModule_basic",
"LogSoftmaxBackwardModule_basic",
"LogCumsumExpModule_basic",
"LogCumsumExpStaticModule_basic",
"LogCumsumExpStaticNegativeDimModule_basic",
"LogCumsumExpDtypeModule_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic",
"MaxPool2dCeilModeTrueModule_basic",
Expand Down Expand Up @@ -3322,6 +3326,8 @@
# RuntimeError: Given input size: (1x1x1). Calculated output size: (1x0x0). Output size is too small
"AvgPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
"MaxPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
"_LogCumsumExpStaticModule_basic",
"_LogCumsumExpStaticNegativeDimModule_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down Expand Up @@ -3683,6 +3689,10 @@
"LinalgNormKeepDimComplexModule_basic",
"LinalgVectorNormComplexModule_basic",
"LinspaceEmptyModule_basic",
"LogCumsumExpModule_basic",
"LogCumsumExpStaticModule_basic",
"LogCumsumExpStaticNegativeDimModule_basic",
"LogCumsumExpDtypeModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dModule_basic",
Expand Down Expand Up @@ -3865,6 +3875,8 @@
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
"_LogCumsumExpStaticModule_basic",
"_LogCumsumExpStaticNegativeDimModule_basic",
}

ONNX_TOSA_CRASHING_SET = {
Expand Down Expand Up @@ -4472,6 +4484,10 @@
"LinalgVectorNormComplexModule_basic",
"LogSoftmaxBackwardModule_basic",
"LogSoftmaxIntModule_basic",
"logCumsumExpModule_basic",
"LogCumsumExpStaticModule_basic",
"LogCumsumExpStaticNegativeDimModule_basic",
"LogCumsumExpDtypeModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"MatmulBroadcastBatchDim_basic",
"MatmulSingleDynamicBatchDim_basic",
Expand Down Expand Up @@ -4926,6 +4942,8 @@
"_ConvolutionDeprecated2DDeterministicModule_basic",
"_LogSoftmaxModule_basic",
"_SoftmaxModule_basic",
"_LogCumsumExpStaticModule_basic",
"_LogCumsumExpStaticNegativeDimModule_basic",
}

if torch_version_for_comparison() > version.parse("2.5.1"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,12 @@ def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None
def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
return self

def aten〇logcumsumexp〡shape(self: List[int], dim: int) -> List[int]:
return self

def aten〇_logcumsumexp〡shape(self: List[int], dim: int) -> List[int]:
return self

def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return self

Expand Down Expand Up @@ -3217,6 +3223,25 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
return torch.int64
return self_dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
def aten〇logcumsumexp〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(
tensor_shapes=[(1, 1)],
tensor_device="cpu",
dim=0,
error_types={*all_integer_dtypes()}
)
)
def aten〇_logcumsumexp〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)
return self_dtype


@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,8 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)")
emit("aten::logcumsumexp : (Tensor, int) -> (Tensor)")
emit("aten::_logcumsumexp : (Tensor, int) -> (Tensor)")
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")
Expand Down
96 changes: 96 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5048,6 +5048,102 @@ def CumsumWithDtypeModule_basic(module, tu: TestUtils):
# ==============================================================================


class LogCumsumExpModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1, -1, -1], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.logcumsumexp(x, dim=1)


@register_test_case(module_factory=lambda: LogCumsumExpModule())
def LogCumsumExpModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 3))


class LogCumsumExpStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([1, 2, 3], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.logcumsumexp(x, dim=1)


@register_test_case(module_factory=lambda: LogCumsumExpStaticModule())
def LogCumsumExpStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 2, 3))


class LogCumsumExpStaticNegativeDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([8, 5, 6], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.logcumsumexp(x, dim=-2)


@register_test_case(module_factory=lambda: LogCumsumExpStaticNegativeDimModule())
def LogCumsumExpStaticNegativeDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 5, 6))


class LogCumsumExpDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([5, 3, 6, 9], torch.float64, True)])
def forward(self, x):
return torch.ops.aten.logcumsumexp(x, dim=1)


@register_test_case(module_factory=lambda: LogCumsumExpDtypeModule())
def LogCumsumExpDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, 6, 9).to(torch.float64))


# ==============================================================================


class _LogCumsumExpStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([5, 3, 6, 9], torch.float32, True)])
def forward(self, x):
return torch.ops.aten._logcumsumexp(x, dim=1)


@register_test_case(module_factory=lambda: _LogCumsumExpStaticModule())
def _LogCumsumExpStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, 6, 9))


class _LogCumsumExpStaticNegativeDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([6, 2, 3], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.logcumsumexp(x, dim=-1)


@register_test_case(module_factory=lambda: _LogCumsumExpStaticNegativeDimModule())
def _LogCumsumExpStaticNegativeDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 2, 3))


# ==============================================================================


class CumprodModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down