Skip to content

Commit 9a37afd

Browse files
[Mosaic] Internal change.
PiperOrigin-RevId: 832476571
1 parent b6c67c5 commit 9a37afd

File tree

9 files changed

+125
-122
lines changed

9 files changed

+125
-122
lines changed

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ LogicalResult UnrollVectorsOp::canonicalize(UnrollVectorsOp op,
106106
LogicalResult BitcastOp::verify() {
107107
auto in_ty = getInput().getType();
108108
auto out_ty = getOutput().getType();
109-
auto in_bitwidth = in_ty.getElementTypeBitWidth();
110-
auto out_bitwidth = out_ty.getElementTypeBitWidth();
109+
auto in_bitwidth = getElementTypeBitwidth(in_ty).value();
110+
auto out_bitwidth = getElementTypeBitwidth(out_ty).value();
111111
if (in_bitwidth != out_bitwidth) {
112112
if (in_ty.getRank() < 2 || out_ty.getRank() < 2) {
113113
return emitError(
@@ -632,8 +632,8 @@ LogicalResult MemRefBitcastOp::verify() {
632632
if (src_ty.getRank() <= 1) {
633633
return emitOpError("Not implemented: 1d memref bitcast.");
634634
}
635-
auto src_bitwidth = src_ty.getElementTypeBitWidth();
636-
auto tgt_bitwidth = tgt_ty.getElementTypeBitWidth();
635+
auto src_bitwidth = getElementTypeBitwidth(src_ty).value();
636+
auto tgt_bitwidth = getElementTypeBitwidth(tgt_ty).value();
637637
for (int i = 0; i < src_ty.getRank(); ++i) {
638638
auto src_dim_size = src_ty.getDimSize(i);
639639
auto tgt_dim_size = tgt_ty.getDimSize(i);
@@ -688,8 +688,8 @@ LogicalResult MemRefBitcastOp::canonicalize(MemRefBitcastOp op,
688688
if (!erase_layout_op) {
689689
return failure();
690690
}
691-
auto src_bitwidth = src_ty.getElementTypeBitWidth();
692-
auto tgt_bitwidth = dst_ty.getElementTypeBitWidth();
691+
auto src_bitwidth = getElementTypeBitwidth(src_ty).value();
692+
auto tgt_bitwidth = getElementTypeBitwidth(dst_ty).value();
693693
auto layout_ref = erase_layout_op.getOperand();
694694
auto layout_ty = layout_ref.getType();
695695
auto layout = cast<tpu::TiledLayoutAttr>(layout_ty.getLayout());
@@ -765,7 +765,7 @@ LogicalResult verifyStoreOp(Op op) {
765765
"Expected base and valueToStore element type to match");
766766
}
767767
if (op.getMask()) {
768-
if (value_ty.getElementTypeBitWidth() != 32) {
768+
if (getElementTypeBitwidth(value_ty) != 32) {
769769
return op.emitError(
770770
"Not implemented: masked store with non-32-bit element type");
771771
}
@@ -799,7 +799,7 @@ LogicalResult verifyLoadOp(Op op) {
799799
return op.emitOpError("Expected base and result element type to match.");
800800
}
801801
if (op.getMask()) {
802-
if (value_ty.getElementTypeBitWidth() != 32) {
802+
if (getElementTypeBitwidth(value_ty) != 32) {
803803
return op.emitError(
804804
"Not implemented: masked load with non-32-bit element type");
805805
}
@@ -988,7 +988,7 @@ LogicalResult MatmulOp::verify() {
988988
return emitOpError(
989989
"Not implemented: matmul acc and result have different types");
990990
}
991-
if (acc_ty.getElementTypeBitWidth() != 32) {
991+
if (getElementTypeBitwidth(acc_ty) != 32) {
992992
return emitOpError("Expected matmul acc to be 32-bit");
993993
}
994994

@@ -1825,8 +1825,8 @@ LogicalResult ReciprocalOp::verify() {
18251825
}
18261826

18271827
LogicalResult UnpackSubelementsOp::verify() {
1828-
const int packing_factor = getType().getElementTypeBitWidth() /
1829-
getSource().getType().getElementTypeBitWidth();
1828+
const int packing_factor = getElementTypeBitwidth(getType()).value() /
1829+
getElementTypeBitwidth(getSource().getType()).value();
18301830
if (auto index = getIndex(); index >= packing_factor) {
18311831
return emitOpError("Index must be between 0 and the packing factor (")
18321832
<< packing_factor << "), got " << index;
@@ -1849,8 +1849,8 @@ LogicalResult UnpackSubelementsOp::canonicalize(UnpackSubelementsOp op,
18491849
rewriter.replaceAllOpUsesWith(
18501850
op, pack.getPaddedSources(
18511851
pack.getSources(), pack.getPositions(),
1852-
op.getType().getElementTypeBitWidth() /
1853-
pack.getType().getElementTypeBitWidth())[op.getIndex()]);
1852+
getElementTypeBitwidth(op.getType()).value() /
1853+
getElementTypeBitwidth(pack.getType()).value())[op.getIndex()]);
18541854
return success();
18551855
}
18561856
return failure();
@@ -1864,8 +1864,8 @@ LogicalResult UnpackSubelementsOp::canonicalize(UnpackSubelementsOp op,
18641864
}
18651865
auto packed_elem_ty = pack.getType().getElementType();
18661866
if (!packed_elem_ty.isSignlessInteger() ||
1867-
packed_elem_ty.getIntOrFloatBitWidth() >
1868-
src_elem_ty.getIntOrFloatBitWidth()) {
1867+
getTypeBitwidth(packed_elem_ty).value() >
1868+
getTypeBitwidth(src_elem_ty).value()) {
18691869
return failure();
18701870
}
18711871
}
@@ -1905,9 +1905,8 @@ LogicalResult PackSubelementsOp::verify() {
19051905
if (getPositions().size() != getSources().size()) {
19061906
return emitOpError("Size of sources and positions must match");
19071907
}
1908-
const int packing_factor = cast<VectorType>(getSources().front().getType())
1909-
.getElementTypeBitWidth() /
1910-
getType().getElementTypeBitWidth();
1908+
const int packing_factor = getElementTypeBitwidth(cast<VectorType>(getSources().front().getType())).value() /
1909+
getElementTypeBitwidth(getType()).value();
19111910
SmallVector<bool> seen_positions(packing_factor, false);
19121911
for (const int32_t position : getPositions()) {
19131912
if (position < 0 || packing_factor <= position) {
@@ -1950,9 +1949,8 @@ LogicalResult PackElementwiseOp::verify() {
19501949
getTargetType()))) {
19511950
return failure();
19521951
}
1953-
const int packing_factor =
1954-
src_vty.getElementTypeBitWidth() /
1955-
getTargetType().getIntOrFloatBitWidth();
1952+
const int packing_factor = getElementTypeBitwidth(src_vty).value() /
1953+
getTypeBitwidth(getTargetType()).value();
19561954
if (packing_factor != getSources().size()) {
19571955
return emitOpError("The number of sources must match the packing factor (")
19581956
<< packing_factor << "), got " << getSources().size();
@@ -1964,8 +1962,8 @@ LogicalResult UnpackElementwiseOp::verify() {
19641962
if (failed(verifyElementwisePacking(*this, getType(), getSourceType()))) {
19651963
return failure();
19661964
}
1967-
const int packing_factor = getType().getElementTypeBitWidth() /
1968-
getSourceType().getIntOrFloatBitWidth();
1965+
const int packing_factor = getElementTypeBitwidth(getType()).value() /
1966+
getTypeBitwidth(getSourceType()).value();
19691967
if (auto index = getIndex(); index >= packing_factor) {
19701968
return emitOpError("Index must be between 0 and the packing factor (")
19711969
<< packing_factor << "), got " << index;
@@ -2012,9 +2010,9 @@ LogicalResult DynamicGatherOp::verify() {
20122010

20132011
LogicalResult AllReduceOp::verify() {
20142012
auto in_ty = getInput().getType();
2015-
auto in_bitwidth = in_ty.getElementTypeBitWidth();
2013+
auto in_bitwidth = getElementTypeBitwidth(in_ty).value();
20162014
auto out_ty = getOutput().getType();
2017-
auto out_bitwidth = out_ty.getElementTypeBitWidth();
2015+
auto out_bitwidth = getElementTypeBitwidth(out_ty).value();
20182016
auto kind = getKind();
20192017

20202018
if (in_bitwidth == 1) {
@@ -2070,7 +2068,7 @@ LogicalResult AllReduceOp::verify() {
20702068
LogicalResult ReduceIndexOp::verify() {
20712069
auto in_ty = getInput().getType();
20722070
auto out_ty = getOutput().getType();
2073-
auto bitwidth = in_ty.getElementTypeBitWidth();
2071+
auto bitwidth = getElementTypeBitwidth(in_ty).value();
20742072
auto axis = getAxis();
20752073
auto kind = getKind();
20762074
if (kind != ReductionKind::kArgMax &&

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ FailureOr<TypedValue<MemRefType>> getInternalScratch(
131131
if (shape.back() % ctx.target_shape[1] != 0) {
132132
return emitError(loc, "Unaligned scratch shape on minormost dimension");
133133
}
134-
int packing = 32 / elem_ty.getIntOrFloatBitWidth();
134+
int packing = 32 / getTypeBitwidth(elem_ty).value();
135135
int sublane_count = llvm::divideCeil(
136136
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>()) /
137137
ctx.target_shape[1],
@@ -364,7 +364,7 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
364364
MLIRContext *mlir_ctx = func.getContext();
365365
Block &entry_block = func.getBody().front();
366366
auto value_ty = cast<VectorType>(value.getType());
367-
if (value_ty.getElementType().getIntOrFloatBitWidth() != 32) {
367+
if (getTypeBitwidth(value_ty.getElementType()) != 32) {
368368
return func.emitOpError("Not implemented: Only 32-bit constants supported");
369369
}
370370
if (func->getAttr("scratch_operands")) {
@@ -1403,7 +1403,7 @@ FailureOr<xla::Array<Value>> unpackVregs(RewriteContext &ctx,
14031403
//
14041404
// 28 24 20 16 12 8 4 0 bit index
14051405
// yyyyyyyyyyyyyyyyxxxxxxxxxxxxxxxx
1406-
if (res_vreg_ty.getElementTypeBitWidth() == 32) {
1406+
if (getElementTypeBitwidth(res_vreg_ty) == 32) {
14071407
// If the result vreg is 32-bit, we can just interleaved unpack the
14081408
// input vreg, as there are no multiple subelements to unpack.
14091409
*v = builder.create<UnpackSubelementsOp>(
@@ -1417,7 +1417,7 @@ FailureOr<xla::Array<Value>> unpackVregs(RewriteContext &ctx,
14171417
: cast<Type>(builder.getF32Type()),
14181418
ctx.target_shape);
14191419
const int dst_packing_factor =
1420-
32 / res_vreg_ty.getElementTypeBitWidth();
1420+
32 / getElementTypeBitwidth(res_vreg_ty).value();
14211421
// `vreg_part` is with respect to result vreg bitwidth. Expand it to
14221422
// base on 32-bit.
14231423
const int vreg_part_unpacked_to_32b = vreg_part * dst_packing_factor;
@@ -1442,7 +1442,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> unpackVregs(
14421442
const xla::Array<Value> &input_vregs, VectorType input_ty,
14431443
VectorType result_ty, const VectorLayout &layout_in,
14441444
const std::array<int64_t, 2> tiling_out) {
1445-
const int unpacked_bitwidth = result_ty.getElementTypeBitWidth();
1445+
const int unpacked_bitwidth = getElementTypeBitwidth(result_ty).value();
14461446
const LayoutOffsets offsets_out = alignedToVregSlice(
14471447
layout_in.offsets(), ctx.target_shape, unpacked_bitwidth, tiling_out);
14481448
const VectorLayout layout_out(unpacked_bitwidth, offsets_out, tiling_out,
@@ -1539,7 +1539,7 @@ LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op,
15391539
auto extui_op = cast<arith::ExtUIOp>(op);
15401540
const auto in_ty = cast<VectorType>(extui_op.getIn().getType());
15411541
const auto out_ty = cast<VectorType>(extui_op.getType());
1542-
const unsigned in_bitwidth = in_ty.getElementTypeBitWidth();
1542+
const unsigned in_bitwidth = getElementTypeBitwidth(in_ty).value();
15431543
if (in_bitwidth == 1) {
15441544
return elementwise_op_rule(ctx, op, layouts_in, layouts_out);
15451545
}
@@ -1548,7 +1548,7 @@ LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op,
15481548
xla::Array<Value> output_vregs,
15491549
ext_op_rule_impl(ctx, builder, extui_op, *layouts_in.front(),
15501550
*layouts_out.front()));
1551-
unsigned out_bitwidth = out_ty.getElementTypeBitWidth();
1551+
unsigned out_bitwidth = getElementTypeBitwidth(out_ty).value();
15521552
// Generate a mask to mask out the sign extension. e.g., for u8 -> u16,
15531553
// the mask is 0x00ff00ff.
15541554
unsigned mask = (1 << in_bitwidth) - 1;
@@ -1656,15 +1656,15 @@ FailureOr<xla::Array<Value>> packVregs(RewriteContext &ctx, OpBuilder &builder,
16561656
// achieve this, we can unpack all subelements in each part to 32-bit and
16571657
// then interleaved pack them into desired type.
16581658
SmallVector<Value> unpacks;
1659-
if (input_ty.getElementType().getIntOrFloatBitWidth() == 32) {
1659+
if (getTypeBitwidth(input_ty.getElementType()) == 32) {
16601660
unpacks.append(parts.begin(), parts.end());
16611661
} else {
16621662
VectorType unpacked_vty =
16631663
getNativeVregType(input_ty.getElementType().isSignlessInteger()
16641664
? cast<Type>(builder.getI32Type())
16651665
: cast<Type>(builder.getF32Type()),
16661666
ctx.target_shape);
1667-
const int32_t packing_factor = 32 / input_ty.getElementTypeBitWidth();
1667+
const int32_t packing_factor = 32 / getElementTypeBitwidth(input_ty).value();
16681668
for (Value part : parts) {
16691669
if (part) {
16701670
for (int i = 0; i < packing_factor; ++i) {
@@ -1692,7 +1692,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> packVregs(
16921692
const xla::Array<Value> &input_vregs, VectorType input_ty,
16931693
VectorType result_ty, const VectorLayout &layout_in,
16941694
const std::array<int64_t, 2> tiling_out, const LayoutOffsets offset_hints) {
1695-
const int packed_bitwidth = result_ty.getElementTypeBitWidth();
1695+
const int packed_bitwidth = getElementTypeBitwidth(result_ty).value();
16961696
const std::array<int64_t, 2> unpacked_vreg_slice =
16971697
layout_in.vregSlice(ctx.target_shape);
16981698
const std::array<int64_t, 2> packed_vreg_slice =
@@ -2600,7 +2600,7 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
26002600
"Not implemented: Unsupported matmul operand tiling");
26012601
}
26022602
}
2603-
if (acc.getType().getElementType().getIntOrFloatBitWidth() != 32) {
2603+
if (getTypeBitwidth(acc.getType().getElementType()) != 32) {
26042604
return op.emitOpError("Not implemented: Non-32-bit matmul acc");
26052605
}
26062606
const ArrayRef<int64_t> lhs_shape = lhs.getType().getShape();
@@ -2659,8 +2659,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
26592659
// second dim to be a multiple of mxu_size.
26602660
auto mxu_contracting_size = ctx.mxu_shape[0];
26612661
auto mxu_noncontracting_size = ctx.mxu_shape[1];
2662-
if (lhs.getType().getElementTypeBitWidth() < 8 &&
2663-
rhs.getType().getElementTypeBitWidth() < 8) {
2662+
if (getElementTypeBitwidth(lhs.getType()) < 8 &&
2663+
getElementTypeBitwidth(rhs.getType()) < 8) {
26642664
mxu_contracting_size *= 2;
26652665
}
26662666
auto rhs_row_size = mxu_contracting_size;
@@ -4530,7 +4530,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
45304530
broadcast_op.erase();
45314531
return success();
45324532
} else if (layout_out.bitwidth() == 32 &&
4533-
broadcast_op.getSourceType().getIntOrFloatBitWidth() == 1) {
4533+
getTypeBitwidth(broadcast_op.getSourceType()) == 1) {
45344534
// Broadcasting the i1 scalar involves first converting i1 to i32, followed
45354535
// by broadcasting i32 to the target shape. Finally, the comparison with 0s
45364536
// yields the vmask.
@@ -4553,14 +4553,14 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
45534553
return success();
45544554
} else if (layout_out.bitwidth() < 32) {
45554555
CHECK_EQ(layout_out.bitwidth(),
4556-
broadcast_op.getSourceType().getIntOrFloatBitWidth());
4556+
getTypeBitwidth(broadcast_op.getSourceType()).value());
45574557
// Broadcasting the scalar with narrower type involves first packing (32 /
45584558
// bitwidth) copies to i32, followed by broadcasting i32 to the target
45594559
// shape. Finally, bitcast i32 vector back to the original narrower type
45604560
// vector.
45614561
auto loc = broadcast_op.getLoc();
45624562
auto src_ty = broadcast_op.getSourceType();
4563-
auto bitwidth = src_ty.getIntOrFloatBitWidth();
4563+
auto bitwidth = getTypeBitwidth(src_ty).value();
45644564
auto unpacked_src = broadcast_op.getSource();
45654565
if (!src_ty.isSignlessInteger(bitwidth)) {
45664566
unpacked_src = builder.create<arith::BitcastOp>(
@@ -4904,12 +4904,12 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
49044904
case vector::CombiningKind::MAXSI: {
49054905
neutral = builder.getIntegerAttr(
49064906
element_type,
4907-
APInt::getSignedMinValue(element_type.getIntOrFloatBitWidth()));
4907+
APInt::getSignedMinValue(getTypeBitwidth(element_type).value()));
49084908
} break;
49094909
case vector::CombiningKind::MINSI: {
49104910
neutral = builder.getIntegerAttr(
49114911
element_type,
4912-
APInt::getSignedMaxValue(element_type.getIntOrFloatBitWidth()));
4912+
APInt::getSignedMaxValue(getTypeBitwidth(element_type).value()));
49134913
} break;
49144914
default:
49154915
return multi_reduction_op.emitOpError(
@@ -5007,7 +5007,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
50075007
}
50085008
for (int i = 0; i < 2; ++i) {
50095009
if (reduces[i] && src_layout.offsets()[i] == std::nullopt &&
5010-
element_type.getIntOrFloatBitWidth() != 32) {
5010+
getTypeBitwidth(element_type) != 32) {
50115011
return multi_reduction_op.emitOpError(
50125012
"Not implemented: Non-32-bit reductions over replicated axes");
50135013
}
@@ -5550,7 +5550,7 @@ Value copyOneRow(OpBuilder &builder, Value src_vreg, int src_row_idx,
55505550
if (dst_vreg) {
55515551
int bitwidth = 32 / packing_factor;
55525552
CHECK_EQ(bitwidth,
5553-
cast<VectorType>(dst_vreg.getType()).getElementTypeBitWidth());
5553+
getElementTypeBitwidth(cast<VectorType>(dst_vreg.getType())).value());
55545554
const VectorType i32_vreg_ty =
55555555
getNativeVregType(builder.getI32Type(), target_shape);
55565556
src_vreg = builder.create<tpu::BitcastVregOp>(src_vreg.getLoc(),
@@ -5856,10 +5856,10 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
58565856
pack_format = tpu::PackFormat::kInterleaved;
58575857
}
58585858
VectorType packed_vty = getNativeVregType(
5859-
builder.getIntegerType(src_ty.getElementTypeBitWidth()),
5859+
builder.getIntegerType(getElementTypeBitwidth(src_ty).value()),
58605860
ctx.target_shape);
58615861
VectorType unpacked_vty = getNativeVregType(
5862-
builder.getIntegerType(src_ty.getElementTypeBitWidth() * 2),
5862+
builder.getIntegerType(getElementTypeBitwidth(src_ty).value() * 2),
58635863
ctx.target_shape);
58645864
src_vregs.Each(
58655865
[&](absl::Span<const int64_t> src_vreg_indices, Value* src_vreg) {
@@ -6901,7 +6901,7 @@ LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op,
69016901
VectorType vty = rng_op.getResult().getType();
69026902
TPU_ASSERT_OP(vty.getElementType().isInteger());
69036903
// Only 32-bit output supported currently.
6904-
TPU_ASSERT_OP(vty.getElementType().getIntOrFloatBitWidth() == 32);
6904+
TPU_ASSERT_OP(getTypeBitwidth(vty.getElementType()) == 32);
69056905
xla::Array<Value> tiles(
69066906
layout_out.tileArrayShape(vty.getShape(), ctx.target_shape));
69076907
VectorType tile_ty = VectorType::get(ctx.target_shape, vty.getElementType());

0 commit comments

Comments
 (0)