Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def ExtractSliceOp : WaveOp<"extract_slice", [WaveInferTypeOpInterface, Identity

def ReadOp : WaveOp<"read", [
WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait,
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait]> {
let summary = "Reads from memory";
Expand Down Expand Up @@ -336,7 +336,7 @@ def RegisterOp : WaveOp<"register", [

def WriteOp : WaveOp<"write", [
WaveInferTypeOpInterface, NoOpTypeInferenceOpTrait,
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
CompatibleOperandsAndResultsIgnoreSpaceOpTrait,
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>]> {
let summary = "Writes into memory";
Expand Down
72 changes: 72 additions & 0 deletions water/lib/Dialect/Wave/IR/WaveOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,34 @@ LogicalResult ReadOp::verify() {
bounds.getMapping());
}

llvm::FailureOr<mlir::ChangeResult>
wave::ReadOp::propagateElementsPerThreadForward(
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
llvm::raw_ostream &errs) {
// ReadOp only propagates elements_per_thread attribute to result (register).
// Memory operand is ignored for propagation - you can read any number of
// elements from memory regardless of how many were written.
std::optional<int64_t> elementsPerThread = getElementsPerThread();
if (!elementsPerThread)
return mlir::ChangeResult::NoChange;

wave::ElementsPerThreadLatticeValue expectedResult(*elementsPerThread);
return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedResult, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
resultElements, "elements_per_thread attribute", "", "result", errs);
}

llvm::FailureOr<mlir::ChangeResult>
wave::ReadOp::propagateElementsPerThreadBackward(
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> resultElements,
llvm::raw_ostream &) {
// ReadOp doesn't propagate backward to memory operand.
// Memory is decoupled from register dataflow for elements_per_thread.
return mlir::ChangeResult::NoChange;
}

//-----------------------------------------------------------------------------
// RegisterOp
//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -1529,6 +1557,50 @@ LogicalResult WriteOp::verify() {
bounds.getMapping());
}

llvm::FailureOr<mlir::ChangeResult>
wave::WriteOp::propagateElementsPerThreadForward(
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::raw_ostream &errs) {
// WriteOp only validates that elements_per_thread attribute matches register
// operand. Memory operand is ignored for propagation - you can write to
// memory with any layout.
std::optional<int64_t> elementsPerThread = getElementsPerThread();
if (!elementsPerThread)
return mlir::ChangeResult::NoChange;

// Validate register operand (value_to_store) matches attribute.
wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread);
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> valueOnly =
operandElements.slice(getValueToStoreMutable().getOperandNumber(), 1);

return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedValue, valueOnly,
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>(),
"elements_per_thread attribute", "operand", "", errs);
}

llvm::FailureOr<mlir::ChangeResult>
wave::WriteOp::propagateElementsPerThreadBackward(
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> operandElements,
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::raw_ostream &errs) {
// WriteOp only propagates backward to register operand (value_to_store).
// Memory operand is ignored - you can write any layout to memory.
std::optional<int64_t> elementsPerThread = getElementsPerThread();
if (!elementsPerThread)
return mlir::ChangeResult::NoChange;

// Propagate to register operand only.
wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread);
llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue> valueOnly =
operandElements.slice(getValueToStoreMutable().getOperandNumber(), 1);

return wave::detail::checkAndPropagateElementsPerThreadFromConstant(
expectedValue, llvm::ArrayRef<wave::ElementsPerThreadLatticeValue>(),
valueOnly, "elements_per_thread attribute", "", "operand", errs);
}

// Propagate index expressions forward from the operands to the result of the
// WriteOp. Since WriteOp has no results, this is a no-op.
llvm::FailureOr<mlir::ChangeResult> wave::WriteOp::propagateIndexExprsForward(
Expand Down
72 changes: 66 additions & 6 deletions water/test/Dialect/Wave/propagate-elements-per-thread.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,66 @@ module {

// -----

// CHECK: #wave.normal_form<full_types,memory_only_types>
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
// CHECK-LABEL: @memory_resharding_allowed
func.func @memory_resharding_allowed(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
%cst = arith.constant 0.0 : f16
// Register gets 8 elements per thread from write operation's backward propagation.
// CHECK: wave.register {{.*}} : vector<8xf16>
%reg8 = wave.register %cst : !wave.tensor<[@M] of f16, <register>>

// Write 8 elements per thread to memory.
// CHECK: wave.write {{.*}} : vector<8xf16>, !wave.tensor<[@M] of f16, <shared>>
wave.write %reg8, %mem {elements_per_thread = 8} : !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <shared>>

// Read 4 elements per thread from same memory - this should be allowed (memory resharding).
// CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, <shared>>) -> vector<4xf16>
%reg4 = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, <shared>>) -> !wave.tensor<[@M] of f16, <register>>

return
}
}

