Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions examples/python/7.1_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,21 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_wide_stores(

def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm(
is_debug=False,
shape=(1024, 1024, 8192),
block=(128, 256, 256),
shape=(1024, 3072, 8192),
block=(128, 128, 256),
eliminate_epilogue=False,
):
"""Preshuffle-B MXFP4 GEMM with dynamic M, N, K."""
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(
shape, block, wave_shape=(1, 4), reorder_workgroups=False
"""Preshuffle-B MXFP4 GEMM with coalesced dwordx4 stores (WaveASM backend).

Same kernel as the LLVM coalesced-stores test but compiled through the
C++ WaveASM backend. Emits v_permlane16_swap_b32 + buffer_store_dwordx4.
"""
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b_wide_store(
shape,
block,
wave_shape=(1, 4),
reorder_workgroups=True,
)
# Make M, N, K dynamic so the compiler does not specialize on problem size.
dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K]
for sym in dynamic_symbols:
del options.subs[sym]
Expand All @@ -483,18 +489,15 @@ def test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm_asm(
options.use_wave_asm_backend = True
options.wave_runtime = True
options.eliminate_epilogue = eliminate_epilogue
options.dump_intermediates = "build/intermediates/"
schedule = get_mxfp4_asymmetric_schedule(
eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True
)
options.print_ir_after = "all" if is_debug else []
options = set_default_run_config(options)
gemm = wave_compile(options, gemm, schedule)

_run_mxfp_gemm_preshuffle(gemm, shape, all=True)
print(
"MXFP GEMM preshuffle-B 4-wave dynamic M, N, K (WaveASM backend) test passed!"
)
_run_mxfp_gemm_preshuffle(gemm, shape, all=True, output_dtype=torch.bfloat16)
print("MXFP GEMM preshuffle-B 4-wave dwordx4 (WaveASM backend) test passed!")


if __name__ == "__main__":
Expand Down
241 changes: 204 additions & 37 deletions wave_lang/kernel/compiler/wave_codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Attribute,
BF16Type,
DenseElementsAttr,
F32Type,
IndexType,
InsertionPoint,
IntegerAttr,
Expand Down Expand Up @@ -1323,9 +1324,29 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):

use_llvm_store = flags != MemoryAccessFlags.NONE

if getattr(node, "_permlane_pack_global", False):
import os as _os

_use_old_asm = _os.environ.get("WAVE_OLD_ASM_STORES", "") == "1"
_has_pack = getattr(node, "_permlane_pack_global", False)

if _has_pack and _use_old_asm and emitter.options.backend == "asm":
_write_permlane_pack_to_global_asm(
emitter,
insert_vector,
kb_dest,
output_shape,
start_indices,
start_indices_wg,
start_indices_th,
get_custom(memory),
index,
)
return

if _has_pack:
is_shared = get_custom(memory).type.address_space == SHARED_ADDRESS_SPACE
if not is_shared and isinstance(element_type, BF16Type):
is_bf16 = isinstance(element_type, BF16Type)
if not is_shared and is_bf16:
role = getattr(node, "_permlane_pack_role", "unpaired")

if role == "first":
Expand Down Expand Up @@ -1438,34 +1459,61 @@ def _write_permlane_pair_to_global(
``_create_vec_read_write`` call would need to be split into
two lane-divergent stores.
"""
num_elems_a = vec_a.type.shape[0] if hasattr(vec_a.type, "shape") else 1
num_elems_b = vec_b.type.shape[0] if hasattr(vec_b.type, "shape") else 1
assert num_elems_a == 4 and num_elems_b == 4, (
f"_write_permlane_pair_to_global expects 4 bf16 elements per thread "
f"per tile, got {num_elems_a} and {num_elems_b}."
)

bf16_type = BF16Type.get()
f32_type = F32Type.get()
i16_type = IntegerType.get_signless(16)
i32_type = IntegerType.get_signless(32)
idx_type = IndexType.get()
v2i32_type = VectorType.get([2], i32_type)
v4i32_type = VectorType.get([4], i32_type)
v8bf16_type = VectorType.get([8], bf16_type)

i32_a = vector_d.bitcast(v2i32_type, vec_a)
a_lo = vector_d.extract(i32_a, static_position=[0], dynamic_position=[])
a_hi = vector_d.extract(i32_a, static_position=[1], dynamic_position=[])

i32_b = vector_d.bitcast(v2i32_type, vec_b)
b_lo = vector_d.extract(i32_b, static_position=[0], dynamic_position=[])
b_hi = vector_d.extract(i32_b, static_position=[1], dynamic_position=[])
is_asm = emitter.options.backend == "asm"

def _extract_i32_dwords(vec_bf16):
"""Extract two i32 dwords from a vector<4xbf16>."""
if not is_asm:
i32_vec = vector_d.bitcast(v2i32_type, vec_bf16)
lo = vector_d.extract(i32_vec, static_position=[0], dynamic_position=[])
hi = vector_d.extract(i32_vec, static_position=[1], dynamic_position=[])
return lo, hi

f32_src = None
defn = vec_bf16.owner
if defn is not None and defn.name == "arith.truncf":
f32_src = defn.operands[0]
assert f32_src is not None, (
"ASM wide-store expects vec from arith.truncf; "
f"got {defn.name if defn else 'block arg'}"
)
e0 = vector_d.extract(f32_src, static_position=[0], dynamic_position=[])
e1 = vector_d.extract(f32_src, static_position=[1], dynamic_position=[])
e2 = vector_d.extract(f32_src, static_position=[2], dynamic_position=[])
e3 = vector_d.extract(f32_src, static_position=[3], dynamic_position=[])
b0 = arith_d.truncf(bf16_type, e0)
b1 = arith_d.truncf(bf16_type, e1)
b2 = arith_d.truncf(bf16_type, e2)
b3 = arith_d.truncf(bf16_type, e3)
c16 = arith_d.constant(i32_type, 16)

def _pack_pair(lo_bf16, hi_bf16):
lo_i32 = arith_d.extui(i32_type, arith_d.bitcast(i16_type, lo_bf16))
hi_i32 = arith_d.extui(i32_type, arith_d.bitcast(i16_type, hi_bf16))
return arith_d.ori(arith_d.shli(hi_i32, c16), lo_i32)

return _pack_pair(b0, b1), _pack_pair(b2, b3)

a_lo, a_hi = _extract_i32_dwords(vec_a)
b_lo, b_hi = _extract_i32_dwords(vec_b)

swap_type = llvm_d.StructType.get_literal([i32_type, i32_type])

# old_dst = a, src = b → result[0] = partner's b, result[1] = partner's a
# 2 swaps, extract both elements from each.
# Element [0] = new_dst = partner's src, Element [1] = new_src = partner's old_dst.
# The WaveASM handler creates V_PERMLANE16_SWAP_B32_PAIR when it detects
# that element [1] is extracted, properly modeling both hardware outputs.
swap_lo = rocdl_d.permlane16_swap(swap_type, a_lo, b_lo, False, False)
swap_hi = rocdl_d.permlane16_swap(swap_type, a_hi, b_hi, False, False)

partner_b_lo = llvm_d.extractvalue(i32_type, swap_lo, [0])
partner_a_lo = llvm_d.extractvalue(i32_type, swap_lo, [1])
partner_b_hi = llvm_d.extractvalue(i32_type, swap_hi, [0])
Expand All @@ -1490,29 +1538,45 @@ def _write_permlane_pair_to_global(
elems_per_thread = arith_d.constant(idx_type, 4)

# Lower lane uses tile A's address; upper lane uses tile B's address.
# Upper lane subtracts elems_per_thread from the last dim to align
# to the lower lane's column position (same as the single-write path).
adj_th = list(start_indices_th_a)
adj_full = list(start_indices_a)
for dim_idx in range(len(adj_th)):
if dim_idx == len(adj_th) - 1:
adj_b_th = arith_d.subi(start_indices_th_b[-1], elems_per_thread)
adj_b_full = arith_d.subi(start_indices_b[-1], elems_per_thread)

if is_asm:
# ASM backend: the SRD must be uniform (goes through readfirstlane).
# Both tiles share the same workgroup, so wg_a == wg_b at runtime.
# Use wg_a for the SRD and th_b directly for the thread offset.
for dim_idx in range(len(adj_th)):
if dim_idx == len(adj_th) - 1:
adj_b_th = arith_d.subi(start_indices_th_b[-1], elems_per_thread)
adj_b_full = arith_d.subi(start_indices_b[-1], elems_per_thread)
else:
adj_b_th = start_indices_th_b[dim_idx]
adj_b_full = start_indices_b[dim_idx]
adj_th[dim_idx] = arith_d.select(is_lower, adj_th[dim_idx], adj_b_th)
adj_full[dim_idx] = arith_d.select(is_lower, adj_full[dim_idx], adj_b_full)
else:
adj_th[dim_idx] = arith_d.select(
is_lower, start_indices_th_a[dim_idx], start_indices_th_b[dim_idx]
)
adj_full[dim_idx] = arith_d.select(
is_lower, start_indices_a[dim_idx], start_indices_b[dim_idx]
adj_wg = start_indices_wg_a
else:
for dim_idx in range(len(adj_th)):
if dim_idx == len(adj_th) - 1:
adj_b_th = arith_d.subi(start_indices_th_b[-1], elems_per_thread)
adj_b_full = arith_d.subi(start_indices_b[-1], elems_per_thread)
adj_th[dim_idx] = arith_d.select(is_lower, adj_th[dim_idx], adj_b_th)
adj_full[dim_idx] = arith_d.select(
is_lower, adj_full[dim_idx], adj_b_full
)
else:
adj_th[dim_idx] = arith_d.select(
is_lower, start_indices_th_a[dim_idx], start_indices_th_b[dim_idx]
)
adj_full[dim_idx] = arith_d.select(
is_lower, start_indices_a[dim_idx], start_indices_b[dim_idx]
)
adj_wg_list = list(start_indices_wg_a)
for dim_idx in range(len(adj_wg_list)):
adj_wg_list[dim_idx] = arith_d.select(
is_lower, start_indices_wg_a[dim_idx], start_indices_wg_b[dim_idx]
)

adj_wg = list(start_indices_wg_a)
for dim_idx in range(len(adj_wg)):
adj_wg[dim_idx] = arith_d.select(
is_lower, start_indices_wg_a[dim_idx], start_indices_wg_b[dim_idx]
)
adj_wg = tuple(adj_wg_list)

sel_output_shape = output_shape_a
sel_memory_custom = memory_custom_a
Expand All @@ -1526,7 +1590,7 @@ def _write_permlane_pair_to_global(
wide_vec,
None,
tuple(adj_full),
tuple(adj_wg),
adj_wg,
tuple(adj_th),
8,
sel_memory_custom,
Expand All @@ -1535,6 +1599,109 @@ def _write_permlane_pair_to_global(
)


def _write_permlane_pack_to_global_asm(
emitter: WaveEmitter,
insert_vector: Value,
kb_dest: Value,
output_shape: tuple,
start_indices: tuple,
start_indices_wg: tuple,
start_indices_th: tuple,
memory_custom,
index: dict,
):
"""Single-tile wide store for ASM backend with explicit bf16 conversion.

Each lane exchanges its own 4 bf16 values with a partner 16 apart,
assembles an 8-element wide vector, and writes buffer_store_dwordx4.
Both lane halves write identical data (duplicate stores).
"""
bf16_type = BF16Type.get()
f32_type = F32Type.get()
i16_type = IntegerType.get_signless(16)
i32_type = IntegerType.get_signless(32)
idx_type = IndexType.get()
v4i32_type = VectorType.get([4], i32_type)
v8bf16_type = VectorType.get([8], bf16_type)

f32_src = None
defn = insert_vector.owner
if defn is not None and defn.name == "arith.truncf":
f32_src = defn.operands[0]
assert (
f32_src is not None
), "_write_permlane_pack_to_global_asm expects vec from arith.truncf"

e0 = vector_d.extract(f32_src, static_position=[0], dynamic_position=[])
e1 = vector_d.extract(f32_src, static_position=[1], dynamic_position=[])
e2 = vector_d.extract(f32_src, static_position=[2], dynamic_position=[])
e3 = vector_d.extract(f32_src, static_position=[3], dynamic_position=[])
b0 = arith_d.truncf(bf16_type, e0)
b1 = arith_d.truncf(bf16_type, e1)
b2 = arith_d.truncf(bf16_type, e2)
b3 = arith_d.truncf(bf16_type, e3)

c16 = arith_d.constant(i32_type, 16)

def _pack_pair(lo_bf16, hi_bf16):
lo_i32 = arith_d.extui(i32_type, arith_d.bitcast(i16_type, lo_bf16))
hi_i32 = arith_d.extui(i32_type, arith_d.bitcast(i16_type, hi_bf16))
return arith_d.ori(arith_d.shli(hi_i32, c16), lo_i32)

own_lo = _pack_pair(b0, b1)
own_hi = _pack_pair(b2, b3)

swap_type = llvm_d.StructType.get_literal([i32_type, i32_type])
partner_lo = llvm_d.extractvalue(
i32_type, rocdl_d.permlane16_swap(swap_type, own_lo, own_lo, False, False), [0]
)
partner_hi = llvm_d.extractvalue(
i32_type, rocdl_d.permlane16_swap(swap_type, own_hi, own_hi, False, False), [0]
)

lane_in_wave = arith_d.remui(emitter.thread_ids[0], arith_d.constant(idx_type, 64))
half_pos = arith_d.remui(lane_in_wave, arith_d.constant(idx_type, 32))
is_lower = arith_d.cmpi(
arith_d.CmpIPredicate.ult, half_pos, arith_d.constant(idx_type, 16)
)

d0 = arith_d.select(is_lower, own_lo, partner_lo)
d1 = arith_d.select(is_lower, own_hi, partner_hi)
d2 = arith_d.select(is_lower, partner_lo, own_lo)
d3 = arith_d.select(is_lower, partner_hi, own_hi)

wide_i32 = vector_d.from_elements(v4i32_type, [d0, d1, d2, d3])
wide_vec = vector_d.bitcast(v8bf16_type, wide_i32)

num_elems = 4
elems_per_thread = arith_d.constant(idx_type, num_elems)

adj_th = list(start_indices_th)
adj_th[-1] = arith_d.select(
is_lower, adj_th[-1], arith_d.subi(adj_th[-1], elems_per_thread)
)

adj_full = list(start_indices)
adj_full[-1] = arith_d.select(
is_lower, adj_full[-1], arith_d.subi(adj_full[-1], elems_per_thread)
)

_create_vec_read_write(
emitter,
output_shape,
kb_dest,
wide_vec,
None,
tuple(adj_full),
start_indices_wg,
tuple(adj_th),
8,
memory_custom,
None,
node_index=index,
)


def assume_index_subgroup_uniform(value: Value, element_type: IrType) -> Value:
res = gpu_d.subgroup_broadcast(value, gpu_d.BroadcastType.first_active_lane)
return res
Expand Down
9 changes: 9 additions & 0 deletions wave_lang/kernel/wave/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,10 @@ def build_graph_passes(
partial(decompose_topk_ops, trace, launchable.constraints),
]

# Tags bf16 global writes that use source/target transposition
# (wide_stores layout) with _permlane_pack_global so the codegen emits
# v_permlane16_swap_b32 + buffer_store_dwordx4. Runs unconditionally;
# writes without source/target are left untouched.
from .wide_store_coalescing import coalesce_wide_stores

graph_passes.append(partial(coalesce_wide_stores, trace))
Expand Down Expand Up @@ -1334,6 +1338,11 @@ def _generate_asm_code(mb, options):
mlir_file.write(kernel_mlir)
mlir_path = mlir_file.name

# DEBUG: save a copy of the MLIR IR for inspection
import shutil as _shutil

_shutil.copy2(mlir_path, "/tmp/waveasm_input.mlir")

try:
base_passes = [
"--mlir-cse",
Expand Down
6 changes: 2 additions & 4 deletions wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,6 @@ def _get_tagged_mxfp4_gemm_preshuffle_b_impl(
K_PACKED = tkl.sym.K_PACKED
K_SCALE_SHUFFLED = tkl.sym.K_SCALE_SHUFFLED

if wide_stores:
m_symbol = tkl.sym.m_symbol
n_symbol = tkl.sym.n_symbol

constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
Expand All @@ -426,6 +422,8 @@ def _get_tagged_mxfp4_gemm_preshuffle_b_impl(
constraints += [tkw.Assumption(K > BLOCK_K * 6)]

if wide_stores:
m_symbol = tkl.sym.m_symbol
n_symbol = tkl.sym.n_symbol
constraints += [tkw.IteratorBindings({m_symbol: M, n_symbol: N})]
constraints += [tkw.Assumption(Eq(M % BLOCK_M, 0))]
constraints += [tkw.Assumption(Eq(N % BLOCK_N, 0))]
Expand Down
Loading
Loading