-
Notifications
You must be signed in to change notification settings - Fork 25
[wave2water] E2e execution matrix add test using water middle-end #575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
abeeefe to
8a0d767
Compare
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall nit: please be systematic about using or not using the full stop at the end of sentences in comments. I'm fine either way as long as it is locally consistent in a file.
| # CHECK-NOT: wave.read | ||
| # CHECK-NOT: wave.write | ||
| # CHECK-NOT: wave.mma | ||
| # CHECK-NOT: wave.iterate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These appear misplaced (or I don't understand something). Don't we want to have an mfma inside the scf.for? It is currently matched before. Also if we check for there not being a wave.read after an scf.for, there may well be one before it and this structure won't complain.
| let arguments = (ins | ||
| Arg<WaveSymbolAttr, "Iterator symbol">:$iterator, | ||
| Arg<Variadic<WaveTensorType>, "Carried values">:$iter_args, | ||
| Arg<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Carried values">:$iter_args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WaveTensorInRegister? Do we want to support non-register tensors here actually?
|
|
||
| let results = (outs | ||
| Res<Variadic<WaveTensorType>, "Yielded values">:$results | ||
| Res<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Yielded values">:$results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
|
|
||
| let arguments = (ins | ||
| Arg<Variadic<WaveTensorType>, "Yielded values">:$values | ||
| Arg<Variadic<AnyTypeOf<[WaveTensorType, AnyVectorOfAnyRank]>>, "Yielded values">:$values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as iterate
| if (!elementsPerThread) | ||
| return mlir::ChangeResult::NoChange; | ||
|
|
||
| // Only propagate to operands[0] (register), not operands[1] (memory) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trick to avoid hardcoding operand positions: getRegisterMutable().getOperandNumber().
| // TODO: Implement proper iteration variable handling | ||
| baseSymVals.emplace_back( | ||
| arith::ConstantIndexOp::create(rewriter, loc, 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please implement. There is also no guarantee that iterate is the direct parent and that the iterator symbol belongs to the immediately enclosing loop rather than some of the other nested loops.
|
|
||
| // Not from wave.allocate, use original type shape | ||
| ArrayRef<Attribute> symbols(originalType.getShape().begin(), originalType.getShape().end()); | ||
| return {symbols, AffineMap{}}; // Empty map means identity |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be simpler / more consistent to create an actual identity map?
| if (effectiveShapeInfo.has_value()) { | ||
| ArrayRef<Attribute> symbols = effectiveShapeInfo->first; | ||
| SmallVector<wave::WaveSymbolAttr> convertedSyms; | ||
| convertedSyms.reserve(symbols.size()); | ||
| for (Attribute attr : symbols) { | ||
| auto symbolAttr = dyn_cast<wave::WaveSymbolAttr>(attr); | ||
| if (!symbolAttr) { | ||
| return rewriter.notifyMatchFailure( | ||
| op, "expected WaveSymbolAttr in distributed shape"); | ||
| } | ||
| convertedSyms.push_back(symbolAttr); | ||
| } | ||
| orderedSyms = convertedSyms; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unlike shapes in the type, distrbuted shapes may be arbitrarily complex expressions AFAIK, which may need to be handled before beyond just the list of symbols.
| // Phase 2: Convert control flow operations | ||
| { | ||
| RewritePatternSet phase2Patterns(ctx); | ||
|
|
||
| // Add only control flow patterns | ||
| wave::populateWaveControlFlowLoweringPatterns(typeConverter, phase2Patterns); | ||
|
|
||
| if (failed(applyPartialConversion(op, target, std::move(phase2Patterns), | ||
| config))) { | ||
| op->emitError() << "failed to convert in phase 2"; | ||
| return WalkResult::interrupt(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather make a separate pass then, and potentially a named pipeline to manage that complexity
e1694bb to
0e44243
Compare
This showcases end to end execution using water transformations. However it's not straightforward to put this as a pytest, since it messes up with all the package setup... I guess a lit_test with assertions to check computed results is good for now. Fixes iree-org#584 Signed-off-by: tyb0807 <[email protected]>
|
Ok, I simplified this PR to only lowering matrix add test, to make review easier. Working on breaking down the changes to smaller PR. Will address your comments along the way. |
This showcases end to end execution using water transformations. However it's not straightforward to put this as a
pytest, since it messes up with all the package setup... I guess a lit_test with assertions to check computed results is good for now.Fixes #584