-
Notifications
You must be signed in to change notification settings - Fork 25
[water] Fix elements_per_thread propagation to ignore memory operands #623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implements elements per thread propagation for MMA operations. Fixes iree-org#608. Signed-off-by: tyb0807 <[email protected]>
Changes: - ReadOp: Only propagate attribute to result (register), ignore memory - WriteOp: Only validate/propagate with register operand, ignore memory This fixes false positives where memory resharding was incorrectly flagged as propagation errors. Fixes iree-org#622. Signed-off-by: tyb0807 <[email protected]>
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stacked commit LGTM modulo nits
|
|
||
| def ReadOp : WaveOp<"read", [ | ||
| WaveInferTypeOpInterface, IdentityTypeInferenceOpTrait, | ||
| WaveElementsPerThreadOpInterface, AttrBasedElementsPerThreadOpTrait, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This trait becomes unused, remove it entirely.
| // ReadOp doesn't propagate backward to memory operand | ||
| // Memory is decoupled from register dataflow for elements_per_thread |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ultra-nit: use full sentences in comments and terminate them with a full stop. This will make it more readable when the comment is inevitably reflowed by tooling.
| llvm::MutableArrayRef<wave::ElementsPerThreadLatticeValue>, | ||
| llvm::raw_ostream &errs) { | ||
| // WriteOp only validates that elements_per_thread attribute matches register | ||
| // operand Memory operand is ignored for propagation - you can write to memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like here, if you had a full stop, it would have been more readable.
| // Validate register operand (value_to_store) matches attribute | ||
| wave::ElementsPerThreadLatticeValue expectedValue(*elementsPerThread); | ||
| llvm::ArrayRef<wave::ElementsPerThreadLatticeValue> valueOnly = | ||
| operandElements.slice(0, 1); // Only first operand (value_to_store) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: use the getValueToStoreMutable().getOperandNumber() trick to avoid hardcoding positions.
| func.func @read_write_conflict(%mem: !wave.tensor<[@M] of f16, <global>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} { | ||
| %reg = wave.read %mem {elements_per_thread = 4} : (!wave.tensor<[@M] of f16, <global>>) -> !wave.tensor<[@M] of f16, <register>> | ||
| // expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and operand #0 (4)}} | ||
| // expected-error @below {{failed to propagate elements per thread backward: mismatch between elements_per_thread attribute (8) and register operand #0 (4)}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "register operand #0" reads like there are multiple register operands. I'd consider just keeping as is.
| // CHECK-LABEL: @write_backward_propagation | ||
| func.func @write_backward_propagation(%mem: !wave.tensor<[@M] of f16, <shared>>) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128}>} { | ||
| %cst = arith.constant 0.0 : f16 | ||
| // RegisterOp without explicit elements_per_thread - should get it from backward propagation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we punted on having explicit elements_per_thread on registers
Stacked PRs, do not merge.
Changes:
This fixes false positives where memory resharding was incorrectly
flagged as propagation errors.
Fixes #622.