-
Notifications
You must be signed in to change notification settings - Fork 555
[MLIR][TORCH] Add op verifier for aten.index_put op #4184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit adds a check to verify whether the shapes of the `values` operand of index_put op is broadcast compatible with the indexing result or not. Signed-off-by: Vivek Khandelwal <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was expecting that the op would get a verifier that can be ran at any point. For example this can not be used to verify input IR that contains said op.
What is the reasoning to do the check during conversion and not have a verifier?
Also it is desirable to have a test to show that the check works as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mentioned that for dynamic shapes the lowering would insert runtime asserts. How does this work?
if (!indexTy.hasSizes()) | ||
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the verification fail if does not have sizes. I assume this is the shape. What does it mean even if there is no size?
if (!inputType.hasSizes()) | ||
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we fail here as the other case?
if (!valuesType.hasSizes()) | ||
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we fail here as the other case?
|
||
// Determine the common broadcast shape of all the index tensors. | ||
SmallVector<int64_t> | ||
getIndexBroadcastShape(SmallVector<Torch::ValueTensorType> indicesTypes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indices may not be broadcast compatible on their own. E.g.
t = torch.empty([3, 4, 5], dtype=torch.int32)
indices = [
torch.zeros(size=[5], dtype=torch.long),
torch.zeros(size=[3], dtype=torch.long),
]
t[indices]
raises
t[indices]
~^^^^^^^^^
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [5], [3]
This function should be failable.
The resulting shape after indexing is more complicated. I think this verification would reject valid IR.
Example:
indices = [
torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
torch.Size([7, 4, 5])
indices = [
torch.zeros(size=[7], dtype=torch.long),
torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
torch.Size([7, 5])
indices = [
torch.zeros(size=[7], dtype=torch.long),
torch.zeros(size=[7], dtype=torch.long),
torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
torch.Size([7])
indices = [
torch.zeros(size=[7], dtype=torch.long),
torch.zeros(size=[7], dtype=torch.long),
torch.zeros(size=[7], dtype=torch.long),
torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
print(t[indices].shape)
~^^^^^^^^^
IndexError: too many indices for tensor of dimension 3
Here I am using t[indices]
indexing because according to the doc torch.Tensor.index_put_
is equivalent to tensor[indices] = values
.
@@ -158,6 +158,10 @@ LogicalResult getPermutedType(BaseTensorType inType, | |||
SmallVector<int64_t> permuteDims, | |||
Type &permutedType); | |||
|
|||
// Check whether the given shapes of 2 tensors are broadcastable or not. | |||
LogicalResult areStaticallyBroadcastCompatible(ArrayRef<int64_t> shapeA, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am actually surprised that there are no functions already available in Torch MLIR to get the shape after broadcasting or to check if broadcastable.
This commit adds an op verifier for the index_put op.