Skip to content

Commit 15c4a23

Browse files
authored
Fix UnsqueezeParser for ONNX Opset 13+ (#119)
This PR updates the `Squeeze` and `Unsqueeze` operators to align with the ONNX Opset 13+ standard, where axes is now provided as an input instead of a node attribute. The change is backward compatible, as the old logic remains for single-input cases. A new test case, `CCT/CCT_2_32_32_128_Opset20`, has been added using models exported from PyTorch with ONNX Opset 20 to verify compatibility with recent versions. ## Added - Added support for ONNX Opset 13 and higher. ## Changed - UnsqueezeParser in Generic NodeParser - Check for the presence of `axes` in node attributes and use the old workflow otherwise check for exactly 2 inputs (data and axes). - Node context was changes accordingly; 1 single input follows the old workflow, 2 inputs uses the new 2 input Op. format. ## Fixed - Breaking compilation with ONNX Opset 13 and higher when using `Squeeze` Op.
1 parent 362033f commit 15c4a23

File tree

7 files changed

+35
-9
lines changed

7 files changed

+35
-9
lines changed

.github/workflows/ci-platform-generic.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,5 @@ jobs:
9494
miniMobileNet
9595
miniMobileNetv2
9696
CCT/CCT_1_16_16_8
97+
CCT/CCT_2_32_32_128_Opset20
9798
testFloatDemoTinyViT

.github/workflows/ci-platform-siracusa.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,6 @@ jobs:
8585
MLPerf/ImageClassification
8686
MLPerf/AnomalyDetection
8787
CCT/CCT_1_16_16_8
88+
CCT/CCT_2_32_32_128_Opset20
8889
testTrainCCT/CCT1_Classifier_Training/CCT_1_16_16_8
8990
num-cores: 8

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid
1414
- Node Mangling to avoid duplication [#93](https://github.com/pulp-platform/Deeploy/pull/93)
1515
- Prepare Post v0.2.0 Release [#104](https://github.com/pulp-platform/Deeploy/pull/104)
1616
- Use Docker digests instead of arch-specific tags [#106](https://github.com/pulp-platform/Deeploy/pull/106)
17+
- Fix `Unsqueeze` Op. when using ONNX opset 13 or higher (from attribute to input) [#119](https://github.com/pulp-platform/Deeploy/pull/119)
1718

1819
### Added
1920
- Add manual type inference feature (CLI: `--input-type-map`/`--input-offset-map`) to resolve ambiguities when test inputs are not representative enough
@@ -81,6 +82,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid
8182
- Fixed multiple typos in variable and method names, such as changing `includeGobalReferences` to `includeGlobalReferences` and `dicardedMappers` to `discardedMappers`
8283
- Corrected method usage in `importDeeployState` to call `NetworkContext.importNetworkContext` instead of the incorrect method name
8384
- Correctly return `signProp` from `setupDeployer` instead of hardcoding the value to `False` in `testMVP.py`
85+
- Fixed `Unsqueeze` Op. when using ONNX opset 13 or higher (from attribute to input)
8486

8587
### Removed
8688
- Delete outdated and unused `.gitlab-ci.yml` file

Deeploy/Targets/Generic/Parsers.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -940,10 +940,19 @@ def __init__(self):
940940

941941
def parseNode(self, node: gs.Node) -> (bool):
942942

943-
ret = all(['axes' in node.attrs, len(node.inputs) == 1, len(node.outputs) == 1])
943+
# ONNX v11: 'axes' is a node attribute
944+
if 'axes' in node.attrs:
945+
ret = all(['axes' in node.attrs, len(node.inputs) == 1, len(node.outputs) == 1])
946+
# ONNX v13+: 'axes' becomes an input with the data
947+
# Source: https://onnx.ai/onnx/operators/onnx__Unsqueeze.html
948+
else:
949+
ret = all([len(node.inputs) == 2, len(node.outputs) == 1])
944950

945-
if ret:
946-
self.operatorRepresentation['axes'] = node.attrs['axes']
951+
if ret and 'axes' in node.attrs:
952+
axes_attr = node.attrs['axes']
953+
self.operatorRepresentation['axes'] = [int(axes_attr)] if isinstance(axes_attr, int) \
954+
else [int(a) for a in axes_attr]
955+
# For opset 13+, axes will be extracted from the second input in parseNodeCtxt
947956

948957
return ret
949958

@@ -952,13 +961,26 @@ def parseNodeCtxt(self,
952961
node: gs.Node,
953962
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
954963

955-
inputs = ['data_in']
956964
outputs = ['data_out']
957-
958-
for idx, inputNode in enumerate(node.inputs):
959-
self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name
960-
for idx, outputNode in enumerate(node.outputs):
961-
self.operatorRepresentation[outputs[idx]] = ctxt.lookup(outputNode.name).name
965+
if len(node.inputs) == 1:
966+
inputs = ['data_in']
967+
for idx, inputNode in enumerate(node.inputs):
968+
self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name
969+
for idx, outputNode in enumerate(node.outputs):
970+
self.operatorRepresentation[outputs[idx]] = ctxt.lookup(outputNode.name).name
971+
else:
972+
data_in = ctxt.lookup(node.inputs[0].name)
973+
data_out = ctxt.lookup(node.outputs[0].name)
974+
self.operatorRepresentation['data_in'] = data_in.name
975+
self.operatorRepresentation['data_out'] = data_out.name
976+
# axes must be a constant; extract values
977+
axes_buf = ctxt.lookup(node.inputs[1].name)
978+
assert hasattr(axes_buf, 'values'), "Unsqueeze: expected constant 'axes' input for opset 13+"
979+
axes_vals = np.array(axes_buf.values).astype(int).flatten().tolist()
980+
self.operatorRepresentation['axes'] = axes_vals
981+
# Do not deploy the axes tensor
982+
axes_buf._live = False
983+
axes_buf._deploy = False
962984

963985
return ctxt, True
964986

12.3 KB
Binary file not shown.
869 KB
Binary file not shown.
306 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)