diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index ef300f038a..3ca105c506 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -465,15 +465,21 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_wide_stores( def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm( is_debug=False, - shape=(1024, 1024, 8192), - block=(128, 256, 256), + shape=(1024, 3072, 8192), + block=(128, 128, 256), eliminate_epilogue=False, ): - """Preshuffle-B MXFP4 GEMM with dynamic M, N, K.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( - shape, block, wave_shape=(1, 4), reorder_workgroups=False + """Preshuffle-B MXFP4 GEMM with coalesced dwordx4 stores (WaveASM backend). + + Same kernel as the LLVM coalesced-stores test but compiled through the + C++ WaveASM backend. Emits v_permlane16_swap_b32 + buffer_store_dwordx4. + """ + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b_wide_store( + shape, + block, + wave_shape=(1, 4), + reorder_workgroups=True, ) - # Make M, N, K dynamic so the compiler does not specialize on problem size. dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] for sym in dynamic_symbols: del options.subs[sym] @@ -483,7 +489,6 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm( options.use_wave_asm_backend = True options.wave_runtime = True options.eliminate_epilogue = eliminate_epilogue - options.dump_intermediates = "build/intermediates/" schedule = get_mxfp4_asymmetric_schedule( eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True ) @@ -491,10 +496,8 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm( options = set_default_run_config(options) gemm = wave_compile(options, gemm, schedule) - _run_mxfp_gemm_preshuffle(gemm, shape, all=True) - print( - "MXFP GEMM preshuffle-B 4-wave dynamic M, N, K (WaveASM backend) test passed!" - ) + _run_mxfp_gemm_preshuffle(gemm, shape, all=True, output_dtype=torch.bfloat16) + print("MXFP GEMM preshuffle-B 4-wave dwordx4 (WaveASM backend) test passed!") if __name__ == "__main__": diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 4b755d2123..ee28b1c68c 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -16,6 +16,7 @@ Attribute, BF16Type, DenseElementsAttr, + F32Type, IndexType, InsertionPoint, IntegerAttr, @@ -1323,9 +1324,29 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): use_llvm_store = flags != MemoryAccessFlags.NONE - if getattr(node, "_permlane_pack_global", False): + import os as _os + + _use_old_asm = _os.environ.get("WAVE_OLD_ASM_STORES", "") == "1" + _has_pack = getattr(node, "_permlane_pack_global", False) + + if _has_pack and _use_old_asm and emitter.options.backend == "asm": + _write_permlane_pack_to_global_asm( + emitter, + insert_vector, + kb_dest, + output_shape, + start_indices, + start_indices_wg, + start_indices_th, + get_custom(memory), + index, + ) + return + + if _has_pack: is_shared = get_custom(memory).type.address_space == SHARED_ADDRESS_SPACE - if not is_shared and isinstance(element_type, BF16Type): + is_bf16 = isinstance(element_type, BF16Type) + if not is_shared and is_bf16: role = getattr(node, "_permlane_pack_role", "unpaired") if role == "first": @@ -1438,34 +1459,61 @@ def _write_permlane_pair_to_global( ``_create_vec_read_write`` call would need to be split into two lane-divergent stores. """ - num_elems_a = vec_a.type.shape[0] if hasattr(vec_a.type, "shape") else 1 - num_elems_b = vec_b.type.shape[0] if hasattr(vec_b.type, "shape") else 1 - assert num_elems_a == 4 and num_elems_b == 4, ( - f"_write_permlane_pair_to_global expects 4 bf16 elements per thread " - f"per tile, got {num_elems_a} and {num_elems_b}." - ) - bf16_type = BF16Type.get() + f32_type = F32Type.get() + i16_type = IntegerType.get_signless(16) i32_type = IntegerType.get_signless(32) idx_type = IndexType.get() v2i32_type = VectorType.get([2], i32_type) v4i32_type = VectorType.get([4], i32_type) v8bf16_type = VectorType.get([8], bf16_type) - i32_a = vector_d.bitcast(v2i32_type, vec_a) - a_lo = vector_d.extract(i32_a, static_position=[0], dynamic_position=[]) - a_hi = vector_d.extract(i32_a, static_position=[1], dynamic_position=[]) - - i32_b = vector_d.bitcast(v2i32_type, vec_b) - b_lo = vector_d.extract(i32_b, static_position=[0], dynamic_position=[]) - b_hi = vector_d.extract(i32_b, static_position=[1], dynamic_position=[]) + is_asm = emitter.options.backend == "asm" + + def _extract_i32_dwords(vec_bf16): + """Extract two i32 dwords from a vector<4xbf16>.""" + if not is_asm: + i32_vec = vector_d.bitcast(v2i32_type, vec_bf16) + lo = vector_d.extract(i32_vec, static_position=[0], dynamic_position=[]) + hi = vector_d.extract(i32_vec, static_position=[1], dynamic_position=[]) + return lo, hi + + f32_src = None + defn = vec_bf16.owner + if defn is not None and defn.name == "arith.truncf": + f32_src = defn.operands[0] + assert f32_src is not None, ( + "ASM wide-store expects vec from arith.truncf; " + f"got {defn.name if defn else 'block arg'}" + ) + e0 = vector_d.extract(f32_src, static_position=[0], dynamic_position=[]) + e1 = vector_d.extract(f32_src, static_position=[1], dynamic_position=[]) + e2 = vector_d.extract(f32_src, static_position=[2], dynamic_position=[]) + e3 = vector_d.extract(f32_src, static_position=[3], dynamic_position=[]) + b0 = arith_d.truncf(bf16_type, e0) + b1 = arith_d.truncf(bf16_type, e1) + b2 = arith_d.truncf(bf16_type, e2) + b3 = arith_d.truncf(bf16_type, e3) + c16 = arith_d.constant(i32_type, 16) + + def _pack_pair(lo_bf16, hi_bf16): + lo_i32 = arith_d.extui(i32_type, arith_d.bitcast(i16_type, lo_bf16)) + hi_i32 = arith_d.extui(i32_type, arith_d.bitcast(i16_type, hi_bf16)) + return arith_d.ori(arith_d.shli(hi_i32, c16), lo_i32) + + return _pack_pair(b0, b1), _pack_pair(b2, b3) + + a_lo, a_hi = _extract_i32_dwords(vec_a) + b_lo, b_hi = _extract_i32_dwords(vec_b) swap_type = llvm_d.StructType.get_literal([i32_type, i32_type]) - # old_dst = a, src = b → result[0] = partner's b, result[1] = partner's a + # 2 swaps, extract both elements from each. + # Element [0] = new_dst = partner's src, Element [1] = new_src = partner's old_dst. + # The WaveASM handler creates V_PERMLANE16_SWAP_B32_PAIR when it detects + # that element [1] is extracted, properly modeling both hardware outputs. swap_lo = rocdl_d.permlane16_swap(swap_type, a_lo, b_lo, False, False) swap_hi = rocdl_d.permlane16_swap(swap_type, a_hi, b_hi, False, False) - partner_b_lo = llvm_d.extractvalue(i32_type, swap_lo, [0]) partner_a_lo = llvm_d.extractvalue(i32_type, swap_lo, [1]) partner_b_hi = llvm_d.extractvalue(i32_type, swap_hi, [0]) @@ -1490,29 +1538,45 @@ def _write_permlane_pair_to_global( elems_per_thread = arith_d.constant(idx_type, 4) # Lower lane uses tile A's address; upper lane uses tile B's address. - # Upper lane subtracts elems_per_thread from the last dim to align - # to the lower lane's column position (same as the single-write path). adj_th = list(start_indices_th_a) adj_full = list(start_indices_a) - for dim_idx in range(len(adj_th)): - if dim_idx == len(adj_th) - 1: - adj_b_th = arith_d.subi(start_indices_th_b[-1], elems_per_thread) - adj_b_full = arith_d.subi(start_indices_b[-1], elems_per_thread) + + if is_asm: + # ASM backend: the SRD must be uniform (goes through readfirstlane). + # Both tiles share the same workgroup, so wg_a == wg_b at runtime. + # Use wg_a for the SRD and th_b directly for the thread offset. + for dim_idx in range(len(adj_th)): + if dim_idx == len(adj_th) - 1: + adj_b_th = arith_d.subi(start_indices_th_b[-1], elems_per_thread) + adj_b_full = arith_d.subi(start_indices_b[-1], elems_per_thread) + else: + adj_b_th = start_indices_th_b[dim_idx] + adj_b_full = start_indices_b[dim_idx] adj_th[dim_idx] = arith_d.select(is_lower, adj_th[dim_idx], adj_b_th) adj_full[dim_idx] = arith_d.select(is_lower, adj_full[dim_idx], adj_b_full) - else: - adj_th[dim_idx] = arith_d.select( - is_lower, start_indices_th_a[dim_idx], start_indices_th_b[dim_idx] - ) - adj_full[dim_idx] = arith_d.select( - is_lower, start_indices_a[dim_idx], start_indices_b[dim_idx] + adj_wg = start_indices_wg_a + else: + for dim_idx in range(len(adj_th)): + if dim_idx == len(adj_th) - 1: + adj_b_th = arith_d.subi(start_indices_th_b[-1], elems_per_thread) + adj_b_full = arith_d.subi(start_indices_b[-1], elems_per_thread) + adj_th[dim_idx] = arith_d.select(is_lower, adj_th[dim_idx], adj_b_th) + adj_full[dim_idx] = arith_d.select( + is_lower, adj_full[dim_idx], adj_b_full + ) + else: + adj_th[dim_idx] = arith_d.select( + is_lower, start_indices_th_a[dim_idx], start_indices_th_b[dim_idx] + ) + adj_full[dim_idx] = arith_d.select( + is_lower, start_indices_a[dim_idx], start_indices_b[dim_idx] + ) + adj_wg_list = list(start_indices_wg_a) + for dim_idx in range(len(adj_wg_list)): + adj_wg_list[dim_idx] = arith_d.select( + is_lower, start_indices_wg_a[dim_idx], start_indices_wg_b[dim_idx] ) - - adj_wg = list(start_indices_wg_a) - for dim_idx in range(len(adj_wg)): - adj_wg[dim_idx] = arith_d.select( - is_lower, start_indices_wg_a[dim_idx], start_indices_wg_b[dim_idx] - ) + adj_wg = tuple(adj_wg_list) sel_output_shape = output_shape_a sel_memory_custom = memory_custom_a @@ -1526,7 +1590,7 @@ def _write_permlane_pair_to_global( wide_vec, None, tuple(adj_full), - tuple(adj_wg), + adj_wg, tuple(adj_th), 8, sel_memory_custom, @@ -1535,6 +1599,109 @@ def _write_permlane_pair_to_global( ) +def _write_permlane_pack_to_global_asm( + emitter: WaveEmitter, + insert_vector: Value, + kb_dest: Value, + output_shape: tuple, + start_indices: tuple, + start_indices_wg: tuple, + start_indices_th: tuple, + memory_custom, + index: dict, +): + """Single-tile wide store for ASM backend with explicit bf16 conversion. + + Each lane exchanges its own 4 bf16 values with a partner 16 apart, + assembles an 8-element wide vector, and writes buffer_store_dwordx4. + Both lane halves write identical data (duplicate stores). + """ + bf16_type = BF16Type.get() + f32_type = F32Type.get() + i16_type = IntegerType.get_signless(16) + i32_type = IntegerType.get_signless(32) + idx_type = IndexType.get() + v4i32_type = VectorType.get([4], i32_type) + v8bf16_type = VectorType.get([8], bf16_type) + + f32_src = None + defn = insert_vector.owner + if defn is not None and defn.name == "arith.truncf": + f32_src = defn.operands[0] + assert ( + f32_src is not None + ), "_write_permlane_pack_to_global_asm expects vec from arith.truncf" + + e0 = vector_d.extract(f32_src, static_position=[0], dynamic_position=[]) + e1 = vector_d.extract(f32_src, static_position=[1], dynamic_position=[]) + e2 = vector_d.extract(f32_src, static_position=[2], dynamic_position=[]) + e3 = vector_d.extract(f32_src, static_position=[3], dynamic_position=[]) + b0 = arith_d.truncf(bf16_type, e0) + b1 = arith_d.truncf(bf16_type, e1) + b2 = arith_d.truncf(bf16_type, e2) + b3 = arith_d.truncf(bf16_type, e3) + + c16 = arith_d.constant(i32_type, 16) + + def _pack_pair(lo_bf16, hi_bf16): + lo_i32 = arith_d.extui(i32_type, arith_d.bitcast(i16_type, lo_bf16)) + hi_i32 = arith_d.extui(i32_type, arith_d.bitcast(i16_type, hi_bf16)) + return arith_d.ori(arith_d.shli(hi_i32, c16), lo_i32) + + own_lo = _pack_pair(b0, b1) + own_hi = _pack_pair(b2, b3) + + swap_type = llvm_d.StructType.get_literal([i32_type, i32_type]) + partner_lo = llvm_d.extractvalue( + i32_type, rocdl_d.permlane16_swap(swap_type, own_lo, own_lo, False, False), [0] + ) + partner_hi = llvm_d.extractvalue( + i32_type, rocdl_d.permlane16_swap(swap_type, own_hi, own_hi, False, False), [0] + ) + + lane_in_wave = arith_d.remui(emitter.thread_ids[0], arith_d.constant(idx_type, 64)) + half_pos = arith_d.remui(lane_in_wave, arith_d.constant(idx_type, 32)) + is_lower = arith_d.cmpi( + arith_d.CmpIPredicate.ult, half_pos, arith_d.constant(idx_type, 16) + ) + + d0 = arith_d.select(is_lower, own_lo, partner_lo) + d1 = arith_d.select(is_lower, own_hi, partner_hi) + d2 = arith_d.select(is_lower, partner_lo, own_lo) + d3 = arith_d.select(is_lower, partner_hi, own_hi) + + wide_i32 = vector_d.from_elements(v4i32_type, [d0, d1, d2, d3]) + wide_vec = vector_d.bitcast(v8bf16_type, wide_i32) + + num_elems = 4 + elems_per_thread = arith_d.constant(idx_type, num_elems) + + adj_th = list(start_indices_th) + adj_th[-1] = arith_d.select( + is_lower, adj_th[-1], arith_d.subi(adj_th[-1], elems_per_thread) + ) + + adj_full = list(start_indices) + adj_full[-1] = arith_d.select( + is_lower, adj_full[-1], arith_d.subi(adj_full[-1], elems_per_thread) + ) + + _create_vec_read_write( + emitter, + output_shape, + kb_dest, + wide_vec, + None, + tuple(adj_full), + start_indices_wg, + tuple(adj_th), + 8, + memory_custom, + None, + node_index=index, + ) + + def assume_index_subgroup_uniform(value: Value, element_type: IrType) -> Value: res = gpu_d.subgroup_broadcast(value, gpu_d.BroadcastType.first_active_lane) return res diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index a4f569ef2c..d41468f43a 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -530,6 +530,10 @@ def build_graph_passes( partial(decompose_topk_ops, trace, launchable.constraints), ] + # Tags bf16 global writes that use source/target transposition + # (wide_stores layout) with _permlane_pack_global so the codegen emits + # v_permlane16_swap_b32 + buffer_store_dwordx4. Runs unconditionally; + # writes without source/target are left untouched. from .wide_store_coalescing import coalesce_wide_stores graph_passes.append(partial(coalesce_wide_stores, trace)) @@ -1334,6 +1338,11 @@ def _generate_asm_code(mb, options): mlir_file.write(kernel_mlir) mlir_path = mlir_file.name + # DEBUG: save a copy of the MLIR IR for inspection + import shutil as _shutil + + _shutil.copy2(mlir_path, "/tmp/waveasm_input.mlir") + try: base_passes = [ "--mlir-cse", diff --git a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py index 3301d46c55..44e6e50dec 100755 --- a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py +++ b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py @@ -404,10 +404,6 @@ def _get_tagged_mxfp4_gemm_preshuffle_b_impl( K_PACKED = tkl.sym.K_PACKED K_SCALE_SHUFFLED = tkl.sym.K_SCALE_SHUFFLED - if wide_stores: - m_symbol = tkl.sym.m_symbol - n_symbol = tkl.sym.n_symbol - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.TilingConstraint(K, BLOCK_K)] @@ -426,6 +422,8 @@ def _get_tagged_mxfp4_gemm_preshuffle_b_impl( constraints += [tkw.Assumption(K > BLOCK_K * 6)] if wide_stores: + m_symbol = tkl.sym.m_symbol + n_symbol = tkl.sym.n_symbol constraints += [tkw.IteratorBindings({m_symbol: M, n_symbol: N})] constraints += [tkw.Assumption(Eq(M % BLOCK_M, 0))] constraints += [tkw.Assumption(Eq(N % BLOCK_N, 0))] diff --git a/waveasm/include/waveasm/Dialect/WaveASMOps.td b/waveasm/include/waveasm/Dialect/WaveASMOps.td index c8bac3147b..1abe6c6a7b 100644 --- a/waveasm/include/waveasm/Dialect/WaveASMOps.td +++ b/waveasm/include/waveasm/Dialect/WaveASMOps.td @@ -521,6 +521,48 @@ def WaveASM_V_READFIRSTLANE_B32 : WAVEASMOp<"v_readfirstlane_b32", [Pure]> { let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst)"; } +// Lane swap operations (VGPR <-> VGPR across lanes) +// [Pure] follows the dialect convention (same as V_READFIRSTLANE_B32 etc.) +// even though the hardware clobbers the source register. The handler +// inserts a v_mov_b32 copy before the swap to preserve the original value, +// so from the SSA perspective the op is side-effect-free. +def WaveASM_V_PERMLANE16_SWAP_B32 : WAVEASMOp<"v_permlane16_swap_b32", [Pure]> { + let summary = "Swap VGPR values between lanes 16 apart"; + let description = [{ + Exchanges a 32-bit value between paired lanes that are 16 positions apart. + Lane i swaps with lane i^16 within each 32-lane half-wave. + The hardware writes the swapped value to dst AND clobbers src. + The handler inserts a v_mov_b32 copy before the swap to preserve the + original source value for downstream uses. + }]; + let arguments = (ins WaveASM_AnyVGPR:$src); + let results = (outs WaveASM_AnyVGPR:$dst); + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst)"; +} + +// Dual-output lane swap: exposes BOTH hardware outputs of +// v_permlane16_swap_b32 to SSA. The hardware exchanges values between +// lanes i and i^16: +// new_dst = partner's old src +// new_src = partner's old_dst +// The assembly emitter copies old_dst and src into scratch VGPRs, executes +// the swap on those scratches, then copies both results to the allocated +// output registers. This avoids the clobbering problem of the single-output +// V_PERMLANE16_SWAP_B32 and enables paired wide stores without duplicate +// writes. +def WaveASM_V_PERMLANE16_SWAP_B32_PAIR : WAVEASMOp<"v_permlane16_swap_b32_pair", [Pure]> { + let summary = "Swap VGPR values between lanes 16 apart (dual output)"; + let description = [{ + Like v_permlane16_swap_b32 but models both hardware outputs: + new_dst = partner lane's src + new_src = partner lane's old_dst + Both inputs are preserved (copied to scratch before the swap). + }]; + let arguments = (ins WaveASM_AnyVGPR:$old_dst, WaveASM_AnyVGPR:$src); + let results = (outs WaveASM_AnyVGPR:$new_dst, WaveASM_AnyVGPR:$new_src); + let assemblyFormat = "$old_dst `,` $src attr-dict `:` type($old_dst) `,` type($src) `->` type($new_dst) `,` type($new_src)"; +} + // Bit operations def WaveASM_V_NOT_B32 : VALUUnaryOp<"v_not_b32">; def WaveASM_V_NOT_B64 : VALUUnaryOp<"v_not_b64">; diff --git a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h index cbfe19176e..694790586d 100644 --- a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h +++ b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h @@ -52,8 +52,24 @@ class ValueMapper { return valueMap.contains(mlirValue); } + /// Map a sub-element of a struct-typed MLIR value (for llvm.extractvalue). + void setExtraMapping(mlir::Value structVal, int64_t index, + mlir::Value elemVal) { + extraMap[{structVal, index}] = elemVal; + } + + /// Get a sub-element of a struct-typed MLIR value. + std::optional getExtraMapping(mlir::Value structVal, + int64_t index) const { + auto it = extraMap.find({structVal, index}); + if (it != extraMap.end()) + return it->second; + return std::nullopt; + } + private: llvm::DenseMap valueMap; + llvm::DenseMap, mlir::Value> extraMap; }; //===----------------------------------------------------------------------===// diff --git a/waveasm/lib/Transforms/AssemblyEmitter.cpp b/waveasm/lib/Transforms/AssemblyEmitter.cpp index a2e1cf7a1e..52c45c9823 100644 --- a/waveasm/lib/Transforms/AssemblyEmitter.cpp +++ b/waveasm/lib/Transforms/AssemblyEmitter.cpp @@ -21,9 +21,12 @@ #include "waveasm/Transforms/RegAlloc.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" +#define DEBUG_TYPE "waveasm-assembly-emitter" + #include using namespace mlir; @@ -972,6 +975,91 @@ std::optional KernelGenerator::generateOp(Operation *op) { return formatter.format("v_cvt_pk_bf16_f32", operands); }) + // V_PERMLANE16_SWAP_B32: swap lanes 16 apart. + // The hardware clobbers BOTH dst and src. The handler inserts a + // v_mov_b32 copy before the swap, so the allocator should always + // assign dst != src. The fallback uses two scratch VGPRs above + // the allocator's range (kScratchVGPR and kScratchVGPR + 1). + .Case( + [&](V_PERMLANE16_SWAP_B32 swapOp) -> std::optional { + std::string dst = resolveValue(swapOp.getDst()); + std::string src = resolveValue(swapOp.getSrc()); + if (dst != src) { + llvm::SmallVector operands = {dst, src}; + return formatter.format("v_permlane16_swap_b32", operands); + } + // dst==src fallback: use two scratch VGPRs both above allocator + // range + std::string scratch0 = formatVGPRRange(kScratchVGPR, 1); + std::string scratch1 = formatVGPRRange(kScratchVGPR + 1, 1); + peakVGPRs = std::max(peakVGPRs, kScratchVGPR + 2); + invalidateScratchCache(); + return " v_mov_b32 " + scratch0 + ", " + src + "\n" + + " v_mov_b32 " + scratch1 + ", " + src + "\n" + + " s_nop 1\n" + " v_permlane16_swap_b32 " + dst + ", " + + scratch1 + "\n" + " v_mov_b32 " + src + ", " + scratch0; + }) + + // V_PERMLANE16_SWAP_B32_PAIR: dual-output swap for paired wide stores. + // Copies both inputs to scratch VGPRs, executes the swap on the + // scratches, then copies both results to the allocated output registers. + .Case( + [&](V_PERMLANE16_SWAP_B32_PAIR pairOp) -> std::optional { + std::string newDst = resolveValue(pairOp.getNewDst()); + std::string newSrc = resolveValue(pairOp.getNewSrc()); + std::string oldDst = resolveValue(pairOp.getOldDst()); + std::string src = resolveValue(pairOp.getSrc()); + + // Always use scratch VGPRs: the in-place swap clobbers the input + // registers, but downstream selects still need the originals. + std::string s0 = formatVGPRRange(kScratchVGPR, 1); + std::string s1 = formatVGPRRange(kScratchVGPR + 1, 1); + peakVGPRs = std::max(peakVGPRs, kScratchVGPR + 2); + invalidateScratchCache(); + return " v_mov_b32 " + s0 + ", " + oldDst + "\n" + " v_mov_b32 " + + s1 + ", " + src + "\n" + " s_nop 1\n" + + " v_permlane16_swap_b32 " + s0 + ", " + s1 + "\n" + + " v_mov_b32 " + newDst + ", " + s0 + "\n" + " v_mov_b32 " + + newSrc + ", " + s1; + }) + + // V_ACCVGPR_READ_B32: unroll multi-register reads into scalar ops + .Case( + [&](V_ACCVGPR_READ_B32 readOp) -> std::optional { + Value dst = readOp.getDst(); + Value src = readOp.getSrc(); + int64_t dstSize = getRegSize(dst.getType()); + int64_t srcSize = getRegSize(src.getType()); + int64_t size = std::max(dstSize, srcSize); + if (size <= 1) { + return emitDefaultFormat(readOp, "v_accvgpr_read_b32"); + } + int64_t dstBase = -1, srcBase = -1; + if (auto pv = dyn_cast(dst.getType())) + dstBase = pv.getIndex(); + else if (isVirtualRegType(dst.getType())) + dstBase = mapping.getPhysReg(dst); + if (auto pa = dyn_cast(src.getType())) + srcBase = pa.getIndex(); + else if (isVirtualRegType(src.getType())) + srcBase = mapping.getPhysReg(src); + if (dstBase < 0 || srcBase < 0) { + readOp.emitError( + "V_ACCVGPR_READ_B32: cannot resolve base registers for " + "multi-register unroll (dstBase=") + << dstBase << " srcBase=" << srcBase << ")"; + return std::nullopt; + } + std::string lines; + for (int64_t i = 0; i < size; ++i) { + if (i > 0) + lines += "\n"; + lines += " v_accvgpr_read_b32 v" + std::to_string(dstBase + i) + + ", a" + std::to_string(srcBase + i); + } + return lines; + }) + // Carry ops: on GFX9, carry-out is implicit VCC. // v_add_co_u32: dst, vcc, src0, src1 // v_addc_co_u32: dst, vcc, src0, src1, vcc (carry-in). diff --git a/waveasm/lib/Transforms/ExtractScalarization.cpp b/waveasm/lib/Transforms/ExtractScalarization.cpp index f1c1ff99b4..dee544992e 100644 --- a/waveasm/lib/Transforms/ExtractScalarization.cpp +++ b/waveasm/lib/Transforms/ExtractScalarization.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/DebugLog.h" @@ -39,7 +40,8 @@ using namespace mlir; namespace { /// Read element `index` from a DenseElementsAttr with integer elements, -/// returning its value as int64_t. +/// returning its value as int64_t. Returns nullopt for float constants; +/// callers that need float scalarization should add a float-element path. static std::optional getIntElement(DenseElementsAttr dense, int64_t index) { auto vecType = dyn_cast(dense.getType()); @@ -73,6 +75,100 @@ static LogicalResult peelSelect(Value v, Value &cond, Value &trueVal, return failure(); } +/// Recursively scalarize a vector value at element `k`. +/// Propagates extract through elementwise ops, broadcasts, and constants. +static Value scalarizeAtIndex(Value v, int64_t k, OpBuilder &rewriter, + Location loc); + +/// Scalarize a single operand at index k. For vector operands, recurse. +/// For scalar operands (already scalar or broadcast source), return as-is. +static Value scalarizeOperand(Value v, int64_t k, OpBuilder &rewriter, + Location loc) { + if (!isa(v.getType())) + return v; + return scalarizeAtIndex(v, k, rewriter, loc); +} + +static Value scalarizeAtIndex(Value v, int64_t k, OpBuilder &rewriter, + Location loc) { + if (!isa(v.getType())) + return v; + + Operation *def = v.getDefiningOp(); + if (!def) + return nullptr; + + // broadcast(scalar) → scalar + if (auto bcast = dyn_cast(def)) + return bcast.getSource(); + + // Dense constant → extract element k + if (auto constOp = dyn_cast(def)) { + auto dense = dyn_cast(constOp.getValue()); + if (!dense) + return nullptr; + auto elemVal = getIntElement(dense, k); + if (!elemVal) + return nullptr; + Type scalarType = getElementTypeOrSelf(v.getType()); + return arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(scalarType, *elemVal)); + } + + // Binary elementwise ops: propagate extract to both operands + if (isa(def)) { + Value lhs = scalarizeOperand(def->getOperand(0), k, rewriter, loc); + Value rhs = scalarizeOperand(def->getOperand(1), k, rewriter, loc); + if (!lhs || !rhs) + return nullptr; + return rewriter + .create(loc, def->getName().getIdentifier(), ValueRange{lhs, rhs}, + lhs.getType(), def->getAttrs()) + ->getResult(0); + } + + // cmpi: propagate to both operands, result is scalar i1 + if (auto cmpOp = dyn_cast(def)) { + Value lhs = scalarizeOperand(cmpOp.getLhs(), k, rewriter, loc); + Value rhs = scalarizeOperand(cmpOp.getRhs(), k, rewriter, loc); + if (!lhs || !rhs) + return nullptr; + return arith::CmpIOp::create(rewriter, loc, cmpOp.getPredicate(), lhs, rhs); + } + + // select: propagate to condition, true, false + if (auto selOp = dyn_cast(def)) { + Value cond = scalarizeOperand(selOp.getCondition(), k, rewriter, loc); + Value trueVal = scalarizeOperand(selOp.getTrueValue(), k, rewriter, loc); + Value falseVal = scalarizeOperand(selOp.getFalseValue(), k, rewriter, loc); + if (!cond || !trueVal || !falseVal) + return nullptr; + return arith::SelectOp::create(rewriter, loc, cond, trueVal, falseVal); + } + + // index_cast: propagate through + if (auto castOp = dyn_cast(def)) { + Value inner = scalarizeOperand(castOp.getIn(), k, rewriter, loc); + if (!inner) + return nullptr; + Type resultScalar = getElementTypeOrSelf(castOp.getResult().getType()); + return arith::IndexCastOp::create(rewriter, loc, resultScalar, inner); + } + + // truncf: propagate through + if (auto truncOp = dyn_cast(def)) { + Value inner = scalarizeOperand(truncOp.getIn(), k, rewriter, loc); + if (!inner) + return nullptr; + Type resultScalar = getElementTypeOrSelf(truncOp.getResult().getType()); + return arith::TruncFOp::create(rewriter, loc, resultScalar, inner); + } + + return nullptr; +} + /// Try to match and scalarize: /// vector.extract[k] ( /// index_cast? ( @@ -81,6 +177,7 @@ static LogicalResult peelSelect(Value v, Value &cond, Value &trueVal, /// ) /// ) /// ) +/// Falls back to general recursive scalarization for other patterns. /// Returns the scalar replacement value, or nullptr on failure. static Value tryScalarize(vector::ExtractOp extractOp, OpBuilder &rewriter) { auto staticPos = extractOp.getStaticPosition(); @@ -179,41 +276,42 @@ static Value tryScalarize(vector::ExtractOp extractOp, OpBuilder &rewriter) { scalarFalse); } - // Non-select path: extract[k]( index_cast?( addi(broadcast, dense) ) ) - auto addOp = preIndexCast.getDefiningOp(); - if (!addOp) - return nullptr; - - Value broadcastSide = nullptr; - DenseElementsAttr dense; - - for (int swap = 0; swap < 2; ++swap) { - Value lhs = swap ? addOp.getRhs() : addOp.getLhs(); - Value rhs = swap ? addOp.getLhs() : addOp.getRhs(); - - auto bcast = lhs.getDefiningOp(); - auto constOp = rhs.getDefiningOp(); - if (bcast && constOp) { - auto d = dyn_cast(constOp.getValue()); - if (d && !d.isSplat()) { - broadcastSide = bcast.getSource(); - dense = d; - break; + // Fast path for the common pattern extract[k]( index_cast?( addi(broadcast, + // dense) ) ). The general scalarizeAtIndex fallback handles this too, but + // this avoids recursion for the most frequent case. + if (auto addOp = preIndexCast.getDefiningOp()) { + Value broadcastSide = nullptr; + DenseElementsAttr dense; + + for (int swap = 0; swap < 2; ++swap) { + Value lhs = swap ? addOp.getRhs() : addOp.getLhs(); + Value rhs = swap ? addOp.getLhs() : addOp.getRhs(); + + auto bcast = lhs.getDefiningOp(); + auto constOp = rhs.getDefiningOp(); + if (bcast && constOp) { + auto d = dyn_cast(constOp.getValue()); + if (d && !d.isSplat()) { + broadcastSide = bcast.getSource(); + dense = d; + break; + } + } + } + if (broadcastSide) { + auto elemVal = getIntElement(dense, k); + if (elemVal) { + Type scalarType = broadcastSide.getType(); + Value elemConst = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(scalarType, *elemVal)); + return arith::AddIOp::create(rewriter, loc, scalarType, broadcastSide, + elemConst); } } } - if (!broadcastSide) - return nullptr; - - auto elemVal = getIntElement(dense, k); - if (!elemVal) - return nullptr; - Type scalarType = broadcastSide.getType(); - Value elemConst = arith::ConstantOp::create( - rewriter, loc, rewriter.getIntegerAttr(scalarType, *elemVal)); - return arith::AddIOp::create(rewriter, loc, scalarType, broadcastSide, - elemConst); + // General fallback: recursively scalarize through elementwise ops. + return scalarizeAtIndex(src, k, rewriter, loc); } struct ExtractScalarizationPass @@ -273,7 +371,10 @@ struct ExtractScalarizationPass // Vector-typed arith ops (addi, index_cast, select on vectors). if (op->getNumResults() == 1 && isa(op->getResult(0).getType()) && - isa(op)) + isa(op)) return true; // Dense vector constants. if (auto constOp = dyn_cast(op)) { diff --git a/waveasm/lib/Transforms/Liveness.cpp b/waveasm/lib/Transforms/Liveness.cpp index 236e7314c9..42ec188a5b 100644 --- a/waveasm/lib/Transforms/Liveness.cpp +++ b/waveasm/lib/Transforms/Liveness.cpp @@ -477,6 +477,57 @@ LivenessInfo computeLiveness(ProgramOp program) { } } + // Check for permlane swap ops — used by passes 2c and 3a1. + bool hasPermlaneSwaps = false; + for (auto *op : ops) { + if (isa(op)) { + hasPermlaneSwaps = true; + break; + } + } + + // Pass 2c: Extend V_PERMLANE16_SWAP_B32 source live ranges + // (defense-in-depth). The handler already preserves the original source via + // v_mov_b32 (srcCopy) and remaps future MLIR lookups to the copy, so srcVal + // has no uses after the swap op. This extension is largely redundant — the + // copy naturally keeps srcVal alive through the copy point — but is retained + // as a safety net in case the remap is ever bypassed. + if (hasPermlaneSwaps) { + for (auto *op : ops) { + if (auto swapOp = dyn_cast(op)) { + Value src = swapOp.getSrc(); + Value dst = swapOp.getDst(); + auto srcIt = info.ranges.find(src); + auto dstIt = info.ranges.find(dst); + if (srcIt != info.ranges.end() && dstIt != info.ranges.end()) { + srcIt->second.end = std::max(srcIt->second.end, dstIt->second.end); + } + } + // Extend pair op inputs to cover the lifetime of both outputs. + // The assembly emitter reads both inputs at the swap point, so they + // must remain live until both outputs have been consumed. + if (auto pairOp = dyn_cast(op)) { + Value oldDst = pairOp.getOldDst(); + Value src = pairOp.getSrc(); + Value newDst = pairOp.getNewDst(); + Value newSrc = pairOp.getNewSrc(); + int64_t maxEnd = 0; + auto newDstIt = info.ranges.find(newDst); + auto newSrcIt = info.ranges.find(newSrc); + if (newDstIt != info.ranges.end()) + maxEnd = std::max(maxEnd, newDstIt->second.end); + if (newSrcIt != info.ranges.end()) + maxEnd = std::max(maxEnd, newSrcIt->second.end); + auto oldDstIt = info.ranges.find(oldDst); + auto srcIt = info.ranges.find(src); + if (oldDstIt != info.ranges.end()) + oldDstIt->second.end = std::max(oldDstIt->second.end, maxEnd); + if (srcIt != info.ranges.end()) + srcIt->second.end = std::max(srcIt->second.end, maxEnd); + } + } + } + // Note: Pass 3 (CFG-based backward dataflow liveness extension) has been // removed. It was needed for the old label-based control flow path where // loop back-edges were represented as explicit branch instructions. With @@ -551,6 +602,30 @@ LivenessInfo computeLiveness(ProgramOp program) { } }); + // Pass 3a1: Extend pack result ranges to cover downstream users of extracted + // sub-values. Only needed when V_PERMLANE16_SWAP_B32 ops exist, since those + // allocate dst registers that can conflict with post-hoc pack sub-register + // assignments. When no permlane swaps exist, the standard pack handling + // (pass 3a) is sufficient. + if (hasPermlaneSwaps) { + for (auto *op : ops) { + auto extractOp = dyn_cast(op); + if (!extractOp) + continue; + Value source = extractOp.getVector(); + auto sourceIt = info.ranges.find(source); + if (sourceIt == info.ranges.end()) + continue; + Value extractResult = extractOp.getResult(); + auto useIt = info.usePoints.find(extractResult); + if (useIt != info.usePoints.end()) { + for (int64_t use : useIt->second) { + sourceIt->second.end = std::max(sourceIt->second.end, use); + } + } + } + } + // Pass 3a2: InsertOp pass -- treat insert result as an alias of the source // vector, but keep the inserted value in the allocator worklist. // diff --git a/waveasm/lib/Transforms/TranslateFromMLIR.cpp b/waveasm/lib/Transforms/TranslateFromMLIR.cpp index 8edeba6b0e..8f8de05273 100644 --- a/waveasm/lib/Transforms/TranslateFromMLIR.cpp +++ b/waveasm/lib/Transforms/TranslateFromMLIR.cpp @@ -778,6 +778,7 @@ LogicalResult handleArithShRSI(Operation *op, TranslationContext &ctx); LogicalResult handleArithExtUI(Operation *op, TranslationContext &ctx); LogicalResult handleArithExtSI(Operation *op, TranslationContext &ctx); LogicalResult handleArithTruncI(Operation *op, TranslationContext &ctx); +LogicalResult handleArithBitcast(Operation *op, TranslationContext &ctx); LogicalResult handleArithMinSI(Operation *op, TranslationContext &ctx); LogicalResult handleArithMaxSI(Operation *op, TranslationContext &ctx); LogicalResult handleArithMinUI(Operation *op, TranslationContext &ctx); @@ -1645,12 +1646,22 @@ LogicalResult handleVectorStore(Operation *op, TranslationContext &ctx) { // BF16 store conversion: the arith.truncf handler defers vector f32->bf16 // conversion, so data registers may still contain f32 values. + // Skip conversion if data is already packed bf16 (e.g. from permlane + // coalescing path where the data went through v_cvt_pk_bf16_f32 + + // PackOp). Packed bf16 uses numElems/2 registers (2 bf16 per dword); + // unpacked f32 uses numElems registers. if (elementType.isBF16() && data.has_value()) { int64_t numElems = vectorType.getNumElements(); - auto [converted, newNumBytes] = - convertF32ToBF16ForStore(*data, numElems, ctx, builder, loc); - data = converted; - numBytes = newNumBytes; + int64_t dataRegs = getRegSize(data->getType()); + bool alreadyPacked = (dataRegs <= numElems / 2); + if (!alreadyPacked) { + auto [converted, newNumBytes] = + convertF32ToBF16ForStore(*data, numElems, ctx, builder, loc); + data = converted; + numBytes = newNumBytes; + } else { + numBytes = numElems * 2; + } } // Split large stores into multiple buffer_store_dwordx4 (16 bytes each) @@ -1877,6 +1888,9 @@ LogicalResult handleReadFirstLane(Operation *op, TranslationContext &ctx); LogicalResult handleROCDLSBarrier(Operation *op, TranslationContext &ctx); LogicalResult handleROCDLSetPrio(Operation *op, TranslationContext &ctx); LogicalResult handleSWaitcnt(Operation *op, TranslationContext &ctx); +LogicalResult handlePermlane16Swap(Operation *op, TranslationContext &ctx); +LogicalResult handleLLVMExtractValue(Operation *op, TranslationContext &ctx); +LogicalResult handleVectorFromElements(Operation *op, TranslationContext &ctx); //===----------------------------------------------------------------------===// // OpHandlerRegistry Implementation @@ -1982,6 +1996,7 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) { REGISTER_HANDLER(arith::ExtUIOp, handleArithExtUI); REGISTER_HANDLER(arith::ExtSIOp, handleArithExtSI); REGISTER_HANDLER(arith::TruncIOp, handleArithTruncI); + REGISTER_HANDLER(arith::BitcastOp, handleArithBitcast); REGISTER_HANDLER(arith::MinSIOp, handleArithMinSI); REGISTER_HANDLER(arith::MaxSIOp, handleArithMaxSI); REGISTER_HANDLER(arith::MinUIOp, handleArithMinUI); @@ -2019,6 +2034,7 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) { REGISTER_HANDLER(vector::InsertOp, handleVectorInsert); REGISTER_HANDLER(vector::ShapeCastOp, handleVectorShapeCast); REGISTER_HANDLER(vector::BitCastOp, handleVectorBitCast); + REGISTER_HANDLER(vector::FromElementsOp, handleVectorFromElements); REGISTER_HANDLER(vector::TransferReadOp, handleVectorTransferRead); REGISTER_HANDLER(vector::TransferWriteOp, handleVectorTransferWrite); REGISTER_HANDLER(vector::FMAOp, handleVectorFma); @@ -2040,6 +2056,10 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) { REGISTER_HANDLER(ROCDL::SBarrierOp, handleROCDLSBarrier); REGISTER_HANDLER(ROCDL::SetPrioOp, handleROCDLSetPrio); REGISTER_HANDLER(ROCDL::SWaitcntOp, handleSWaitcnt); + REGISTER_HANDLER(ROCDL::Permlane16SwapOp, handlePermlane16Swap); + + // LLVM dialect + REGISTER_HANDLER(LLVM::ExtractValueOp, handleLLVMExtractValue); // IREE/Stream dialect (unregistered operations) registerHandler(mlir::OperationName("stream.binding.subspan", ctx), diff --git a/waveasm/lib/Transforms/VGPRCompaction.cpp b/waveasm/lib/Transforms/VGPRCompaction.cpp index 3f01184666..beb03e1c99 100644 --- a/waveasm/lib/Transforms/VGPRCompaction.cpp +++ b/waveasm/lib/Transforms/VGPRCompaction.cpp @@ -293,8 +293,8 @@ static bool overlaps(const PhysVGPRRange &a, const PhysVGPRRange &b) { } static llvm::DenseMap -computeCompaction(llvm::SmallVectorImpl &ranges, - int64_t maxRegs) { +computeCompaction(llvm::SmallVectorImpl &ranges, int64_t maxRegs, + int64_t scratchCount) { llvm::SmallVector order(ranges.size()); std::iota(order.begin(), order.end(), 0); llvm::sort(order, [&](int64_t a, int64_t b) { @@ -324,8 +324,9 @@ computeCompaction(llvm::SmallVectorImpl &ranges, int64_t align = r.alignment; occupied.reset(); - if (kScratchVGPR < maxRegs) - occupied.set(kScratchVGPR); + for (int64_t s = 0; s < scratchCount; ++s) + if (kScratchVGPR + s < maxRegs) + occupied.set(kScratchVGPR + s); for (size_t j = 0; j < ranges.size(); ++j) { if (newAssignment[j] < 0) continue; @@ -523,7 +524,12 @@ struct WAVEASMVGPRCompaction // tighter than a hardcoded 512 and adapts to the actual allocation. int64_t maxRegs = maxBefore; - auto oldToNew = computeCompaction(ranges, maxRegs); + // Only reserve the second scratch VGPR when the kernel uses PAIR ops. + bool hasPairOps = false; + program->walk([&](V_PERMLANE16_SWAP_B32_PAIR) { hasPairOps = true; }); + int64_t scratchCount = hasPairOps ? 2 : 1; + + auto oldToNew = computeCompaction(ranges, maxRegs, scratchCount); int64_t maxAfter = 0; for (const auto &r : ranges) { diff --git a/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp b/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp index 6d99671326..41f632a3b9 100644 --- a/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/AMDGPUHandlers.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" @@ -1289,6 +1290,92 @@ LogicalResult handleReadFirstLane(Operation *op, TranslationContext &ctx) { return success(); } +LogicalResult handlePermlane16Swap(Operation *op, TranslationContext &ctx) { + auto swapOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + auto src = ctx.getMapper().getMapped(swapOp.getSrc()); + if (!src) { + return op->emitError("permlane16_swap source not mapped"); + } + + // Check whether downstream code extracts element [1] (partner's old_dst). + // If so, we need the dual-output pair op to properly model both hardware + // outputs in SSA. Otherwise, the single-output op with the mapValue + // workaround suffices (and avoids extra scratch register traffic). + bool needsBothElements = false; + for (auto user : swapOp.getRes().getUsers()) { + if (auto ev = dyn_cast(user)) { + if (!ev.getPosition().empty() && ev.getPosition()[0] == 1) { + needsBothElements = true; + break; + } + } + } + + if (needsBothElements) { + // --- Dual-output pair path --- + // Both old_dst and src are explicit inputs; both new_dst and new_src + // are explicit outputs. The assembly emitter uses scratch VGPRs to + // execute the swap without clobbering either input. + auto oldDst = ctx.getMapper().getMapped(swapOp.getOld()); + if (!oldDst) { + return op->emitError("permlane16_swap old_dst not mapped"); + } + + Value oldDstVal = ensureVGPR(*oldDst, ctx, builder, loc); + Value srcVal = ensureVGPR(*src, ctx, builder, loc); + + auto dstType = ctx.createVRegType(); + auto srcType = ctx.createVRegType(); + auto pairOp = V_PERMLANE16_SWAP_B32_PAIR::create( + builder, loc, dstType, srcType, oldDstVal, srcVal); + Value newDst = pairOp.getNewDst(); + Value newSrc = pairOp.getNewSrc(); + + ctx.getMapper().mapValue(swapOp.getRes(), newDst); + ctx.getMapper().setExtraMapping(swapOp.getRes(), 0, newDst); + ctx.getMapper().setExtraMapping(swapOp.getRes(), 1, newSrc); + + return success(); + } + + // --- Single-output path (original handler) --- + Value srcVal = ensureVGPR(*src, ctx, builder, loc); + + Value srcCopy = V_MOV_B32::create(builder, loc, ctx.createVRegType(), srcVal); + auto dstType = ctx.createVRegType(); + Value swapped = V_PERMLANE16_SWAP_B32::create(builder, loc, dstType, srcVal); + + ctx.getMapper().mapValue(swapOp.getSrc(), srcCopy); + + ctx.getMapper().mapValue(swapOp.getRes(), swapped); + ctx.getMapper().setExtraMapping(swapOp.getRes(), 0, swapped); + ctx.getMapper().setExtraMapping(swapOp.getRes(), 1, srcCopy); + + return success(); +} + +LogicalResult handleLLVMExtractValue(Operation *op, TranslationContext &ctx) { + auto extractOp = cast(op); + + auto position = extractOp.getPosition(); + if (position.size() != 1) { + return op->emitError("only single-level extractvalue supported"); + } + int64_t idx = position[0]; + + Value container = extractOp.getContainer(); + auto elem = ctx.getMapper().getExtraMapping(container, idx); + if (!elem) { + return op->emitError("extractvalue element ") << idx << " not mapped"; + } + + ctx.getMapper().mapValue(extractOp.getResult(), *elem); + return success(); +} + LogicalResult handleMemRefAtomicRMW(Operation *op, TranslationContext &ctx) { auto atomicOp = cast(op); auto &builder = ctx.getBuilder(); diff --git a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp index 825b659e8f..9e633084cc 100644 --- a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp @@ -327,6 +327,15 @@ LogicalResult handleArithTruncI(Operation *op, TranslationContext &ctx) { return success(); } +LogicalResult handleArithBitcast(Operation *op, TranslationContext &ctx) { + auto castOp = cast(op); + auto src = ctx.getMapper().getMapped(castOp.getIn()); + if (src) { + ctx.getMapper().mapValue(castOp.getResult(), *src); + } + return success(); +} + //===----------------------------------------------------------------------===// // Comparison and Select Operations //===----------------------------------------------------------------------===// @@ -451,8 +460,12 @@ LogicalResult handleArithSelect(Operation *op, TranslationContext &ctx) { Value zeroConst = createImmConst(0, builder, loc, ctx); V_CMP_NE_U32::create(builder, loc, *cond, zeroConst); + // v_cndmask_b32 requires VGPR sources — move from AGPR if needed. + Value trueV = ensureVGPR(*trueVal, ctx, builder, loc); + Value falseV = ensureVGPR(*falseVal, ctx, builder, loc); + auto result = - V_CNDMASK_B32::create(builder, loc, vregType, *falseVal, *trueVal, *cond); + V_CNDMASK_B32::create(builder, loc, vregType, falseV, trueV, *cond); ctx.getMapper().mapValue(selectOp.getResult(), result); return success(); } @@ -540,14 +553,14 @@ LogicalResult handleArithTruncF(Operation *op, TranslationContext &ctx) { int64_t numElems = vecType ? vecType.getNumElements() : 1; if (numElems > 1) { + // Multi-element truncf: pass through the source register mapping. + // The actual bf16 conversion (v_cvt_pk_bf16_f32) is deferred to the + // store handler in TranslateFromMLIR, which has the context to decide + // whether elements are extracted individually (non-wide-stores) or + // packed into dwordx4 (wide-stores). ctx.getMapper().mapValue(truncOp.getResult(), *src); } else { - Value srcVal = *src; - // VALU conversion instructions cannot read from AGPR. - if (isAGPRType(srcVal.getType())) { - auto vregTmp = ctx.createVRegType(); - srcVal = V_ACCVGPR_READ_B32::create(builder, loc, vregTmp, srcVal); - } + Value srcVal = ensureVGPR(*src, ctx, builder, loc); auto vregType = ctx.createVRegType(); Value result; if (srcElemType.isF32() && dstElemType.isBF16()) { @@ -575,12 +588,7 @@ LogicalResult handleArithExtF(Operation *op, TranslationContext &ctx) { Type srcElemType = getElementTypeOrSelf(extOp.getIn().getType()); Type dstElemType = getElementTypeOrSelf(extOp.getResult().getType()); - Value srcVal = *src; - // VALU conversion instructions cannot read from AGPR. - if (isAGPRType(srcVal.getType())) { - auto vregTmp = ctx.createVRegType(); - srcVal = V_ACCVGPR_READ_B32::create(builder, loc, vregTmp, srcVal); - } + Value srcVal = ensureVGPR(*src, ctx, builder, loc); auto vregType = ctx.createVRegType(); Value result; diff --git a/waveasm/lib/Transforms/handlers/Handlers.h b/waveasm/lib/Transforms/handlers/Handlers.h index c3d2265f9c..c39942de3b 100644 --- a/waveasm/lib/Transforms/handlers/Handlers.h +++ b/waveasm/lib/Transforms/handlers/Handlers.h @@ -38,6 +38,16 @@ namespace waveasm { +/// Move a value from AGPR to VGPR if needed. Returns the original value +/// unchanged when it is already a VGPR (or immediate/scalar). +inline mlir::Value ensureVGPR(mlir::Value val, TranslationContext &ctx, + mlir::OpBuilder &builder, mlir::Location loc) { + if (!isAGPRType(val.getType())) + return val; + auto vregTmp = ctx.createVRegType(); + return V_ACCVGPR_READ_B32::create(builder, loc, vregTmp, val); +} + //===----------------------------------------------------------------------===// // GPU Dialect Handlers //===----------------------------------------------------------------------===// diff --git a/waveasm/lib/Transforms/handlers/VectorHandlers.cpp b/waveasm/lib/Transforms/handlers/VectorHandlers.cpp index 03f74f122a..9c397c7d47 100644 --- a/waveasm/lib/Transforms/handlers/VectorHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/VectorHandlers.cpp @@ -114,6 +114,34 @@ LogicalResult handleVectorShapeCast(Operation *op, TranslationContext &ctx) { return success(); } +LogicalResult handleVectorFromElements(Operation *op, TranslationContext &ctx) { + auto fromOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + SmallVector mappedElems; + for (Value elem : fromOp.getElements()) { + auto mapped = ctx.getMapper().getMapped(elem); + if (!mapped) { + return op->emitError("from_elements operand not mapped"); + } + mappedElems.push_back(*mapped); + } + + int64_t numElems = mappedElems.size(); + if (numElems == 1) { + ctx.getMapper().mapValue(fromOp.getResult(), mappedElems[0]); + } else { + // Alignment must match the widest store the packed register feeds: + // 4 for buffer_store_dwordx4 (16 bytes), 2 for dwordx2, 1 otherwise. + int64_t alignment = (numElems >= 4) ? 4 : (numElems >= 2) ? 2 : 1; + auto packedType = ctx.createVRegType(numElems, alignment); + auto packed = PackOp::create(builder, loc, packedType, mappedElems); + ctx.getMapper().mapValue(fromOp.getResult(), packed); + } + return success(); +} + LogicalResult handleVectorBitCast(Operation *op, TranslationContext &ctx) { auto castOp = cast(op); diff --git a/waveasm/test/Transforms/agpr-bf16-store.mlir b/waveasm/test/Transforms/agpr-bf16-store.mlir index 28abc1d8c1..d965c21789 100644 --- a/waveasm/test/Transforms/agpr-bf16-store.mlir +++ b/waveasm/test/Transforms/agpr-bf16-store.mlir @@ -61,7 +61,10 @@ module { // The AGPR data must be moved to VGPR before the VALU bf16 conversion. %bf16_result = arith.truncf %result : vector<4xf32> to vector<4xbf16> - // CHECK: v_accvgpr_read_b32 v[{{[0-9]+}}:{{[0-9]+}}], a[{{[0-9]+}}:{{[0-9]+}}] + // CHECK: v_accvgpr_read_b32 v{{[0-9]+}}, a{{[0-9]+}} + // CHECK: v_accvgpr_read_b32 v{{[0-9]+}}, a{{[0-9]+}} + // CHECK: v_accvgpr_read_b32 v{{[0-9]+}}, a{{[0-9]+}} + // CHECK: v_accvgpr_read_b32 v{{[0-9]+}}, a{{[0-9]+}} // CHECK: v_cvt_pk_bf16_f32 v{{[0-9]+}}, v{{[0-9]+}}, v{{[0-9]+}} // CHECK: v_cvt_pk_bf16_f32 v{{[0-9]+}}, v{{[0-9]+}}, v{{[0-9]+}} // CHECK: buffer_store_dwordx2 diff --git a/waveasm/test/Transforms/extract-scalarization.mlir b/waveasm/test/Transforms/extract-scalarization.mlir index 071800493b..1ee5642e84 100644 --- a/waveasm/test/Transforms/extract-scalarization.mlir +++ b/waveasm/test/Transforms/extract-scalarization.mlir @@ -76,11 +76,11 @@ module { } // --------------------------------------------------------------- - // Negative: splat dense constant should not be rewritten + // Splat dense constant: now scalarized via general fallback // --------------------------------------------------------------- // CHECK-LABEL: func @no_rewrite_splat - // CHECK: arith.constant dense<42> : vector<4xi32> - // CHECK: vector.extract + // CHECK: arith.constant 42 : i32 + // CHECK: arith.addi func.func @no_rewrite_splat(%x: i32) -> i32 { %bcast = vector.broadcast %x : i32 to vector<4xi32> %splat = arith.constant dense<42> : vector<4xi32> diff --git a/waveasm/test/Transforms/permlane-swap-wide-store.mlir b/waveasm/test/Transforms/permlane-swap-wide-store.mlir new file mode 100644 index 0000000000..664844095c --- /dev/null +++ b/waveasm/test/Transforms/permlane-swap-wide-store.mlir @@ -0,0 +1,148 @@ +// RUN: waveasm-translate --waveasm-linear-scan --emit-assembly %s | FileCheck %s +// +// Test: v_permlane16_swap_b32 → v_cndmask_b32 → buffer_store_dwordx4 +// assembly sequence for coalesced bf16 wide stores. +// +// The WaveASM epilogue for wide bf16 stores works as follows: +// 1. v_permlane16_swap_b32 exchanges packed bf16 data between partner +// lanes (16 apart) so each lane sees the other's data. +// 2. v_cndmask_b32 selects between own and partner data based on VCC +// (set by a prior lane-position comparison). +// 3. The 4 selected registers are packed and written via +// buffer_store_dwordx4 (128-bit store). + +// CHECK-LABEL: permlane_cndmask_dwordx4_store: +waveasm.program @permlane_cndmask_dwordx4_store + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi<> + attributes {vgprs = 32 : i64, sgprs = 32 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %v1 = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + %v2 = waveasm.precolored.vreg 2 : !waveasm.pvreg<2> + %v3 = waveasm.precolored.vreg 3 : !waveasm.pvreg<3> + + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %voff = waveasm.precolored.vreg 4 : !waveasm.pvreg<4> + + %vcc = waveasm.precolored.sreg 106, 2 : !waveasm.psreg<106, 2> + %lane_id = waveasm.precolored.vreg 5 : !waveasm.pvreg<5> + %c16 = waveasm.constant 16 : !waveasm.imm<16> + + // CHECK: v_cmp_lt_u32 vcc, v5, 16 + waveasm.v_cmp_lt_u32 %lane_id, %c16 : !waveasm.pvreg<5>, !waveasm.imm<16> + + // CHECK: v_permlane16_swap_b32 v{{[0-9]+}}, v0 + %swap0 = waveasm.v_permlane16_swap_b32 %v0 : !waveasm.pvreg<0> -> !waveasm.vreg + // CHECK: v_permlane16_swap_b32 v{{[0-9]+}}, v1 + %swap1 = waveasm.v_permlane16_swap_b32 %v1 : !waveasm.pvreg<1> -> !waveasm.vreg + // CHECK: v_permlane16_swap_b32 v{{[0-9]+}}, v2 + %swap2 = waveasm.v_permlane16_swap_b32 %v2 : !waveasm.pvreg<2> -> !waveasm.vreg + // CHECK: v_permlane16_swap_b32 v{{[0-9]+}}, v3 + %swap3 = waveasm.v_permlane16_swap_b32 %v3 : !waveasm.pvreg<3> -> !waveasm.vreg + + // CHECK: v_cndmask_b32 v{{[0-9]+}}, v{{[0-9]+}}, v0 + %sel0 = waveasm.v_cndmask_b32 %swap0, %v0, %vcc + : !waveasm.vreg, !waveasm.pvreg<0>, !waveasm.psreg<106, 2> -> !waveasm.vreg + // CHECK: v_cndmask_b32 v{{[0-9]+}}, v{{[0-9]+}}, v1 + %sel1 = waveasm.v_cndmask_b32 %swap1, %v1, %vcc + : !waveasm.vreg, !waveasm.pvreg<1>, !waveasm.psreg<106, 2> -> !waveasm.vreg + // CHECK: v_cndmask_b32 v{{[0-9]+}}, v{{[0-9]+}}, v2 + %sel2 = waveasm.v_cndmask_b32 %swap2, %v2, %vcc + : !waveasm.vreg, !waveasm.pvreg<2>, !waveasm.psreg<106, 2> -> !waveasm.vreg + // CHECK: v_cndmask_b32 v{{[0-9]+}}, v{{[0-9]+}}, v3 + %sel3 = waveasm.v_cndmask_b32 %swap3, %v3, %vcc + : !waveasm.vreg, !waveasm.pvreg<3>, !waveasm.psreg<106, 2> -> !waveasm.vreg + + %packed = waveasm.pack %sel0, %sel1, %sel2, %sel3 + : (!waveasm.vreg, !waveasm.vreg, !waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<4, 4> + + // CHECK: buffer_store_dwordx4 v[{{[0-9]+}}:{{[0-9]+}}], v4, s[0:3], 0 offen + waveasm.buffer_store_dwordx4 %packed, %srd, %voff + : !waveasm.vreg<4, 4>, !waveasm.psreg<0, 4>, !waveasm.pvreg<4> + + // CHECK: s_endpgm + waveasm.s_endpgm +} + +// ----- + +// Test 2: Dual-output pair op for paired wide stores (no duplicate stores). +// V_PERMLANE16_SWAP_B32_PAIR takes two inputs (old_dst, src) and produces +// two outputs (new_dst = partner's src, new_src = partner's old_dst). +// CHECK-LABEL: permlane_pair_wide_store: +waveasm.program @permlane_pair_wide_store + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi<> + attributes {vgprs = 32 : i64, sgprs = 32 : i64} { + + %a_lo = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %a_hi = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + %b_lo = waveasm.precolored.vreg 2 : !waveasm.pvreg<2> + %b_hi = waveasm.precolored.vreg 3 : !waveasm.pvreg<3> + + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %voff = waveasm.precolored.vreg 4 : !waveasm.pvreg<4> + + %vcc = waveasm.precolored.sreg 106, 2 : !waveasm.psreg<106, 2> + %lane_id = waveasm.precolored.vreg 5 : !waveasm.pvreg<5> + %c16 = waveasm.constant 16 : !waveasm.imm<16> + + waveasm.v_cmp_lt_u32 %lane_id, %c16 : !waveasm.pvreg<5>, !waveasm.imm<16> + + // CHECK: v_permlane16_swap_b32 + %partner_b_lo, %partner_a_lo = waveasm.v_permlane16_swap_b32_pair %a_lo, %b_lo + : !waveasm.pvreg<0>, !waveasm.pvreg<2> -> !waveasm.vreg, !waveasm.vreg + // CHECK: v_permlane16_swap_b32 + %partner_b_hi, %partner_a_hi = waveasm.v_permlane16_swap_b32_pair %a_hi, %b_hi + : !waveasm.pvreg<1>, !waveasm.pvreg<3> -> !waveasm.vreg, !waveasm.vreg + + // Lower lane: [a_lo, a_hi, partner_a_lo, partner_a_hi] + // Upper lane: [partner_b_lo, partner_b_hi, b_lo, b_hi] + %d0 = waveasm.v_cndmask_b32 %partner_b_lo, %a_lo, %vcc + : !waveasm.vreg, !waveasm.pvreg<0>, !waveasm.psreg<106, 2> -> !waveasm.vreg + %d1 = waveasm.v_cndmask_b32 %partner_b_hi, %a_hi, %vcc + : !waveasm.vreg, !waveasm.pvreg<1>, !waveasm.psreg<106, 2> -> !waveasm.vreg + %d2 = waveasm.v_cndmask_b32 %b_lo, %partner_a_lo, %vcc + : !waveasm.pvreg<2>, !waveasm.vreg, !waveasm.psreg<106, 2> -> !waveasm.vreg + %d3 = waveasm.v_cndmask_b32 %b_hi, %partner_a_hi, %vcc + : !waveasm.pvreg<3>, !waveasm.vreg, !waveasm.psreg<106, 2> -> !waveasm.vreg + + %packed = waveasm.pack %d0, %d1, %d2, %d3 + : (!waveasm.vreg, !waveasm.vreg, !waveasm.vreg, !waveasm.vreg) -> !waveasm.vreg<4, 4> + + // CHECK: buffer_store_dwordx4 + waveasm.buffer_store_dwordx4 %packed, %srd, %voff + : !waveasm.vreg<4, 4>, !waveasm.psreg<0, 4>, !waveasm.pvreg<4> + + waveasm.s_endpgm +} + +// ----- + +// Test 3: dst==src fallback path for v_permlane16_swap_b32. +// When the allocator assigns the same register for dst and src, the emitter +// uses scratch VGPRs to avoid clobbering. +// This test uses a tight register budget to encourage dst==src allocation. +// CHECK-LABEL: permlane_dstsrc_fallback: +waveasm.program @permlane_dstsrc_fallback + target = #waveasm.target<#waveasm.gfx950, 5> + abi = #waveasm.abi<> + attributes {vgprs = 8 : i64, sgprs = 8 : i64} { + + %v0 = waveasm.precolored.vreg 0 : !waveasm.pvreg<0> + %srd = waveasm.precolored.sreg 0, 4 : !waveasm.psreg<0, 4> + %voff = waveasm.precolored.vreg 1 : !waveasm.pvreg<1> + + // Whether the allocator chooses dst==src or dst!=src, we should see + // a valid v_permlane16_swap_b32 in the output. + // CHECK: v_permlane16_swap_b32 + %swap = waveasm.v_permlane16_swap_b32 %v0 : !waveasm.pvreg<0> -> !waveasm.vreg + + // CHECK: buffer_store_dword v{{[0-9]+}}, v1, s[0:3], 0 offen + waveasm.buffer_store_dword %swap, %srd, %voff + : !waveasm.vreg, !waveasm.psreg<0, 4>, !waveasm.pvreg<1> + + // CHECK: s_endpgm + waveasm.s_endpgm +}