Skip to content

Commit 8727650

Browse files
committed
Allow dot operand hoisting for math dialect, arith.truncf, arith.trunci
1 parent 1ef30e9 commit 8727650

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -131,19 +131,21 @@ class HoistLayoutConversion : public OpRewritePattern<ConvertLayoutOp> {
131131
// bitwidth is unable to realize that there is a mixed-precision dot
132132
// (hence kWidth = 1) but wants to hoist through the type conversion.
133133
if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
134-
return failure();
134+
return failure();
135135

136-
// Only consider custom conversions or arith ops.
136+
// Only consider custom conversions, math or arith ops.
137137
// TODO(jlebar): Is this too restrictive?
138138
if (!isa<FpToFpOp, BitcastOp>(src) && !isPureUnaryInlineAsm(src) &&
139-
src->getDialect()->getTypeID() != TypeID::get<arith::ArithDialect>())
139+
src->getDialect()->getTypeID() != TypeID::get<arith::ArithDialect>() &&
140+
src->getDialect()->getTypeID() != TypeID::get<math::MathDialect>())
140141
return failure();
141142

142143
// Currently, these instructions are not supported during lowering of
143144
// shared -> dot_operand layout. Not all types and type conversions are
144145
// supported.
145-
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
146+
if (isa<arith::SelectOp>(src)) {
146147
return failure();
148+
}
147149

148150
// Don't hoist through u1 -> fp casts as they aren't supported in
149151
// ElementwiseOpToLLVM::reorderValues().

0 commit comments

Comments
 (0)