Skip to content

Commit 85556c8

Browse files
Copilotgramalingamjustinchuby
authored
Cleanup elimination of redundant scatter-nd: consolidate rules and improve organization (#2426)
This PR consolidates redundant ScatterND elimination logic into a dedicated module and improves code organization as requested in the issue. ## Changes Made ### 1. **Moved redundant ScatterND rule** from `collapse_slices.py` to `redundant_scatter_nd.py` - Extracted `_potential_redundant_scatternd`, `_identity_to_updates`, and `_check_if_redundant_scatternd` functions - Converted to class-based `ScatterAllStatic` rule for consistency with existing patterns - Removed the rule from `collapse_slices.py` rules list ### 2. **Distinguished between static vs dynamic scenarios** with clear naming: - **`ScatterAllDynamic`** (renamed from `ScatterAll`): Handles cases where indices are constructed dynamically using Range operations but axis dimension is statically known - **`ScatterAllStatic`** (new): Handles cases where indices are statically known constants in form `[[0], [1], ..., [n-1]]` ### 3. **Moved corresponding test case** from `collapse_slices_test.py` to `redundant_scatter_nd_test.py` - Test renamed to `test_redundant_scatter_nd_static_indices` for clarity - Original test renamed to `test_redundant_scatter_nd_dynamic_indices` - Both tests validate their respective optimization scenarios ### 4. **Updated documentation** to clearly explain both rules and their use cases ## Key Benefits - **Better organization**: All ScatterND redundancy elimination logic is now in one dedicated module - **Clear separation of concerns**: Static vs dynamic index scenarios are clearly distinguished - **Consistent patterns**: Both rules follow the same class-based structure - **Improved maintainability**: Clear naming and documentation for future developers ## Verification All tests pass, including: - Existing dynamic indices optimization (complex Range-based pattern) - Moved static indices optimization (simple constant indices pattern) - No regressions in slice optimization functionality The changes maintain full backward compatibility while improving code organization and clarity. Fixes #2425. <!-- START COPILOT CODING AGENT TIPS --> --- 💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click [here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to start the survey. --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: gramalingam <[email protected]> Co-authored-by: justinchuby <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent 87d6f11 commit 85556c8

File tree

4 files changed

+110
-97
lines changed

4 files changed

+110
-97
lines changed

onnxscript/rewriter/collapse_slices.py

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -71,71 +71,17 @@ def _identity_to_itself(op, data, **_):
7171
return op.Identity(data)
7272

7373

74-
def _identity_to_updates(op, data, indices, updates, **_):
75-
"""Return the updates as the output.
76-
77-
This is used when the ScatterND is redundant in terms of
78-
updating the whole data with the updates.
79-
80-
"""
81-
return op.Identity(updates)
82-
83-
8474
def _potential_redundant_slice(op, data, starts, ends, axes, steps):
8575
"""To identify a slice op"""
8676
return op.Slice(data, starts, ends, axes, steps)
8777

8878

89-
def _potential_redundant_scatternd(op, data, indices, updates):
90-
"""To identify a ScatterND op"""
91-
return op.ScatterND(data, indices, updates)
92-
93-
94-
def _check_if_redundant_scatternd(
95-
context,
96-
data: ir.Value,
97-
indices: ir.Value,
98-
updates: ir.Value,
99-
**_,
100-
):
101-
"""If the indices is the same length as the first dim of data, and the shape of updates is equal to data, we can simply swap the whole value."""
102-
del context # Reserved for future extensions
103-
104-
# To validate data can be replaced directly by updates, we need to check the following:
105-
# 1. they have the same shape
106-
if data.shape is None:
107-
logger.info("The value 'data' shape is not statically known.")
108-
return False
109-
if updates.shape is None:
110-
logger.info("The value 'updates' shape is not statically known.")
111-
return False
112-
if data.shape != updates.shape:
113-
logger.info("The shape of 'data' and 'updates' are different.")
114-
return False
115-
116-
# 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
117-
if indices.const_value is None:
118-
logger.info("The value 'indices' is not statically known.")
119-
return False
120-
if indices.const_value.numpy().tolist() != [[i] for i in range(data.shape[0])]: # type: ignore[arg-type]
121-
logger.info("The 'indices' is not referring to the whole data.")
122-
return False
123-
124-
return True
125-
126-
12779
# Register the rewrite rules
12880
remove_redundant_slice = pattern.RewriteRule(
12981
_potential_redundant_slice,
13082
_identity_to_itself,
13183
_check_if_redundant_slice,
13284
)
13385

134-
remove_redundant_scatternd = pattern.RewriteRule(
135-
_potential_redundant_scatternd,
136-
_identity_to_updates,
137-
_check_if_redundant_scatternd,
138-
)
139-
14086
# NOTE: The order of the rules is important. Larger pattern should be checked first.
141-
rules = pattern.RewriteRuleSet([remove_redundant_slice, remove_redundant_scatternd])
87+
rules = pattern.RewriteRuleSet([remove_redundant_slice])

onnxscript/rewriter/collapse_slices_test.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -82,35 +82,3 @@ def test_slice_pattern_is_not_matched_when_input_is_dynamic(self):
8282
model = ir.serde.deserialize_model(model_proto)
8383
count = collapse_slices.rules.apply_to_model(model)
8484
self.assertEqual(count, 0)
85-
86-
def test_scatternd_is_redundant_when_it_is_updating_the_whole_input_in_order(self):
87-
model_proto = onnx.parser.parse_model(
88-
"""
89-
<ir_version: 7, opset_import: [ "" : 17]>
90-
agraph (float[112, 16, 512] data, float[112, 16, 512] updates) => (float[112, 16, 512] output)
91-
{
92-
output = ScatterND (data, indices, updates)
93-
}
94-
"""
95-
)
96-
# Use inserted initializers to avoid manually coding the large constants
97-
indices = np.arange(112).reshape(112, 1).astype(np.int64)
98-
model = ir.serde.deserialize_model(model_proto)
99-
# from numpy to ir.Tensor
100-
indices_ir_tensor = ir.Tensor(
101-
name="indices",
102-
value=indices,
103-
)
104-
# assign the tensor to a value
105-
indices = model.graph[0].inputs[1]
106-
indices.const_value = indices_ir_tensor
107-
model.graph.initializers["indices"] = indices
108-
original_model_proto = ir.serde.serialize_model(model)
109-
110-
count = collapse_slices.rules.apply_to_model(model)
111-
self.assertEqual(count, 1)
112-
self.assertEqual(len(model.graph), 1)
113-
self.assertIn("Identity", [node.op_type for node in model.graph])
114-
115-
input = np.random.rand(112, 16, 512).astype(np.float32)
116-
testing.assert_numerically_equal(original_model_proto, model, (input, input))

onnxscript/rewriter/redundant_scatter_nd.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
"""Rewrite rule to eliminate redundant ScatterND operations.
3+
"""Rewrite rules to eliminate redundant ScatterND operations.
44
5-
Identify ScatterND(data, indices, updates) that can be replaced by Identity(updates).
6-
This is generated by the translation of `x[:, ...] = y` in PyTorch.
7-
The specific pattern is that the updated indices take the form [[0], ..., [S-1]] for the first dimension,
8-
where S is the size of the first dimension of the updated-data tensor.
9-
In effect, the scatter-update ends up being an assignment of a new value to the entire tensor.
5+
This module contains two rewrite rules:
6+
7+
1. ScatterAllDynamic: Identifies ScatterND(data, indices, updates) that can be replaced by Identity(updates)
8+
when the indices are computed dynamically using Range operations but represent a complete update
9+
of an entire axis. This is generated by the translation of `x[:, ...] = y` in PyTorch.
10+
11+
2. ScatterAllStatic: Identifies ScatterND(data, indices, updates) that can be replaced by Identity(updates)
12+
when the indices are statically known constants in the form [[0], [1], ..., [n-1]] covering
13+
the entire first dimension of the data tensor.
14+
15+
Both rules detect when the scatter-update ends up being an assignment of a new value to the entire tensor.
1016
"""
1117

1218
from __future__ import annotations
@@ -22,7 +28,7 @@ def fail(*args):
2228
return onnxscript.rewriter.MatchResult().fail(*args)
2329

2430

25-
class ScatterAll(orp.RewriteRuleClassBase):
31+
class ScatterAllDynamic(orp.RewriteRuleClassBase):
2632
def pattern(self, op, data, axis, transposed_data, updates):
2733
# Construct update-indices spanning an entire axis:
2834
shape = op.Shape(data, start=0)
@@ -60,6 +66,44 @@ def rewrite(self, op, updates, **_):
6066
return op.Identity(updates)
6167

6268

63-
rule = ScatterAll.rule()
69+
class ScatterAllStatic(orp.RewriteRuleClassBase):
70+
"""Rewrite rule for eliminating redundant ScatterND with statically known indices.
71+
72+
This handles the case where indices are constant values in the form [[0], [1], ..., [n-1]]
73+
that update the entire first dimension of the data tensor.
74+
"""
75+
76+
def pattern(self, op, data, indices, updates):
77+
"""Pattern to match ScatterND with static indices."""
78+
return op.ScatterND(data, indices, updates)
79+
80+
def check(self, context, data, indices, updates, **_):
81+
"""Check if the ScatterND is redundant due to static indices covering entire tensor."""
82+
# To validate data can be replaced directly by updates, we need to check the following:
83+
# 1. they have the same shape
84+
if data.shape is None:
85+
return fail("The value 'data' shape is not statically known.", data)
86+
if updates.shape is None:
87+
return fail("The value 'updates' shape is not statically known.", updates)
88+
if data.shape != updates.shape:
89+
return fail("The shape of 'data' and 'updates' are different.", data, updates)
90+
91+
# 2. the indices is referring to the whole data, which is from 0 to data.shape[0]
92+
if indices.const_value is None:
93+
return fail("The value 'indices' is not statically known.", indices)
94+
expected_indices = [[i] for i in range(data.shape[0])]
95+
actual_indices = indices.const_value.numpy().tolist()
96+
if actual_indices != expected_indices:
97+
return fail("The 'indices' is not referring to the whole data.", indices)
98+
99+
return True
100+
101+
def rewrite(self, op, updates, **_):
102+
"""Replace ScatterND with Identity since updates covers entire tensor."""
103+
return op.Identity(updates)
104+
105+
106+
rule = ScatterAllDynamic.rule()
107+
static_rule = ScatterAllStatic.rule()
64108

65-
rules = orp.RewriteRuleSet([rule])
109+
rules = orp.RewriteRuleSet([rule, static_rule])

onnxscript/rewriter/redundant_scatter_nd_test.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import unittest
66

77
import numpy as np
8+
import onnx.parser
89
import onnx_ir as ir
910
import onnxruntime
1011
from onnx_ir.passes.common import CheckerPass, ShapeInferencePass
@@ -19,7 +20,9 @@
1920

2021

2122
class RedundantScatterNdTest(unittest.TestCase):
22-
def test_redundant_scatter_nd(self):
23+
def test_redundant_scatter_nd_dynamic_indices(self):
24+
"""Test redundant ScatterND with dynamically constructed indices."""
25+
2326
@script()
2427
def model_script(
2528
data: FLOAT[8, "N", 16], updates: FLOAT[8, "N", 16]
@@ -62,9 +65,61 @@ def model_script(
6265
optimized_model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
6366
)
6467
optimized_outputs = optimized_session.run(None, inputs)
68+
# Compare outputs
6569
for output, optimized_output in zip(outputs, optimized_outputs):
6670
np.testing.assert_allclose(output, optimized_output, rtol=1e-6, atol=1e-6)
6771

72+
def test_redundant_scatter_nd_static_indices(self):
73+
"""Test redundant ScatterND with static indices (moved from collapse_slices_test.py)."""
74+
model_proto = onnx.parser.parse_model(
75+
"""
76+
<ir_version: 7, opset_import: [ "" : 17]>
77+
agraph (float[112, 16, 512] data, float[112, 16, 512] updates) => (float[112, 16, 512] output)
78+
{
79+
output = ScatterND (data, indices, updates)
80+
}
81+
"""
82+
)
83+
# Use inserted initializers to avoid manually coding the large constants
84+
indices = np.arange(112).reshape(112, 1).astype(np.int64)
85+
model = ir.serde.deserialize_model(model_proto)
86+
# from numpy to ir.Tensor
87+
indices_ir_tensor = ir.Tensor(
88+
name="indices",
89+
value=indices,
90+
)
91+
# assign the tensor to a value
92+
indices_value = model.graph[0].inputs[1]
93+
indices_value.const_value = indices_ir_tensor
94+
model.graph.initializers["indices"] = indices_value
95+
original_model_proto = ir.serde.serialize_model(model)
96+
97+
count = redundant_scatter_nd.rules.apply_to_model(model)
98+
self.assertEqual(count, 1)
99+
self.assertEqual(len(model.graph), 1)
100+
self.assertIn("Identity", [node.op_type for node in model.graph])
101+
102+
# Test numerical equivalence
103+
input_data = np.random.rand(112, 16, 512).astype(np.float32)
104+
inputs = {"data": input_data, "updates": input_data}
105+
106+
# Run original model
107+
session = onnxruntime.InferenceSession(
108+
original_model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
109+
)
110+
original_outputs = session.run(None, inputs)
111+
112+
# Run optimized model
113+
optimized_model_proto = ir.serde.serialize_model(model)
114+
optimized_session = onnxruntime.InferenceSession(
115+
optimized_model_proto.SerializeToString(), providers=["CPUExecutionProvider"]
116+
)
117+
optimized_outputs = optimized_session.run(None, inputs)
118+
119+
# Compare outputs
120+
for original_output, optimized_output in zip(original_outputs, optimized_outputs):
121+
np.testing.assert_allclose(original_output, optimized_output, rtol=1e-6, atol=1e-6)
122+
68123

69124
if __name__ == "__main__":
70125
unittest.main()

0 commit comments

Comments
 (0)