Skip to content

Commit 36a700c

Browse files
committed
WIP
1 parent a63c282 commit 36a700c

File tree

4 files changed

+227
-123
lines changed

4 files changed

+227
-123
lines changed

onnxscript/function_libs/torch_lib/ops/fft.py

+203-101
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
# --------------------------------------------------------------------------
2-
# Copyright (c) Microsoft Corporation. All rights reserved.
1+
# Copyright (c) Microsoft Corporation.
32
# Licensed under the MIT License.
4-
# --------------------------------------------------------------------------
53
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
64
"""torch.ops.aten operators under the `fft` module.
75
@@ -12,7 +10,7 @@
1210

1311
from __future__ import annotations
1412

15-
from typing import Optional, Sequence
13+
from typing import Literal, Optional, Sequence
1614

1715
from onnxscript import INT64
1816
from onnxscript.function_libs.torch_lib.registration import torch_op
@@ -21,98 +19,157 @@
2119
from onnxscript.onnx_types import TensorType
2220

2321

24-
@torch_op(
25-
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
26-
private=True,
27-
complex=True,
28-
trace_only=True,
29-
)
22+
# def _compute_signal_size(signal: TFloat, dims: Sequence[int], last_dim_size: Optional[INT64] = None) -> INT64:
23+
# if last_dim_size is not None:
24+
# all_other_dims = dims[:-1]
25+
# if all_other_dims:
26+
# signal_size = op.ReduceProd(signal, axes=all_other_dims, keepdims=False)
27+
# signal_size = op.Mul(signal_size, last_dim_size)
28+
# else:
29+
# signal_size = last_dim_size
30+
# else:
31+
# signal_size = op.ReduceProd(signal, axes=dims, keepdims=False)
32+
# return signal_size
33+
34+
35+
# def _fftn_ortho_normalization(
36+
# self: TFloat,
37+
# dims: Sequence[int],
38+
# forward: bool,
39+
# onesided: bool,
40+
# last_dim_size: Optional[INT64] = None,
41+
# ) -> TFloat:
42+
# transformed = self
43+
44+
# signal_size = _compute_signal_size(self, dims, last_dim_size)
45+
46+
# for dim in dims[:-1]:
47+
# transformed = op.DFT(transformed, axis=dim, onesided=False)
48+
49+
# # Torch computes one-sided FFT on the last dimension only.
50+
# if onesided:
51+
# transformed = op.DFT(transformed, axis=dims[-1], onesided=True)
52+
# # TODO: Update signal_size for one-sided FFT
53+
# elif last_dim_size is not None:
54+
# transformed = op.DFT(
55+
# transformed, last_dim_size, axis=dims[-1], onesided=True
56+
# )
57+
# else:
58+
# transformed = op.DFT(transformed, axis=dims[-1], onesided=False)
59+
60+
3061
def _fftn_onnx_normalization(
31-
self,
32-
transformed: TFloat,
62+
self: TFloat,
3363
normalization: int,
34-
forward: bool,
35-
dims: Sequence[int],
64+
signal_size: INT64,
3665
) -> TFloat:
37-
# Obtain the total_sample_count (n) for normalization
38-
self_shape = op.Shape(self)
39-
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0)
40-
total_sample_count = op.CastLike(total_sample_count, transformed)
41-
42-
# Normalize the result
43-
# Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
44-
# Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
66+
"""
67+
"""
68+
# TODO: Make more efficient - there should be a faster way to recalculate everything
69+
# Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
70+
# Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
71+
# Modes:
72+
# 0: no normalization (backward)
73+
# 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
74+
# 2: divide by signal_size (forward)
4575
if normalization == 1:
46-
# "forward" - normalize by 1/n
47-
if forward:
48-
result = op.Div(transformed, op.Sqrt(total_sample_count))
49-
else:
50-
result = op.Mul(transformed, op.Sqrt(total_sample_count))
76+
self = op.Div(self, op.Sqrt(signal_size))
5177
elif normalization == 2:
52-
# "ortho" - normalize by 1/sqrt(n)
53-
if forward:
54-
result = op.Div(transformed, total_sample_count)
55-
else:
56-
result = transformed
57-
else:
58-
# "backward" - no normalization
59-
if forward:
60-
result = transformed
61-
else:
62-
result = op.Mul(transformed, total_sample_count)
63-
64-
return result
78+
self = op.Div(self, signal_size)
79+
return self
6580

66-
67-
@torch_op(
68-
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
69-
trace_only=True,
70-
private=True,
71-
complex=True,
72-
)
73-
def _fftn_onnx(
74-
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
81+
def _fftn_onnx_inverse_normalization(
82+
self: TFloat,
83+
normalization: int,
84+
signal_size: INT64,
7585
) -> TFloat:
76-
"""Standard complex to complex or real to complex FFT (forward or backward).
77-
78-
This is a private shared function for implementing the various FFT functions.
79-
80-
Args:
81-
self: The input tensor.
82-
dims: The dimensions to apply FFT.
83-
normalization: The normalization mode.
84-
inverse: Whether to compute the inverse FFT.
85-
onesided: Whether to compute the one-sided FFT, which retains only the
86-
positive frequencies.
87-
88-
Returns:
89-
The transformed tensor.
9086
"""
91-
92-
# NOTE: trace_only because we need to process each dimension in a loop
93-
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
94-
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
95-
96-
# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
97-
# dimension at the beginning to represent the batch dimension.
98-
transformed = op.Unsqueeze(self, axes=[0])
99-
100-
# Add 1 to account for the batch dimension when counting axes from the left
101-
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
102-
103-
for dim in new_dims[:-1]:
104-
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)
105-
106-
# Torch computers one-sided FFT on the last dimension only.
107-
if onesided:
108-
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True)
109-
else:
110-
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False)
111-
112-
# Remove the batch dimension
113-
transformed = op.Squeeze(transformed, axes=[0])
114-
115-
return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims)
87+
"""
88+
# TODO: Make more efficient - there should be a faster way to recalculate everything
89+
# Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
90+
# Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
91+
# Modes:
92+
# 0: no normalization (backward)
93+
# 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
94+
# 2: divide by signal_size (forward)
95+
if normalization == 1:
96+
self = op.Mul(self, op.Sqrt(signal_size))
97+
elif normalization == 0:
98+
self = op.Mul(self, signal_size)
99+
return self
100+
101+
# def _fftn_onnx(
102+
# self: TFloat,
103+
# dims: Sequence[int],
104+
# normalization: int,
105+
# forward: bool,
106+
# onesided: bool,
107+
# last_dim_size: Optional[INT64] = None,
108+
# ) -> TFloat:
109+
# """Standard complex to complex or real to complex FFT (forward or backward).
110+
111+
# This is a private shared function for implementing the various FFT functions.
112+
113+
# Args:
114+
# self: The input tensor.
115+
# dims: The dimensions to apply FFT.
116+
# normalization: The normalization mode.
117+
# forward: Whether to compute forward FFT or backward FFT.
118+
# onesided: Whether to compute the one-sided FFT, which retains only the
119+
# positive frequencies.
120+
# last_dim_size: The size of the last specified dimension.
121+
122+
# Returns:
123+
# The transformed tensor.
124+
# """
125+
# # NOTE: SymInt dim is not support because DFT-17 needs a static axis
126+
127+
# # If taking FFT along the 0-th dimension: Since
128+
# # the 0-th dimension in ONNX DFT-17 is the batch dimension (cannot take DFT over),
129+
# # we need to add a new dimension at the beginning to represent the batch dimension.
130+
# unsqueeze_first_dim = 0 in dims
131+
# if unsqueeze_first_dim:
132+
# transformed = op.Unsqueeze(self, axes=[0])
133+
# # Add 1 to account for the batch dimension when counting axes from the left
134+
# dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
135+
# else:
136+
# transformed = self
137+
138+
# # Select inverse mode for ONNX based on the norm mode and forward/backward mode.
139+
# # In ONNX the only difference between inverse=True/False is the 1/n normalization applied.
140+
# #
141+
# # If normalization is 1/n and we are in backward mode, we use the inverse
142+
# # mode in ONNX to get the 1/n normalization.
143+
# inverse = normalization == 2 and not forward
144+
# ortho = normalization == 1
145+
146+
# for dim in dims[:-1]:
147+
# transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)
148+
149+
# # Torch computes one-sided FFT on the last dimension only.
150+
# if onesided:
151+
# transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=True)
152+
# elif last_dim_size is not None:
153+
# transformed = op.DFT(
154+
# transformed, last_dim_size, axis=dims[-1], inverse=inverse, onesided=False
155+
# )
156+
# else:
157+
# transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=False)
158+
159+
# if ortho or inverse:
160+
# normalized = _fftn_onnx_normalization(
161+
# transformed, ortho, dims, last_dim_size=last_dim_size
162+
# )
163+
# else:
164+
# normalized = transformed
165+
# # TODO: Merge to normalization mode and ONNX inverse mode
166+
# # Be sure to normalize before squeezing the batch dimension, because dims would
167+
# # have been shifted by 1 if the batch dimension was added.
168+
# if unsqueeze_first_dim:
169+
# # Remove the batch dimension
170+
# normalized = op.Squeeze(normalized, axes=[0])
171+
172+
# return normalized
116173

