diff --git a/wave_lang/kernel/wave/schedule_reordering.py b/wave_lang/kernel/wave/schedule_reordering.py index 38ff0f3a9..3256f5972 100644 --- a/wave_lang/kernel/wave/schedule_reordering.py +++ b/wave_lang/kernel/wave/schedule_reordering.py @@ -78,6 +78,7 @@ class SchedReorderStrategy(Enum): TWO_PP_CLUSTER = 0x220 ASYNC_TWO_PP_CLUSTER = 0x2201 MXFP4_PP_CLUSTER = 0x101 + FOUR_WAVE_INTERWEAVE = 0x120 def is_pingpong_strategy(strategy): @@ -103,6 +104,7 @@ class CompatibleBlockSize: twoPPConfig = CompatibleBlockSize(128, 128, 64, 16, False, MMA) asyncTwoPPConfig = CompatibleBlockSize(128, 128, 64, 16, True, MMA) MXFP4PPConfig = CompatibleBlockSize(256, 128, 256, 4, False, ScaledMMA) +fourWaveConfig = CompatibleBlockSize(64, 64, 32, 16, True, MMA) class InsertionMode(Enum): @@ -546,6 +548,16 @@ def select_reorder_strategy( hardware_constraint, ): flat_wave_count = math.prod(hardware_constraint.waves_per_block) + if flat_wave_count == 4 and is_compatible_strategy( + mTile, + nTile, + kTile, + mma_bitwidth, + use_global_to_shared, + mma_type, + fourWaveConfig, + ): + return SchedReorderStrategy.FOUR_WAVE_INTERWEAVE if flat_wave_count != 8: return SchedReorderStrategy.NONE if is_compatible_strategy( @@ -824,6 +836,93 @@ def transform_MXFP4_PP_clusters( return clusters +def transform_four_wave_clusters( + mma_nodes, + local_load_lhs, + local_load_rhs, + global_to_shared_lhs, + global_to_shared_rhs, +): + num_slices = 2 + sliced_mma_nodes, sliced_local_load_lhs, sliced_local_load_rhs = slice_mma( + mma_nodes, local_load_lhs, local_load_rhs, num_slice=num_slices + ) + # Check that we have valid slice size for local_loads and mmas. + assert len(sliced_mma_nodes) == len(sliced_local_load_rhs) + assert len(sliced_mma_nodes) == len(sliced_local_load_lhs) + assert len(sliced_mma_nodes) == num_slices + + context_location = mma_nodes and mma_nodes[0].location + + clusters = [] + tmp_graph = fx.Graph() + # 1st cluster interleaved local and global reads. + clusters.append(sliced_local_load_lhs[0]) + clusters.append(sliced_local_load_rhs[0]) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, sliced_local_load_rhs[0])) + + clusters.append(global_to_shared_lhs) + clusters.append(global_to_shared_rhs) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, global_to_shared_rhs)) + + barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + + # 2nd cluster mma_slice[0]. + clusters.append(sliced_mma_nodes[0]) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, sliced_mma_nodes[0])) + + independent_global_count = len(global_to_shared_lhs + global_to_shared_rhs) + barrier_op = MemoryCounterWait(load=independent_global_count).add_to_graph( + tmp_graph + ) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + + barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + + # 3rd cluster local load 2nd slice. + clusters.append(sliced_local_load_lhs[1]) + clusters.append(sliced_local_load_rhs[1]) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, sliced_local_load_rhs[1])) + + barrier_op = MemoryCounterWait(load=0).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + + barrier_op = WorkgroupBarrier().add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, clusters[-1].op)) + + # 4th cluster mma_slice[1]. + clusters.append(sliced_mma_nodes[1]) + barrier_op = SchedulingBarrier([]).add_to_graph(tmp_graph) + barrier_op.location = context_location + clusters.append(insert_op_after(barrier_op, sliced_mma_nodes[1])) + + return clusters + + ############################################################## # Helper fn to classify/detect ops. ############################################################## @@ -1037,6 +1136,14 @@ def schedule_reordering( local_write_rhs_scale, ) clusters = flatten_list(clusters) + elif reorder_strategy == SchedReorderStrategy.FOUR_WAVE_INTERWEAVE: + clusters = transform_four_wave_clusters( + mma_nodes, + local_load_lhs, + local_load_rhs, + global_to_shared_lhs, + global_to_shared_rhs, + ) else: raise ValueError("Unhandled SchedReorderStrategy case.") reordered_graph = reorder_graph(graph, clusters)