// -----

// CHECK: #wave.normal_form<full_types,memory_only_types>
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
// CHECK-LABEL: @write_backward_propagation
func.func @write_backward_propagation(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
%cst = arith.constant 0.0 : f16
// RegisterOp doesn't have explicit elements_per_thread - should get it from backward propagation.
// CHECK: wave.register {{.*}} : vector<4xf16>
%reg = wave.register %cst : !wave.tensor<[@M] of f16, <register>>

// WriteOp should propagate elements_per_thread backward to register operand.
// CHECK: wave.write {{.*}} : vector<4xf16>, !wave.tensor<[@M] of f16, <shared>>
wave.write %reg, %mem {elements_per_thread = 4} : !wave.tensor<[@M] of f16, <register>>, !wave.tensor<[@M] of f16, <shared>>

return
}
}

// -----

// CHECK: #wave.normal_form<full_types,memory_only_types>
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
// CHECK-LABEL: @read_register_propagation
func.func @read_register_propagation(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
// ReadOp should only propagate to its register result, not validate memory.
// CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, <shared>>) -> vector<6xf16>
%reg = wave.read %mem {elements_per_thread = 6} : (!wave.tensor<[@M] of f16, <shared>>) -> !wave.tensor<[@M] of f16, <register>>

// Downstream operation should get 6 elements per thread.
// CHECK: wave.exp2 {{.*}} : (vector<6xf16>) -> vector<6xf16>
%result = wave.exp2 %reg : (!wave.tensor<[@M] of f16, <register>>) -> !wave.tensor<[@M] of f16, <register>>

return
}
}

// -----

module attributes {wave.normal_form = #wave.normal_form<full_types>} {
func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, <global>>, %mem2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
// LHS without elements_per_thread - will be computed from RHS + MMA constraints.
Expand All @@ -174,7 +234,7 @@ func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, <global
// ACC properly initialized through read operation.
%acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>

// LHS elements_per_thread computed via MMA backward propagation
// LHS elements_per_thread computed via MMA backward propagation.
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
return
}
Expand All @@ -194,7 +254,7 @@ func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, <global
// ACC properly initialized through read operation.
%acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>

// RHS elements_per_thread computed via MMA backward propagation
// RHS elements_per_thread computed via MMA backward propagation.
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
return
}
Expand All @@ -205,7 +265,7 @@ func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, <global
// Test MMA can compute both LHS and RHS when both are uninitialized
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
func.func @mma_compute_both_lhs_rhs(%mem1: !wave.tensor<[@M, @K] of f16, <global>>, %mem2: !wave.tensor<[@N, @K] of f16, <global>>, %mem3: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
// Both LHS and RHS without elements_per_thread - can compute from MMA formulas
// Both LHS and RHS without elements_per_thread - can compute from MMA formulas.
%lhs_init = arith.constant 0.0 : f16
%lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, <register>>
%rhs_init = arith.constant 0.0 : f16
Expand All @@ -215,7 +275,7 @@ module attributes {wave.normal_form = #wave.normal_form<full_types>} {
%acc = wave.read %mem3 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>

// With proper MMA formulas, we can now compute both LHS and RHS from constraints,
// so this should succeed instead of failing
// so this should succeed instead of failing.
%result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
return
}
Expand All @@ -226,14 +286,14 @@ module attributes {wave.normal_form = #wave.normal_form<full_types>} {
// Test MMA error when operand has wrong elements_per_thread
module attributes {wave.normal_form = #wave.normal_form<full_types>} {
func.func @mma_operand_mismatch(%mem1: !wave.tensor<[@M, @K] of f16, <global>>, %mem2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint<threads_per_wave = 32, waves_per_block = [1, 1, 1], mma_type = #wave.mma_kind<f32_16x16x16_f16>, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} {
// LHS with wrong elements_per_thread (should be 8, not 4)
// LHS with wrong elements_per_thread (should be 8, not 4).
%lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, <global>>) -> !wave.tensor<[@M, @K] of f16, <register>>

// RHS without elements_per_thread - will be computed from MMA constraints.
%rhs_init = arith.constant 0.0 : f16
%rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, <register>>

// ACC properly initialized
// ACC properly initialized.
%acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, <global>>) -> !wave.tensor<[@M, @N] of f32, <register>>

// expected-error @below {{failed to propagate elements per thread backward: mismatch between computed from MMA kind (8) and LHS operand #0 (4)}}
Expand Down
Loading