117174

118175
@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
@@ -124,39 +181,74 @@ def aten__fft_c2c(
124181
Standard complex to complex FFT (forward or backward).
125182
"""
126183

127-
# NOTE: trace_only because we need to negate forward
128-
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
129-
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
184+
# NOTE: SymInt dim is not supported because DFT-17 needs a static axis
130185

131186
# ONNX DFT input assumes the last dimension is the complex dimension.
132187
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
133-
dim = [d - 1 if d < 0 else d for d in dim]
134-
return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False)
188+
assert(dim[2] in dim == 2, "Unexpected input size")
189+
190+
signal = self
191+
self_rank = len(self.shape)
192+
signal_size = op.Size(signal)
193+
194+
# ONNX DFT input assumes the last dimension is the complex dimension.
195+
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
196+
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
197+
198+
transformed = signal
199+
200+
for dimension in reversed(dim):
201+
transformed = op.DFT(transformed, axis=dimension, inverse=not forward, onesided=False)
202+
if forward:
203+
transformed = _fftn_onnx_normalization(transformed, normalization, signal_size)
204+
else:
205+
transformed = _fftn_onnx_inverse_normalization(transformed, normalization, signal_size)
206+
207+
# Unsure if output format is correct
208+
return transformed
135209

136210

137211
@torch_op("aten::_fft_c2r", trace_only=True, complex=True)
138212
def aten__fft_c2r(
139213
self: TFloat,
140214
dim: Sequence[int],
141215
normalization: int,
142-
last_dim_size: INT64, # pylint: disable=unused-argument
216+
last_dim_size: INT64,
143217
) -> TFloat:
144218
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
145219
146220
Complex to real inverse FFT.
147221
"""
222+
assert(dim[2] in dim == 2, "Unexpected input size")
148223

149-
# TODO(justinchuby): Figure out what last_dim_size does
150-
224+
signal = self
151225
self_rank = len(self.shape)
226+
signal_size = op.Size(signal)
227+
152228
# ONNX DFT input assumes the last dimension is the complex dimension.
153229
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
154230
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
155-
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
156-
# Take only the real part
157-
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])
158231

