Skip to content

[MLIR][TORCH] Add op verifier for aten.index_put op #4184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

vivekkhandelwal1
Copy link
Collaborator

@vivekkhandelwal1 vivekkhandelwal1 commented May 16, 2025

This commit adds an op verifier for the index_put op.

This commit adds a check to verify whether the shapes of the
`values` operand of index_put op is broadcast compatible with
the indexing result or not.

Signed-off-by: Vivek Khandelwal <[email protected]>
Copy link
Collaborator

@sogartar sogartar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting that the op would get a verifier that can be ran at any point. For example this can not be used to verify input IR that contains said op.
What is the reasoning to do the check during conversion and not have a verifier?

Also it is desirable to have a test to show that the check works as expected.

@vivekkhandelwal1 vivekkhandelwal1 requested a review from sogartar May 19, 2025 08:48
@vivekkhandelwal1 vivekkhandelwal1 changed the title [MLIR][TORCH] Add shape verifier check for index_put op [MLIR][TORCH] Add op verifier for aten.index_put op May 19, 2025
Copy link
Collaborator

@sogartar sogartar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mentioned that for dynamic shapes the lowering would insert runtime asserts. How does this work?

Comment on lines +6135 to +6136
if (!indexTy.hasSizes())
return success();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the verification fail if does not have sizes. I assume this is the shape. What does it mean even if there is no size?

Comment on lines +6142 to +6143
if (!inputType.hasSizes())
return success();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we fail here as the other case?

Comment on lines +6147 to +6148
if (!valuesType.hasSizes())
return success();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we fail here as the other case?


// Determine the common broadcast shape of all the index tensors.
SmallVector<int64_t>
getIndexBroadcastShape(SmallVector<Torch::ValueTensorType> indicesTypes) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indices may not be broadcast compatible on their own. E.g.

t = torch.empty([3, 4, 5], dtype=torch.int32)
indices = [
    torch.zeros(size=[5], dtype=torch.long),
    torch.zeros(size=[3], dtype=torch.long),
]
t[indices]

raises

    t[indices]
    ~^^^^^^^^^
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [5], [3]

This function should be failable.

The resulting shape after indexing is more complicated. I think this verification would reject valid IR.
Example:

indices = [
    torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
torch.Size([7, 4, 5])

indices = [
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
torch.Size([7, 5])

indices = [
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
torch.Size([7])

indices = [
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
    torch.zeros(size=[7], dtype=torch.long),
]
print(t[indices].shape)
    print(t[indices].shape)
          ~^^^^^^^^^
IndexError: too many indices for tensor of dimension 3

Here I am using t[indices] indexing because according to the doc torch.Tensor.index_put_ is equivalent to tensor[indices] = values.

@@ -158,6 +158,10 @@ LogicalResult getPermutedType(BaseTensorType inType,
SmallVector<int64_t> permuteDims,
Type &permutedType);

// Check whether the given shapes of 2 tensors are broadcastable or not.
LogicalResult areStaticallyBroadcastCompatible(ArrayRef<int64_t> shapeA,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am actually surprised that there are no functions already available in Torch MLIR to get the shape after broadcasting or to check if broadcastable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants