diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 145dcc861..f64c4b0b0 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -282,7 +282,7 @@ def ExtractSliceOp : WaveOp<"extract_slice", [WaveInferTypeOpInterface, Identity def ReadOp : WaveOp<"read", [ WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait, - WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait, + DeclareOpInterfaceMethods, CompatibleOperandsAndResultsIgnoreSpaceOpTrait, WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait]> { let summary = "Reads from memory"; @@ -336,7 +336,7 @@ def RegisterOp : WaveOp<"register", [ def WriteOp : WaveOp<"write", [ WaveInferTypeOpInterface, NoOpTypeInferenceOpTrait, - WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait, + DeclareOpInterfaceMethods, CompatibleOperandsAndResultsIgnoreSpaceOpTrait, DeclareOpInterfaceMethods]> { let summary = "Writes into memory"; diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index 22f743b99..172357c12 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1448,6 +1448,34 @@ LogicalResult ReadOp::verify() { bounds.getMapping()); } +llvm::FailureOr +wave::ReadOp::propagateElementsPerThreadForward( + llvm::ArrayRef, + llvm::MutableArrayRef resultElements, + llvm::raw_ostream &errs) { + // ReadOp only propagates elements_per_thread attribute to result (register). + // Memory operand is ignored for propagation - you can read any number of + // elements from memory regardless of how many were written. + std::optional elementsPerThread = getElementsPerThread(); + if (!elementsPerThread) + return mlir::ChangeResult::NoChange; + + wave::ElementsPerThreadLatticeValue expectedResult(*elementsPerThread); + return wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedResult, llvm::ArrayRef(), + resultElements, "elements_per_thread attribute", "", "result", errs); +} + +llvm::FailureOr +wave::ReadOp::propagateElementsPerThreadBackward( + llvm::MutableArrayRef, + llvm::ArrayRef resultElements, + llvm::raw_ostream &) { + // ReadOp doesn't propagate backward to memory operand. + // Memory is decoupled from register dataflow for elements_per_thread. + return mlir::ChangeResult::NoChange; +} + //----------------------------------------------------------------------------- // RegisterOp //----------------------------------------------------------------------------- @@ -1529,6 +1557,50 @@ LogicalResult WriteOp::verify() { bounds.getMapping()); } +llvm::FailureOr +wave::WriteOp::propagateElementsPerThreadForward( + llvm::ArrayRef operandElements, + llvm::MutableArrayRef, + llvm::raw_ostream &errs) { + // WriteOp only validates that elements_per_thread attribute matches register + // operand. Memory operand is ignored for propagation - you can write to + // memory with any layout. + std::optional elementsPerThread = getElementsPerThread(); + if (!elementsPerThread) + return mlir::ChangeResult::NoChange; + + // Validate register operand (value_to_store) matches attribute. + wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread); + llvm::ArrayRef valueOnly = + operandElements.slice(getValueToStoreMutable().getOperandNumber(), 1); + + return wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedValue, valueOnly, + llvm::MutableArrayRef(), + "elements_per_thread attribute", "operand", "", errs); +} + +llvm::FailureOr +wave::WriteOp::propagateElementsPerThreadBackward( + llvm::MutableArrayRef operandElements, + llvm::ArrayRef, + llvm::raw_ostream &errs) { + // WriteOp only propagates backward to register operand (value_to_store). + // Memory operand is ignored - you can write any layout to memory. + std::optional elementsPerThread = getElementsPerThread(); + if (!elementsPerThread) + return mlir::ChangeResult::NoChange; + + // Propagate to register operand only. + wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread); + llvm::MutableArrayRef valueOnly = + operandElements.slice(getValueToStoreMutable().getOperandNumber(), 1); + + return wave::detail::checkAndPropagateElementsPerThreadFromConstant( + expectedValue, llvm::ArrayRef(), + valueOnly, "elements_per_thread attribute", "", "operand", errs); +} + // Propagate index expressions forward from the operands to the result of the // WriteOp. Since WriteOp has no results, this is a no-op. llvm::FailureOr wave::WriteOp::propagateIndexExprsForward( diff --git a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir index d9ef3382b..9b592df8f 100644 --- a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir +++ b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir @@ -162,6 +162,66 @@ module { // ----- +// CHECK: #wave.normal_form +module attributes {wave.normal_form = #wave.normal_form} { +// CHECK-LABEL: @memory_resharding_allowed +func.func @memory_resharding_allowed(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} { + %cst = arith.constant 0.0 : f16 + // Register gets 8 elements per thread from write operation's backward propagation. + // CHECK: wave.register {{.*}} : vector<8xf16> + %reg8 = wave.register %cst : !wave.tensor<[@M] of f16, > + + // Write 8 elements per thread to memory. + // CHECK: wave.write {{.*}} : vector<8xf16>, !wave.tensor<[@M] of f16, > + wave.write %reg8, %mem {elements_per_thread = 8} : !wave.tensor<[@M] of f16, >, !wave.tensor<[@M] of f16, > + + // Read 4 elements per thread from same memory - this should be allowed (memory resharding). + // CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, >) -> vector<4xf16> + %reg4 = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > + + return +} +} + +// ----- + +// CHECK: #wave.normal_form +module attributes {wave.normal_form = #wave.normal_form} { +// CHECK-LABEL: @write_backward_propagation +func.func @write_backward_propagation(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} { + %cst = arith.constant 0.0 : f16 + // RegisterOp doesn't have explicit elements_per_thread - should get it from backward propagation. + // CHECK: wave.register {{.*}} : vector<4xf16> + %reg = wave.register %cst : !wave.tensor<[@M] of f16, > + + // WriteOp should propagate elements_per_thread backward to register operand. + // CHECK: wave.write {{.*}} : vector<4xf16>, !wave.tensor<[@M] of f16, > + wave.write %reg, %mem {elements_per_thread = 4} : !wave.tensor<[@M] of f16, >, !wave.tensor<[@M] of f16, > + + return +} +} + +// ----- + +// CHECK: #wave.normal_form +module attributes {wave.normal_form = #wave.normal_form} { +// CHECK-LABEL: @read_register_propagation +func.func @read_register_propagation(%mem: !wave.tensor<[@M] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} { + // ReadOp should only propagate to its register result, not validate memory. + // CHECK: wave.read {{.*}} : (!wave.tensor<[@M] of f16, >) -> vector<6xf16> + %reg = wave.read %mem {elements_per_thread = 6} : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > + + // Downstream operation should get 6 elements per thread. + // CHECK: wave.exp2 {{.*}} : (vector<6xf16>) -> vector<6xf16> + %result = wave.exp2 %reg : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > + + return +} +} + +// ----- + module attributes {wave.normal_form = #wave.normal_form} { func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { // LHS without elements_per_thread - will be computed from RHS + MMA constraints. @@ -174,7 +234,7 @@ func.func @mma_compute_lhs_from_rhs(%mem1: !wave.tensor<[@N, @K] of f16, >) -> !wave.tensor<[@M, @N] of f32, > - // LHS elements_per_thread computed via MMA backward propagation + // LHS elements_per_thread computed via MMA backward propagation. %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > return } @@ -194,7 +254,7 @@ func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, >) -> !wave.tensor<[@M, @N] of f32, > - // RHS elements_per_thread computed via MMA backward propagation + // RHS elements_per_thread computed via MMA backward propagation. %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > return } @@ -205,7 +265,7 @@ func.func @mma_compute_rhs_from_lhs(%mem1: !wave.tensor<[@M, @K] of f16, } { func.func @mma_compute_both_lhs_rhs(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@N, @K] of f16, >, %mem3: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { - // Both LHS and RHS without elements_per_thread - can compute from MMA formulas + // Both LHS and RHS without elements_per_thread - can compute from MMA formulas. %lhs_init = arith.constant 0.0 : f16 %lhs = wave.register %lhs_init : !wave.tensor<[@M, @K] of f16, > %rhs_init = arith.constant 0.0 : f16 @@ -215,7 +275,7 @@ module attributes {wave.normal_form = #wave.normal_form} { %acc = wave.read %mem3 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > // With proper MMA formulas, we can now compute both LHS and RHS from constraints, - // so this should succeed instead of failing + // so this should succeed instead of failing. %result = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > return } @@ -226,14 +286,14 @@ module attributes {wave.normal_form = #wave.normal_form} { // Test MMA error when operand has wrong elements_per_thread module attributes {wave.normal_form = #wave.normal_form} { func.func @mma_operand_mismatch(%mem1: !wave.tensor<[@M, @K] of f16, >, %mem2: !wave.tensor<[@M, @N] of f32, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 16, N = 16, K = 16}>, wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1, N = 1, K = 16}, max_bits_per_load = 128>]} { - // LHS with wrong elements_per_thread (should be 8, not 4) + // LHS with wrong elements_per_thread (should be 8, not 4). %lhs = wave.read %mem1 {elements_per_thread = 4} : (!wave.tensor<[@M, @K] of f16, >) -> !wave.tensor<[@M, @K] of f16, > // RHS without elements_per_thread - will be computed from MMA constraints. %rhs_init = arith.constant 0.0 : f16 %rhs = wave.register %rhs_init : !wave.tensor<[@N, @K] of f16, > - // ACC properly initialized + // ACC properly initialized. %acc = wave.read %mem2 {elements_per_thread = 8} : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > // expected-error @below {{failed to propagate elements per thread backward: mismatch between computed from MMA kind (8) and LHS operand #0 (4)}}