Skip to content

Commit 38871a5

Browse files
authored
Support dynamic shapes for aten_unfold (#2407)
While converting a new model that I'd like to add to Transformers.js, I ran into #2309, indicating that dynamic shapes aren't currently supported for `aten_unfold`: ``` File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/onnxscript/function_libs/torch_lib/ops/core.py", line 8662, in aten_unfold low_indices = range(0, dim_size, step) TypeError: 'SymbolicDim' object cannot be interpreted as an integer ``` So, I dug a bit into the code and with some help from Claude, I got a version which works for my use-case (output matches exactly)! 👍 Code to reproduce (adapted from pytorch/pytorch#112844 (comment)) ```py import torch class SpecMaker(torch.nn.Module): def forward(self, x): return torch.ops.aten.unfold(x, -1, 512, 160) specmodel = SpecMaker() input = torch.rand(32000 * 10) spec = specmodel(input) input_batch = torch.stack([input, input]) spec_batch = specmodel(input_batch) onnx_program = torch.onnx.export( specmodel, (input_batch,), f="/tmp/model.onnx", dynamic_shapes=[{0: "dim_x",1:"length"}], input_names=["input"], output_names=["output"], dynamo=True, report=True, ) ``` ## Logs (before) ``` (base) ➜ onnxscript git:(main) ✗ python testing/unfold.py [torch.onnx] Obtain model graph for `SpecMaker()` with `torch.export.export(..., strict=False)`... [torch.onnx] Obtain model graph for `SpecMaker()` with `torch.export.export(..., strict=False)`... ✅ [torch.onnx] Run decomposition... [torch.onnx] Run decomposition... ✅ [torch.onnx] Translate the graph into ONNX... [torch.onnx] Translate the graph into ONNX... ❌ [torch.onnx] Export report has been saved to 'onnx_export_2025-06-20_14-08-52-474773_conversion.md'. Traceback (most recent call last): File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 519, in _handle_call_function_node_with_lowering outputs = onnx_function(*onnx_args, **onnx_kwargs) File ".../onnxscript/onnxscript/values.py", line 625, in __call__ return self.func(*args, **kwargs) ~~~~~~~~~^^^^^^^^^^^^^^^^^ File ".../onnxscript/onnxscript/function_libs/torch_lib/ops/core.py", line 8660, in aten_unfold low_indices = range(0, dim_size, step) TypeError: 'SymbolicDim' object cannot be interpreted as an integer The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 707, in _translate_fx_graph _handle_call_function_node_with_lowering( ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^ model, ^^^^^^ ...<6 lines>... node_name_to_local_functions=node_name_to_local_functions, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 521, in _handle_call_function_node_with_lowering raise _errors.GraphConstructionError( f"Error when calling function '{onnx_function}' with args '{onnx_args}' and kwargs '{onnx_kwargs}'" ) from e torch.onnx._internal.exporter._errors.GraphConstructionError: Error when calling function 'TracedOnnxFunction(<function aten_unfold at 0x120baa7a0>)' with args '[SymbolicTensor(name='x', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s0), SymbolicDim(s1)])), -1, 512, 160]' and kwargs '{}' The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 1373, in export onnx_program = _exported_program_to_onnx_program( decomposed_program, registry=registry ) File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 1007, in _exported_program_to_onnx_program values = _translate_fx_graph( fx_graph, ...<4 lines>... registry=registry, ) File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 733, in _translate_fx_graph raise _errors.ConversionError( f"Error when translating node {node.format_node()}. See the stack trace for more information." ) from e torch.onnx._internal.exporter._errors.ConversionError: Error when translating node %unfold : [num_users=1] = call_function[target=torch.ops.aten.unfold.default](args = (%x, -1, 512, 160), kwargs = {}). See the stack trace for more information. The above exception was the direct cause of the following exception: Traceback (most recent call last): File ".../onnxscript/testing/unfold.py", line 15, in <module> onnx_program = torch.onnx.export( specmodel, ...<7 lines>... # verbose=True, ) File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/__init__.py", line 364, in export return _compat.export_compat( ~~~~~~~~~~~~~~~~~~~~~^ model, ^^^^^^ ...<19 lines>... fallback=fallback, ^^^^^^^^^^^^^^^^^^ ) ^ File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_compat.py", line 120, in export_compat onnx_program = _core.export( model, ...<11 lines>... verbose=verbose, ) File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/torch/onnx/_internal/exporter/_core.py", line 1419, in export raise _errors.ConversionError( ...<3 lines>... ) from e torch.onnx._internal.exporter._errors.ConversionError: Failed to convert the exported program to an ONNX model. This is step 3/3 of exporting the model to ONNX. Next steps: - If there is a missing ONNX function, implement it and register it to the registry. - If there is an internal error during ONNX conversion, debug the error and summit a PR to PyTorch. - Create an error report with `torch.onnx.export(..., report=True)`, and save the ExportedProgram as a pt2 file. Create an issue in the PyTorch GitHub repository against the *onnx* component. Attach the error report and the pt2 model. Error report has been saved to 'onnx_export_2025-06-20_14-08-52-474773_conversion.md'. ## Exception summary <class 'TypeError'>: 'SymbolicDim' object cannot be interpreted as an integer ⬆️ <class 'torch.onnx._internal.exporter._errors.GraphConstructionError'>: Error when calling function 'TracedOnnxFunction(<function aten_unfold at 0x120baa7a0>)' with args '[SymbolicTensor(name='x', type=Tensor(FLOAT), shape=Shape([SymbolicDim(s0), SymbolicDim(s1)])), -1, 512, 160]' and kwargs '{}' ⬆️ <class 'torch.onnx._internal.exporter._errors.ConversionError'>: Error when translating node %unfold : [num_users=1] = call_function[target=torch.ops.aten.unfold.default](args = (%x, -1, 512, 160), kwargs = {}). See the stack trace for more information. (Refer to the full stack trace above for more information.) ``` ## Logs (after) ``` (base) ➜ onnxscript git:(main) ✗ python testing/unfold.py [torch.onnx] Obtain model graph for `SpecMaker()` with `torch.export.export(..., strict=False)`... [torch.onnx] Obtain model graph for `SpecMaker()` with `torch.export.export(..., strict=False)`... ✅ [torch.onnx] Run decomposition... [torch.onnx] Run decomposition... ✅ [torch.onnx] Translate the graph into ONNX... [torch.onnx] Translate the graph into ONNX... ✅ [torch.onnx] Export report has been saved to 'onnx_export_2025-06-20_14-11-27-804730_success.md'. Applied 1 of general pattern rewrite rules. ``` Closes #2309. cc @justinchuby
1 parent 03ab4c5 commit 38871a5

File tree

1 file changed

+28
-21
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+28
-21
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8655,29 +8655,36 @@ def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor:
86558655
# Handle negative dimension
86568656
if dimension < 0:
86578657
dimension = dimension + self_rank
8658-
dim_size = self.shape[dimension]
8659-
8660-
low_indices = range(0, dim_size, step)
8661-
hi_indices = range(size, dim_size + 1, step)
8662-
stack = [
8663-
op.Slice(
8664-
self,
8665-
op.Constant(value_ints=[low]),
8666-
op.Constant(value_ints=[hi]),
8667-
op.Constant(value_ints=[dimension]),
8668-
)
8669-
for low, hi in zip(low_indices, hi_indices)
8670-
]
86718658

8659+
input_shape = op.Shape(self)
8660+
dim_size = op.Gather(input_shape, op.Constant(value_ints=[dimension]))
8661+
8662+
# Create indices for each window
8663+
window_starts = op.Range(0, op.Sub(dim_size, size - 1), step)
8664+
8665+
# Create the base indices for one window
8666+
window_indices = list(range(size))
8667+
8668+
# Broadcast to create all indices
8669+
starts_expanded = op.Unsqueeze(window_starts, [1]) # [num_windows, 1]
8670+
indices_expanded = op.Unsqueeze(window_indices, [0]) # [1, size]
8671+
all_indices = op.Add(starts_expanded, indices_expanded) # [num_windows, size]
8672+
8673+
# Gather along the specified dimension
8674+
result = op.Gather(self, all_indices, axis=dimension)
8675+
8676+
# The result shape is now [..., num_windows, size, ...] with num_windows at position 'dimension'.
8677+
# We need to move the size dimension to the end:
8678+
# Current shape: [..., num_windows, size, ...]
8679+
# Target shape: [..., num_windows, ..., size]
8680+
8681+
# Move the size dimension (at position dimension+1) to the end
86728682
# perm need to be list[int], so have to be generated in trace_only mode
8673-
perm = list(range(self_rank))
8674-
# from [0,1,2,3,4] -> [0,1,3,4,2] when dimension=1
8675-
perm.append(perm.pop(dimension))
8676-
unsqueeze = [
8677-
op.Unsqueeze(op.Transpose(t, perm=perm), op.Constant(value_ints=[dimension]))
8678-
for t in stack
8679-
]
8680-
result = op.Concat(*unsqueeze, axis=dimension)
8683+
perm = list(range(self_rank + 1))
8684+
perm.append(perm.pop(dimension + 1))
8685+
8686+
result = op.Transpose(result, perm=perm)
8687+
86818688
return result
86828689

86838690

0 commit comments

Comments
 (0)