159-
return op.Squeeze(real_part, axes=[-1])
232+
transformed = signal
233+
for dimension in reversed(dim):
234+
transformed = op.DFT(transformed, axis=dimension, inverse=True, onesided=False)
235+
transformed = _fftn_onnx_inverse_normalization(transformed, normalization, signal_size)
236+
237+
# Unsure if output format is correct
238+
transformed = op.Squeeze(transformed, axes=[-1])
239+
240+
if transformed.shape[-1] < last_dim_size:
241+
pads = [0, last_dim_size - transformed.shape[-1]]
242+
mode = 'constant'
243+
constant_value = 0.0
244+
transformed = op.Pad(mode=mode, data=transformed, pads=pads, constant_value=constant_value, axes=[-1])
245+
elif transformed.shape[-1] > last_dim_size:
246+
starts = [0]*(self_rank-1)
247+
ends = list(self.shape)
248+
ends[-1] = last_dim_size
249+
transformed = op.Slice(data=transformed, starts=starts, ends=ends)
250+
251+
return transformed
160252

161253

162254
@torch_op("aten::_fft_r2c", trace_only=True)
@@ -174,12 +266,22 @@ def aten__fft_r2c(
174266
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs
175267

176268
self_rank = len(self.shape)
269+
signal_size = op.Size(signal)
270+
177271
# ONNX DFT input assumes the last dimension is the complex dimension.
178272
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
179273
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
180274

181-
return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided)
275+
# Torch computes one-sided FFT on the last dimension only.
276+
transformed = op.DFT(signal, axis=dim[-1], inverse=False, onesided=onesided)
277+
transformed = _fftn_onnx_normalization(transformed, normalization, signal_size)
278+
279+
for dimension in reversed(dim[:-1]):
280+
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=False)
281+
transformed = _fftn_onnx_normalization(transformed, normalization, signal_size)
182282

283+
# Unsure if output format is correct
284+
return transformed
183285

184286
def aten_fft_fft(
185287
self: TensorType, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None

0 commit comments

Comments
 (0)