@@ -131,7 +131,7 @@ FailureOr<TypedValue<MemRefType>> getInternalScratch(
131131 if (shape.back () % ctx.target_shape [1 ] != 0 ) {
132132 return emitError (loc, " Unaligned scratch shape on minormost dimension" );
133133 }
134- int packing = 32 / elem_ty. getIntOrFloatBitWidth ();
134+ int packing = 32 / getTypeBitwidth ( elem_ty). value ();
135135 int sublane_count = llvm::divideCeil (
136136 std::accumulate (shape.begin (), shape.end (), 1 , std::multiplies<int >()) /
137137 ctx.target_shape [1 ],
@@ -364,7 +364,7 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
364364 MLIRContext *mlir_ctx = func.getContext ();
365365 Block &entry_block = func.getBody ().front ();
366366 auto value_ty = cast<VectorType>(value.getType ());
367- if (value_ty.getElementType (). getIntOrFloatBitWidth ( ) != 32 ) {
367+ if (getTypeBitwidth ( value_ty.getElementType ()) != 32 ) {
368368 return func.emitOpError (" Not implemented: Only 32-bit constants supported" );
369369 }
370370 if (func->getAttr (" scratch_operands" )) {
@@ -1403,7 +1403,7 @@ FailureOr<xla::Array<Value>> unpackVregs(RewriteContext &ctx,
14031403 //
14041404 // 28 24 20 16 12 8 4 0 bit index
14051405 // yyyyyyyyyyyyyyyyxxxxxxxxxxxxxxxx
1406- if (res_vreg_ty. getElementTypeBitWidth ( ) == 32 ) {
1406+ if (getElementTypeBitwidth (res_vreg_ty ) == 32 ) {
14071407 // If the result vreg is 32-bit, we can just interleaved unpack the
14081408 // input vreg, as there are no multiple subelements to unpack.
14091409 *v = builder.create <UnpackSubelementsOp>(
@@ -1417,7 +1417,7 @@ FailureOr<xla::Array<Value>> unpackVregs(RewriteContext &ctx,
14171417 : cast<Type>(builder.getF32Type ()),
14181418 ctx.target_shape );
14191419 const int dst_packing_factor =
1420- 32 / res_vreg_ty. getElementTypeBitWidth ();
1420+ 32 / getElementTypeBitwidth ( res_vreg_ty). value ();
14211421 // `vreg_part` is with respect to result vreg bitwidth. Expand it to
14221422 // base on 32-bit.
14231423 const int vreg_part_unpacked_to_32b = vreg_part * dst_packing_factor;
@@ -1442,7 +1442,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> unpackVregs(
14421442 const xla::Array<Value> &input_vregs, VectorType input_ty,
14431443 VectorType result_ty, const VectorLayout &layout_in,
14441444 const std::array<int64_t , 2 > tiling_out) {
1445- const int unpacked_bitwidth = result_ty. getElementTypeBitWidth ();
1445+ const int unpacked_bitwidth = getElementTypeBitwidth ( result_ty). value ();
14461446 const LayoutOffsets offsets_out = alignedToVregSlice (
14471447 layout_in.offsets (), ctx.target_shape , unpacked_bitwidth, tiling_out);
14481448 const VectorLayout layout_out (unpacked_bitwidth, offsets_out, tiling_out,
@@ -1539,7 +1539,7 @@ LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op,
15391539 auto extui_op = cast<arith::ExtUIOp>(op);
15401540 const auto in_ty = cast<VectorType>(extui_op.getIn ().getType ());
15411541 const auto out_ty = cast<VectorType>(extui_op.getType ());
1542- const unsigned in_bitwidth = in_ty. getElementTypeBitWidth ();
1542+ const unsigned in_bitwidth = getElementTypeBitwidth ( in_ty). value ();
15431543 if (in_bitwidth == 1 ) {
15441544 return elementwise_op_rule (ctx, op, layouts_in, layouts_out);
15451545 }
@@ -1548,7 +1548,7 @@ LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op,
15481548 xla::Array<Value> output_vregs,
15491549 ext_op_rule_impl (ctx, builder, extui_op, *layouts_in.front (),
15501550 *layouts_out.front ()));
1551- unsigned out_bitwidth = out_ty. getElementTypeBitWidth ();
1551+ unsigned out_bitwidth = getElementTypeBitwidth ( out_ty). value ();
15521552 // Generate a mask to mask out the sign extension. e.g., for u8 -> u16,
15531553 // the mask is 0x00ff00ff.
15541554 unsigned mask = (1 << in_bitwidth) - 1 ;
@@ -1656,15 +1656,15 @@ FailureOr<xla::Array<Value>> packVregs(RewriteContext &ctx, OpBuilder &builder,
16561656 // achieve this, we can unpack all subelements in each part to 32-bit and
16571657 // then interleaved pack them into desired type.
16581658 SmallVector<Value> unpacks;
1659- if (input_ty.getElementType (). getIntOrFloatBitWidth ( ) == 32 ) {
1659+ if (getTypeBitwidth ( input_ty.getElementType ()) == 32 ) {
16601660 unpacks.append (parts.begin (), parts.end ());
16611661 } else {
16621662 VectorType unpacked_vty =
16631663 getNativeVregType (input_ty.getElementType ().isSignlessInteger ()
16641664 ? cast<Type>(builder.getI32Type ())
16651665 : cast<Type>(builder.getF32Type ()),
16661666 ctx.target_shape );
1667- const int32_t packing_factor = 32 / input_ty. getElementTypeBitWidth ();
1667+ const int32_t packing_factor = 32 / getElementTypeBitwidth ( input_ty). value ();
16681668 for (Value part : parts) {
16691669 if (part) {
16701670 for (int i = 0 ; i < packing_factor; ++i) {
@@ -1692,7 +1692,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> packVregs(
16921692 const xla::Array<Value> &input_vregs, VectorType input_ty,
16931693 VectorType result_ty, const VectorLayout &layout_in,
16941694 const std::array<int64_t , 2 > tiling_out, const LayoutOffsets offset_hints) {
1695- const int packed_bitwidth = result_ty. getElementTypeBitWidth ();
1695+ const int packed_bitwidth = getElementTypeBitwidth ( result_ty). value ();
16961696 const std::array<int64_t , 2 > unpacked_vreg_slice =
16971697 layout_in.vregSlice (ctx.target_shape );
16981698 const std::array<int64_t , 2 > packed_vreg_slice =
@@ -2600,7 +2600,7 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
26002600 " Not implemented: Unsupported matmul operand tiling" );
26012601 }
26022602 }
2603- if (acc.getType ().getElementType (). getIntOrFloatBitWidth ( ) != 32 ) {
2603+ if (getTypeBitwidth ( acc.getType ().getElementType ()) != 32 ) {
26042604 return op.emitOpError (" Not implemented: Non-32-bit matmul acc" );
26052605 }
26062606 const ArrayRef<int64_t > lhs_shape = lhs.getType ().getShape ();
@@ -2659,8 +2659,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op,
26592659 // second dim to be a multiple of mxu_size.
26602660 auto mxu_contracting_size = ctx.mxu_shape [0 ];
26612661 auto mxu_noncontracting_size = ctx.mxu_shape [1 ];
2662- if (lhs.getType (). getElementTypeBitWidth ( ) < 8 &&
2663- rhs.getType (). getElementTypeBitWidth ( ) < 8 ) {
2662+ if (getElementTypeBitwidth ( lhs.getType ()) < 8 &&
2663+ getElementTypeBitwidth ( rhs.getType ()) < 8 ) {
26642664 mxu_contracting_size *= 2 ;
26652665 }
26662666 auto rhs_row_size = mxu_contracting_size;
@@ -4530,7 +4530,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
45304530 broadcast_op.erase ();
45314531 return success ();
45324532 } else if (layout_out.bitwidth () == 32 &&
4533- broadcast_op.getSourceType (). getIntOrFloatBitWidth ( ) == 1 ) {
4533+ getTypeBitwidth ( broadcast_op.getSourceType ()) == 1 ) {
45344534 // Broadcasting the i1 scalar involves first converting i1 to i32, followed
45354535 // by broadcasting i32 to the target shape. Finally, the comparison with 0s
45364536 // yields the vmask.
@@ -4553,14 +4553,14 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
45534553 return success ();
45544554 } else if (layout_out.bitwidth () < 32 ) {
45554555 CHECK_EQ (layout_out.bitwidth (),
4556- broadcast_op.getSourceType (). getIntOrFloatBitWidth ());
4556+ getTypeBitwidth ( broadcast_op.getSourceType ()). value ());
45574557 // Broadcasting the scalar with narrower type involves first packing (32 /
45584558 // bitwidth) copies to i32, followed by broadcasting i32 to the target
45594559 // shape. Finally, bitcast i32 vector back to the original narrower type
45604560 // vector.
45614561 auto loc = broadcast_op.getLoc ();
45624562 auto src_ty = broadcast_op.getSourceType ();
4563- auto bitwidth = src_ty. getIntOrFloatBitWidth ();
4563+ auto bitwidth = getTypeBitwidth ( src_ty). value ();
45644564 auto unpacked_src = broadcast_op.getSource ();
45654565 if (!src_ty.isSignlessInteger (bitwidth)) {
45664566 unpacked_src = builder.create <arith::BitcastOp>(
@@ -4904,12 +4904,12 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
49044904 case vector::CombiningKind::MAXSI: {
49054905 neutral = builder.getIntegerAttr (
49064906 element_type,
4907- APInt::getSignedMinValue (element_type. getIntOrFloatBitWidth ()));
4907+ APInt::getSignedMinValue (getTypeBitwidth ( element_type). value ()));
49084908 } break ;
49094909 case vector::CombiningKind::MINSI: {
49104910 neutral = builder.getIntegerAttr (
49114911 element_type,
4912- APInt::getSignedMaxValue (element_type. getIntOrFloatBitWidth ()));
4912+ APInt::getSignedMaxValue (getTypeBitwidth ( element_type). value ()));
49134913 } break ;
49144914 default :
49154915 return multi_reduction_op.emitOpError (
@@ -5007,7 +5007,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
50075007 }
50085008 for (int i = 0 ; i < 2 ; ++i) {
50095009 if (reduces[i] && src_layout.offsets ()[i] == std::nullopt &&
5010- element_type. getIntOrFloatBitWidth ( ) != 32 ) {
5010+ getTypeBitwidth (element_type ) != 32 ) {
50115011 return multi_reduction_op.emitOpError (
50125012 " Not implemented: Non-32-bit reductions over replicated axes" );
50135013 }
@@ -5550,7 +5550,7 @@ Value copyOneRow(OpBuilder &builder, Value src_vreg, int src_row_idx,
55505550 if (dst_vreg) {
55515551 int bitwidth = 32 / packing_factor;
55525552 CHECK_EQ (bitwidth,
5553- cast<VectorType>(dst_vreg.getType ()). getElementTypeBitWidth ());
5553+ getElementTypeBitwidth ( cast<VectorType>(dst_vreg.getType ())). value ());
55545554 const VectorType i32_vreg_ty =
55555555 getNativeVregType (builder.getI32Type (), target_shape);
55565556 src_vreg = builder.create <tpu::BitcastVregOp>(src_vreg.getLoc (),
@@ -5856,10 +5856,10 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
58565856 pack_format = tpu::PackFormat::kInterleaved ;
58575857 }
58585858 VectorType packed_vty = getNativeVregType (
5859- builder.getIntegerType (src_ty. getElementTypeBitWidth ()),
5859+ builder.getIntegerType (getElementTypeBitwidth ( src_ty). value ()),
58605860 ctx.target_shape );
58615861 VectorType unpacked_vty = getNativeVregType (
5862- builder.getIntegerType (src_ty. getElementTypeBitWidth () * 2 ),
5862+ builder.getIntegerType (getElementTypeBitwidth ( src_ty). value () * 2 ),
58635863 ctx.target_shape );
58645864 src_vregs.Each (
58655865 [&](absl::Span<const int64_t > src_vreg_indices, Value* src_vreg) {
@@ -6901,7 +6901,7 @@ LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op,
69016901 VectorType vty = rng_op.getResult ().getType ();
69026902 TPU_ASSERT_OP (vty.getElementType ().isInteger ());
69036903 // Only 32-bit output supported currently.
6904- TPU_ASSERT_OP (vty.getElementType (). getIntOrFloatBitWidth ( ) == 32 );
6904+ TPU_ASSERT_OP (getTypeBitwidth ( vty.getElementType ()) == 32 );
69056905 xla::Array<Value> tiles (
69066906 layout_out.tileArrayShape (vty.getShape (), ctx.target_shape ));
69076907 VectorType tile_ty = VectorType::get (ctx.target_shape , vty.getElementType ());
0 commit comments