diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 92b8abb36d..7e86869d93 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4218,23 +4218,9 @@ def _aten_index_onnx( # ] # # Need to transpose the result of GatherND to match this axes ordering. - first_not_none_position = reordered_positions[0] # x_None_front_m + 1 - starting_position_of_none_in_back = ( - advanced_indexing_rank + first_not_none_position - ) # x_None_back_1 - result_rank = self_rank - len(not_none_indices) + advanced_indexing_rank - perm = [ - *range( - advanced_indexing_rank, starting_position_of_none_in_back - ), # None_front_1...x_None_back_1 - *range(advanced_indexing_rank), # 0...len(broadcasted_shape) - *range( - starting_position_of_none_in_back, - result_rank, - ), # None_back_1...None_back_m - ] + inverse_positions = np.argsort(reordered_positions).tolist() - return op.Transpose(self, perm=perm) + return op.Transpose(self, perm=inverse_positions) @torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) @@ -4324,91 +4310,57 @@ def aten_index_copy( @torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True) def aten_index_put( self: TReal, - indices: Sequence[INT64], + indices: Sequence[Optional[INT64]], values: TReal, accumulate: bool = False, ) -> TReal: - """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor - - See implementation of `torch.onnx.symbolic_opset11.index_put - `_. - """ - - def _make_reshape_list_broadcastable(reshape_list, values_shape): - # Remove ones until the rank of reshape_list matches values_shape. - while len(reshape_list) > len(values_shape) and 1 in reshape_list: - reshape_list.remove(1) - - # Now ensure each dimension is broadcastable: - # This is mandatory when mixing basic and advanced indexing - # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) - # the reshape list should be : [[2, 1], [1, 3], [2, 1]] - for i, r in enumerate(reshape_list): - if r not in (1, values_shape[i]): - value_index = values_shape.index(r) - # Swap elements - # For the example above the current reshape list is [1, 2] for last dim, - # to make it broadcastable, we swap the elements - reshape_list[value_index], reshape_list[i] = r, 1 - - return reshape_list + """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" # Ensure the number of indices matches the tensor rank. self_rank = len(self.shape) - if len(indices) < self_rank: - indices = list(indices) + [None] * (self_rank - len(indices)) - - # Get values shape - values_shape = tuple(values.shape) - - index_vectors = [] - for i in range(self_rank): - if indices[i] is None: - # For a full slice along dim i, create a range index [0, self.shape[i]). - idx = op.Range(0, self.shape[i], 1) - reshape_update = self.shape[i] - else: - idx = indices[i] - reshape_update = math.prod(idx.shape) - # when Index is more than 1D, flatten it and also the values shape - # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) - # Indices -> (2*4,) and values shape (2*4, 32) - if len(idx.shape) > 1: - values_shape = (reshape_update, *values_shape[len(idx.shape) :]) - - # Flatten index (always working with 1D index in each dim) - idx = op.Reshape(idx, [-1]) - # Create a reshape pattern: one value per index dimension, - # with the current dimension set to the update size. - reshape_list = [1] * len(indices) - reshape_list[i] = reshape_update + # 1. Reorder input tensor so that None-indexed axes are last + # This logic is identical to the aten.index implementation. + reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i)) + remaining_dims = [i for i in range(self_rank) if i not in reordered_positions] + reordered_positions.extend(remaining_dims) - # Adjust the reshape list to match the values shape. - reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) + # Transpose the input data to group the indexed dimensions first + transposed_self = op.Transpose(self, perm=reordered_positions) - # Reshape and expand the index. - idx = op.Reshape(idx, reshape_list, allowzero=True) - idx = op.Expand(idx, values_shape) + # 2. Prepare indices for ScatterND + # This logic is also identical. + not_none_indices = [idx for idx in indices if idx is not None] + broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices) - # Flatten the index to 1D and unsqueeze to form a column vector. - idx = op.Reshape(idx, [-1]) - idx = op.Unsqueeze(idx, axes=[1]) - index_vectors.append(idx) + final_index_parts = [] + for idx in not_none_indices: + # Unsqueeze is needed to make indices broadcastable to the common shape + expanded_idx = op.Expand(idx, broadcast_shape) + final_index_parts.append(op.Unsqueeze(expanded_idx, [-1])) - # Concatenate the index vectors along axis=1 to form the final indices. - new_index = op.Concat(*index_vectors, axis=1) + final_index = op.Concat(*final_index_parts, axis=-1) - # Flatten values to match the indices - flat_values = op.Reshape(values, [-1]) + # 3. Prepare the 'updates' tensor (values) + # The 'values' tensor must be broadcast to match the shape of the + # broadcasted indices. + expanded_values = op.Expand(values, broadcast_shape) + # TODO: Handle None + expanded_values = op.Transpose(expanded_values, perm=reordered_positions) + # 4. Perform the scatter operation if accumulate: - result = op.ScatterND(self, new_index, flat_values, reduction="add") + scattered_data = op.ScatterND(transposed_self, final_index, expanded_values, reduction="add") else: - result = op.ScatterND(self, new_index, flat_values) + scattered_data = op.ScatterND(transposed_self, final_index, expanded_values) - return result + # 5. Restore original dimension order + # The output of ScatterND has the same shape as the transposed input. + # We must apply an "inverse" transpose to get the final result. + inverse_positions = np.argsort(reordered_positions).tolist() + final_output = op.Transpose(scattered_data, perm=inverse_positions) + return final_output @torch_op("aten::index_put", trace_only=True) def aten_index_put_bool(