Skip to content

Conversation

@tyb0807
Copy link
Contributor

@tyb0807 tyb0807 commented Dec 23, 2025

Stacked PRs, do not merge.

Changes:

  • ReadOp: Only propagate attribute to result (register), ignore memory
  • WriteOp: Only validate/propagate with register operand, ignore memory

This fixes false positives where memory resharding was incorrectly
flagged as propagation errors.

Fixes #622.

@tyb0807 tyb0807 requested a review from ftynse December 23, 2025 00:45
Implements elements per thread propagation for MMA operations.

Fixes iree-org#608.

Signed-off-by: tyb0807 <[email protected]>
Changes:
- ReadOp: Only propagate attribute to result (register), ignore memory
- WriteOp: Only validate/propagate with register operand, ignore memory

This fixes false positives where memory resharding was incorrectly
flagged as propagation errors.

Fixes iree-org#622.

Signed-off-by: tyb0807 <[email protected]>
Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stacked commit LGTM modulo nits


def ReadOp : WaveOp<"read", [
WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait,
WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trait becomes unused, remove it entirely.

Comment on lines +1353 to +1354
// ReadOp doesn't propagate backward to memory operand
// Memory is decoupled from register dataflow for elements_per_thread
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultra-nit: use full sentences in comments and terminate them with a full stop. This will make it more readable when the comment is inevitably reflowed by tooling.

llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>,
llvm::raw_ostream &errs) {
// WriteOp only validates that elements_per_thread attribute matches register
// operand Memory operand is ignored for propagation - you can write to memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like here, if you had a full stop, it would have been more readable.

// Validate register operand (value_to_store) matches attribute
wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread);
llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> valueOnly =
operandElements.slice(0, 1); // Only first operand (value_to_store)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: use the getValueToStoreMutable().getOperandNumber() trick to avoid hardcoding positions.

func.func @read_write_conflict(%mem: !wave.tensor<[@M] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
%reg = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, <global>>) -> !wave.tensor<[@M] of f16, <register>>
// expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and operand #0 (4)}}
// expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and register operand #0 (4)}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "register operand #0" reads like there are multiple register operands. I'd consider just keeping as is.

// CHECK-LABEL: @write_backward_propagation
func.func @write_backward_propagation(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} {
%cst = arith.constant 0.0 : f16
// RegisterOp without explicit elements_per_thread - should get it from backward propagation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we punted on having explicit elements_per_thread on registers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix elements_per_thread propagation to ignore memory operands

2 participants