diff --git a/tests/kernel/wave/dist_gemm_test.py b/tests/kernel/wave/dist_gemm_test.py index 8fe7651ab..e49554702 100644 --- a/tests/kernel/wave/dist_gemm_test.py +++ b/tests/kernel/wave/dist_gemm_test.py @@ -55,7 +55,9 @@ MMAType.F32_32x32x8_F16, ], ) -@pytest.mark.parametrize("devices", [(1, 1), (2, 1), (4, 1), (8, 1)]) +@pytest.mark.parametrize( + "devices", [(1, 1), (2, 1), (4, 1), (8, 1), (1, 2), (1, 4), (1, 8), (2, 2), (2, 4), (4, 2)] +) @pytest.mark.parametrize("datatype", [torch.float16]) def testPureGemm( shape: tuple[int], diff --git a/wave_lang/kernel/compiler/host_utils.py b/wave_lang/kernel/compiler/host_utils.py index 0d74eee78..f07a0754e 100644 --- a/wave_lang/kernel/compiler/host_utils.py +++ b/wave_lang/kernel/compiler/host_utils.py @@ -9,6 +9,7 @@ Value, arith_d, flow_d, + tensor_d, ) from .kernel_codegen import BindingDesc, KernelSignature, BindingType @@ -34,6 +35,18 @@ def substitute_dimensions_in_shape(symbolic_shape, symbol_map): return symbolic_shape +def get_or_create_index_constant(value: int, constant_map: dict) -> Value: + """ + This function reuses existing constants in the mlir output if they exist. + """ + if value not in constant_map: + constant_map[value] = arith_d.constant( + IndexType.get(), + IntegerAttr.get(IndexType.get(), value), + ) + return constant_map[value] + + class HostSignature: """ With the introduction of multi-device support, the host signature may not be the same as the kernel signature. @@ -151,9 +164,7 @@ def split_input_tensors( # Get the mapping from dimensions to device layout constraint_map = {c.dim: (c.tile_size, c.device_dim) for c in device_constraints} - host_shape = host_buffer_binding.kernel_buffer_type.symbolic_shape - device_map = device_map if device_map is not None else {} # Tracks constants to avoid repeated constants in the IR @@ -226,35 +237,6 @@ def split_input_tensors( if hasattr(tile_dim_size, "__int__") else tile_dim_size ) - - # Add to slice signature for caching - slice_signature.append((start_offset_int, tile_dim_size_int)) - - if start_offset_int not in constant_map: - constant_map[start_offset_int] = arith_d.constant( - IndexType.get(), - IntegerAttr.get(IndexType.get(), start_offset_int), - ) - start_idx = constant_map[start_offset_int] - start_indices.append(start_idx) - - if tile_dim_size_int not in constant_map: - constant_map[tile_dim_size_int] = arith_d.constant( - IndexType.get(), - IntegerAttr.get(IndexType.get(), tile_dim_size_int), - ) - length = constant_map[tile_dim_size_int] - lengths.append(length) - result_shape.append(tile_dim_size_int) - - # Store slice info for this dimension - slice_info[f"dim_{i}"] = { - "symbol": dim, - "start_offset": start_offset_int, - "length": tile_dim_size_int, - "device_coord": device_coord, - "device_dim": device_dim, - } else: # This dimension is not split across devices full_dim_size = mlir_shape[i] @@ -264,37 +246,35 @@ def split_input_tensors( else full_dim_size ) - slice_signature.append((0, full_dim_size)) - - if 0 not in constant_map: - constant_map[0] = arith_d.constant( - IndexType.get(), IntegerAttr.get(IndexType.get(), 0) - ) - start_indices.append(constant_map[0]) - - if full_dim_size not in constant_map: - constant_map[full_dim_size] = arith_d.constant( - IndexType.get(), IntegerAttr.get(IndexType.get(), full_dim_size) - ) - length = constant_map[full_dim_size] - lengths.append(length) - result_shape.append(full_dim_size) - - # Store slice info for this dimension - slice_info[f"dim_{i}"] = { - "symbol": dim, - "start_offset": 0, - "length": full_dim_size, - "device_coord": None, - "device_dim": None, - } + start_offset_int = 0 + tile_dim_size_int = full_dim_size + device_coord = None + device_dim = None + + # Add to slice signature for caching + slice_signature.append((start_offset_int, tile_dim_size_int)) + start_idx = get_or_create_index_constant(start_offset_int, constant_map) + length = get_or_create_index_constant(tile_dim_size_int, constant_map) + + start_indices.append(start_idx) + lengths.append(length) + result_shape.append(tile_dim_size_int) + + # Store slice info for this dimension + slice_info[f"dim_{i}"] = { + "symbol": dim, + "start_offset": start_offset_int, + "length": tile_dim_size_int, + "device_coord": device_coord, + "device_dim": device_dim, + } # Convert slice signature to a hashable key slice_key = tuple(slice_signature) # check if we've already created this slice if slice_key in slice_cache: - device_slice_result, cached_slice_info = slice_cache[slice_key] + device_slice_result = slice_cache[slice_key] else: # check if the host_tensor and the slice dimensions match if host_type.shape != result_shape: @@ -320,7 +300,7 @@ def split_input_tensors( device_slice_result = transferred_slice.result # Cache the slice - slice_cache[slice_key] = (device_slice_result, slice_info) + slice_cache[slice_key] = device_slice_result # store device mapping if device_id not in device_map: @@ -347,36 +327,42 @@ def merge_output_slices( constant_map: dict, device_tensor_map: list[dict], ): - # getting the orignial output tensor - result_tensor = arguments[ + + original_tensor = arguments[ len(host_sig.buffer_bindings) - len(host_sig.output_buffer_bindings) ] output_idx = len(host_sig.buffer_bindings) - len(host_sig.output_buffer_bindings) - + result_tensor = original_tensor + for i, dispatch_result in enumerate(output_list): - # Get the device coordinates for this result + # Get the device map for the output from the i-th device # TODO: Handle multiple outputs per device, currently assuming one output per device device_info = device_tensor_map[i][output_idx] - slice_shape = device_info["result_shape"] slice_info = device_info["slice_info"] - offsets = [] - for dim_key in sorted(slice_info.keys()): # dim_0, dim_1, etc. + + # for each dimension in the output tensor, get the start offset and data length + start_offsets = [] + data_lengths = [] + strides = [] + + for dim_key in sorted(slice_info.keys()): # dim_0, dim_1 start_offset = slice_info[dim_key]["start_offset"] - if start_offset not in constant_map: - constant_map[start_offset] = arith_d.constant( - IndexType.get(), - IntegerAttr.get(IndexType.get(), start_offset), - ) - offsets.append(constant_map[start_offset]) + length = slice_info[dim_key]["length"] + start_offsets.append(start_offset) + data_lengths.append(length) + strides.append(1) # stride of 1 always? # Update the result tensor with this device's output - # flow_d.TensorUpdateOp signature: (target, target_dims, start_indices, update, update_dims) - result_value = flow_d.TensorUpdateOp( - result_tensor, # target tensor - [], # target_dims (empty list for dynamic dims) - offsets, # start_indices (where to place it) - dispatch_result.results[0], # update (the slice to insert) - [], # update_dims (empty list for dynamic dims) + # tensor_d.InsertSliceOp signature: (source, dest, offsets, sizes, strides, static_offsets, static_sizes, static_strides) + result_value = tensor_d.InsertSliceOp( + dispatch_result.results[0], # source + result_tensor, # dest + [], # offsets + [], # sizes + [], # strides + start_offsets, # [start_m, start_n] + data_lengths, # [tile_m, tile_n] + strides, # [1, 1] ).result result_tensor = result_value diff --git a/wave_lang/kernel/wave/templates/dist_gemm.py b/wave_lang/kernel/wave/templates/dist_gemm.py index 064d0ca78..c7a87d1e8 100644 --- a/wave_lang/kernel/wave/templates/dist_gemm.py +++ b/wave_lang/kernel/wave/templates/dist_gemm.py @@ -25,7 +25,6 @@ def get_dist_gemm_kernel( device_m: int = 1, device_n: int = 1, ): - assert device_m == 1 or device_n == 1, "Only one of device_m or device_n can be 1" if not isinstance(dynamic_dims, Sequence): dynamic_dims = (dynamic_dims,) * 3 @@ -47,10 +46,8 @@ def get_dist_gemm_kernel( # Expose user-constraints constraints: list[tkw.Constraint] = [] # Only support distribution along outer dimension - if device_m > 1: - constraints += [tkw.DeviceConstraint(M, DEVICE_M, 0)] - if device_n > 1: - constraints += [tkw.DeviceConstraint(N, DEVICE_N, 0)] + constraints += [tkw.DeviceConstraint(M, DEVICE_M, 0)] + constraints += [tkw.DeviceConstraint(N, DEVICE_N, 1)] constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.TilingConstraint(K, BLOCK_K)]