1
- # --------------------------------------------------------------------------
2
- # Copyright (c) Microsoft Corporation. All rights reserved.
1
+ # Copyright (c) Microsoft Corporation.
3
2
# Licensed under the MIT License.
4
- # --------------------------------------------------------------------------
5
3
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
6
4
"""torch.ops.aten operators under the `fft` module.
7
5
12
10
13
11
from __future__ import annotations
14
12
15
- from typing import Optional , Sequence
13
+ from typing import Literal , Optional , Sequence
16
14
17
15
from onnxscript import INT64
18
16
from onnxscript .function_libs .torch_lib .registration import torch_op
21
19
from onnxscript .onnx_types import TensorType
22
20
23
21
24
- @torch_op (
25
- ("aten::_fft_c2c" , "aten::_fft_c2r" , "aten::_fft_r2c" ),
26
- private = True ,
27
- complex = True ,
28
- trace_only = True ,
29
- )
22
+ # def _compute_signal_size(signal: TFloat, dims: Sequence[int], last_dim_size: Optional[INT64] = None) -> INT64:
23
+ # if last_dim_size is not None:
24
+ # all_other_dims = dims[:-1]
25
+ # if all_other_dims:
26
+ # signal_size = op.ReduceProd(signal, axes=all_other_dims, keepdims=False)
27
+ # signal_size = op.Mul(signal_size, last_dim_size)
28
+ # else:
29
+ # signal_size = last_dim_size
30
+ # else:
31
+ # signal_size = op.ReduceProd(signal, axes=dims, keepdims=False)
32
+ # return signal_size
33
+
34
+
35
+ # def _fftn_ortho_normalization(
36
+ # self: TFloat,
37
+ # dims: Sequence[int],
38
+ # forward: bool,
39
+ # onesided: bool,
40
+ # last_dim_size: Optional[INT64] = None,
41
+ # ) -> TFloat:
42
+ # transformed = self
43
+
44
+ # signal_size = _compute_signal_size(self, dims, last_dim_size)
45
+
46
+ # for dim in dims[:-1]:
47
+ # transformed = op.DFT(transformed, axis=dim, onesided=False)
48
+
49
+ # # Torch computes one-sided FFT on the last dimension only.
50
+ # if onesided:
51
+ # transformed = op.DFT(transformed, axis=dims[-1], onesided=True)
52
+ # # TODO: Update signal_size for one-sided FFT
53
+ # elif last_dim_size is not None:
54
+ # transformed = op.DFT(
55
+ # transformed, last_dim_size, axis=dims[-1], onesided=True
56
+ # )
57
+ # else:
58
+ # transformed = op.DFT(transformed, axis=dims[-1], onesided=False)
59
+
60
+
30
61
def _fftn_onnx_normalization (
31
- self ,
32
- transformed : TFloat ,
62
+ self : TFloat ,
33
63
normalization : int ,
34
- forward : bool ,
35
- dims : Sequence [int ],
64
+ signal_size : INT64 ,
36
65
) -> TFloat :
37
- # Obtain the total_sample_count (n) for normalization
38
- self_shape = op .Shape (self )
39
- total_sample_count = op .ReduceProd (op .Gather (self_shape , dims ), keepdims = 0 )
40
- total_sample_count = op .CastLike (total_sample_count , transformed )
41
-
42
- # Normalize the result
43
- # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
44
- # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
66
+ """
67
+ """
68
+ # TODO: Make more efficient - there should be a faster way to recalculate everything
69
+ # Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
70
+ # Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
71
+ # Modes:
72
+ # 0: no normalization (backward)
73
+ # 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
74
+ # 2: divide by signal_size (forward)
45
75
if normalization == 1 :
46
- # "forward" - normalize by 1/n
47
- if forward :
48
- result = op .Div (transformed , op .Sqrt (total_sample_count ))
49
- else :
50
- result = op .Mul (transformed , op .Sqrt (total_sample_count ))
76
+ self = op .Div (self , op .Sqrt (signal_size ))
51
77
elif normalization == 2 :
52
- # "ortho" - normalize by 1/sqrt(n)
53
- if forward :
54
- result = op .Div (transformed , total_sample_count )
55
- else :
56
- result = transformed
57
- else :
58
- # "backward" - no normalization
59
- if forward :
60
- result = transformed
61
- else :
62
- result = op .Mul (transformed , total_sample_count )
63
-
64
- return result
78
+ self = op .Div (self , signal_size )
79
+ return self
65
80
66
-
67
- @torch_op (
68
- ("aten::_fft_c2c" , "aten::_fft_c2r" , "aten::_fft_r2c" ),
69
- trace_only = True ,
70
- private = True ,
71
- complex = True ,
72
- )
73
- def _fftn_onnx (
74
- self : TFloat , dims : Sequence [int ], normalization : int , inverse : bool , onesided : bool
81
+ def _fftn_onnx_inverse_normalization (
82
+ self : TFloat ,
83
+ normalization : int ,
84
+ signal_size : INT64 ,
75
85
) -> TFloat :
76
- """Standard complex to complex or real to complex FFT (forward or backward).
77
-
78
- This is a private shared function for implementing the various FFT functions.
79
-
80
- Args:
81
- self: The input tensor.
82
- dims: The dimensions to apply FFT.
83
- normalization: The normalization mode.
84
- inverse: Whether to compute the inverse FFT.
85
- onesided: Whether to compute the one-sided FFT, which retains only the
86
- positive frequencies.
87
-
88
- Returns:
89
- The transformed tensor.
90
86
"""
91
-
92
- # NOTE: trace_only because we need to process each dimension in a loop
93
- # NOTE: SymInt dim is not support because DFT-17 needs a static axis
94
- # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
95
-
96
- # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
97
- # dimension at the beginning to represent the batch dimension.
98
- transformed = op .Unsqueeze (self , axes = [0 ])
99
-
100
- # Add 1 to account for the batch dimension when counting axes from the left
101
- new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims ]
102
-
103
- for dim in new_dims [:- 1 ]:
104
- transformed = op .DFT (transformed , axis = dim , inverse = inverse , onesided = False )
105
-
106
- # Torch computers one-sided FFT on the last dimension only.
107
- if onesided :
108
- transformed = op .DFT (transformed , axis = new_dims [- 1 ], inverse = inverse , onesided = True )
109
- else :
110
- transformed = op .DFT (transformed , axis = new_dims [- 1 ], inverse = inverse , onesided = False )
111
-
112
- # Remove the batch dimension
113
- transformed = op .Squeeze (transformed , axes = [0 ])
114
-
115
- return _fftn_onnx_normalization (self , transformed , normalization , not inverse , dims )
87
+ """
88
+ # TODO: Make more efficient - there should be a faster way to recalculate everything
89
+ # Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
90
+ # Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
91
+ # Modes:
92
+ # 0: no normalization (backward)
93
+ # 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
94
+ # 2: divide by signal_size (forward)
95
+ if normalization == 1 :
96
+ self = op .Mul (self , op .Sqrt (signal_size ))
97
+ elif normalization == 0 :
98
+ self = op .Mul (self , signal_size )
99
+ return self
100
+
101
+ # def _fftn_onnx(
102
+ # self: TFloat,
103
+ # dims: Sequence[int],
104
+ # normalization: int,
105
+ # forward: bool,
106
+ # onesided: bool,
107
+ # last_dim_size: Optional[INT64] = None,
108
+ # ) -> TFloat:
109
+ # """Standard complex to complex or real to complex FFT (forward or backward).
110
+
111
+ # This is a private shared function for implementing the various FFT functions.
112
+
113
+ # Args:
114
+ # self: The input tensor.
115
+ # dims: The dimensions to apply FFT.
116
+ # normalization: The normalization mode.
117
+ # forward: Whether to compute forward FFT or backward FFT.
118
+ # onesided: Whether to compute the one-sided FFT, which retains only the
119
+ # positive frequencies.
120
+ # last_dim_size: The size of the last specified dimension.
121
+
122
+ # Returns:
123
+ # The transformed tensor.
124
+ # """
125
+ # # NOTE: SymInt dim is not support because DFT-17 needs a static axis
126
+
127
+ # # If taking FFT along the 0-th dimension: Since
128
+ # # the 0-th dimension in ONNX DFT-17 is the batch dimension (cannot take DFT over),
129
+ # # we need to add a new dimension at the beginning to represent the batch dimension.
130
+ # unsqueeze_first_dim = 0 in dims
131
+ # if unsqueeze_first_dim:
132
+ # transformed = op.Unsqueeze(self, axes=[0])
133
+ # # Add 1 to account for the batch dimension when counting axes from the left
134
+ # dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
135
+ # else:
136
+ # transformed = self
137
+
138
+ # # Select inverse mode for ONNX based on the norm mode and forward/backward mode.
139
+ # # In ONNX the only difference between inverse=True/False is the 1/n normalization applied.
140
+ # #
141
+ # # If normalization is 1/n and we are in backward mode, we use the inverse
142
+ # # mode in ONNX to get the 1/n normalization.
143
+ # inverse = normalization == 2 and not forward
144
+ # ortho = normalization == 1
145
+
146
+ # for dim in dims[:-1]:
147
+ # transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)
148
+
149
+ # # Torch computes one-sided FFT on the last dimension only.
150
+ # if onesided:
151
+ # transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=True)
152
+ # elif last_dim_size is not None:
153
+ # transformed = op.DFT(
154
+ # transformed, last_dim_size, axis=dims[-1], inverse=inverse, onesided=False
155
+ # )
156
+ # else:
157
+ # transformed = op.DFT(transformed, axis=dims[-1], inverse=inverse, onesided=False)
158
+
159
+ # if ortho or inverse:
160
+ # normalized = _fftn_onnx_normalization(
161
+ # transformed, ortho, dims, last_dim_size=last_dim_size
162
+ # )
163
+ # else:
164
+ # normalized = transformed
165
+ # # TODO: Merge to normalization mode and ONNX inverse mode
166
+ # # Be sure to normalize before squeezing the batch dimension, because dims would
167
+ # # have been shifted by 1 if the batch dimension was added.
168
+ # if unsqueeze_first_dim:
169
+ # # Remove the batch dimension
170
+ # normalized = op.Squeeze(normalized, axes=[0])
171
+
172
+ # return normalized
116
173
117
174
118
175
@torch_op ("aten::_fft_c2c" , trace_only = True , complex = True )
@@ -124,39 +181,74 @@ def aten__fft_c2c(
124
181
Standard complex to complex FFT (forward or backward).
125
182
"""
126
183
127
- # NOTE: trace_only because we need to negate forward
128
- # NOTE: SymInt dim is not support because DFT-17 needs a static axis
129
- # TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
184
+ # NOTE: SymInt dim is not supported because DFT-17 needs a static axis
130
185
131
186
# ONNX DFT input assumes the last dimension is the complex dimension.
132
187
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
133
- dim = [d - 1 if d < 0 else d for d in dim ]
134
- return _fftn_onnx (self , dim , normalization , inverse = not forward , onesided = False )
188
+ assert (dim [2 ] in dim == 2 , "Unexpected input size" )
189
+
190
+ signal = self
191
+ self_rank = len (self .shape )
192
+ signal_size = op .Size (signal )
193
+
194
+ # ONNX DFT input assumes the last dimension is the complex dimension.
195
+ # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
196
+ dim = [(d - 1 ) + self_rank if d < 0 else d for d in dim ]
197
+
198
+ transformed = signal
199
+
200
+ for dimension in reversed (dim ):
201
+ transformed = op .DFT (transformed , axis = dimension , inverse = not forward , onesided = False )
202
+ if forward :
203
+ transformed = _fftn_onnx_normalization (transformed , normalization , signal_size )
204
+ else :
205
+ transformed = _fftn_onnx_inverse_normalization (transformed , normalization , signal_size )
206
+
207
+ # Unsure if output format is correct
208
+ return transformed
135
209
136
210
137
211
@torch_op ("aten::_fft_c2r" , trace_only = True , complex = True )
138
212
def aten__fft_c2r (
139
213
self : TFloat ,
140
214
dim : Sequence [int ],
141
215
normalization : int ,
142
- last_dim_size : INT64 , # pylint: disable=unused-argument
216
+ last_dim_size : INT64 ,
143
217
) -> TFloat :
144
218
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
145
219
146
220
Complex to real inverse FFT.
147
221
"""
222
+ assert (dim [2 ] in dim == 2 , "Unexpected input size" )
148
223
149
- # TODO(justinchuby): Figure out what last_dim_size does
150
-
224
+ signal = self
151
225
self_rank = len (self .shape )
226
+ signal_size = op .Size (signal )
227
+
152
228
# ONNX DFT input assumes the last dimension is the complex dimension.
153
229
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
154
230
dim = [(d - 1 ) + self_rank if d < 0 else d for d in dim ]
155
- transformed = _fftn_onnx (self , dim , normalization , inverse = True , onesided = False )
156
- # Take only the real part
157
- real_part = op .Slice (transformed , axes = [- 1 ], starts = [0 ], ends = [1 ])
158
231
159
- return op .Squeeze (real_part , axes = [- 1 ])
232
+ transformed = signal
233
+ for dimension in reversed (dim ):
234
+ transformed = op .DFT (transformed , axis = dimension , inverse = True , onesided = False )
235
+ transformed = _fftn_onnx_inverse_normalization (transformed , normalization , signal_size )
236
+
237
+ # Unsure if output format is correct
238
+ transformed = op .Squeeze (transformed , axes = [- 1 ])
239
+
240
+ if transformed .shape [- 1 ] < last_dim_size :
241
+ pads = [0 , last_dim_size - transformed .shape [- 1 ]]
242
+ mode = 'constant'
243
+ constant_value = 0.0
244
+ transformed = op .Pad (mode = mode , data = transformed , pads = pads , constant_value = constant_value , axes = [- 1 ])
245
+ elif transformed .shape [- 1 ] > last_dim_size :
246
+ starts = [0 ]* (self_rank - 1 )
247
+ ends = list (self .shape )
248
+ ends [- 1 ] = last_dim_size
249
+ transformed = op .Slice (data = transformed , starts = starts , ends = ends )
250
+
251
+ return transformed
160
252
161
253
162
254
@torch_op ("aten::_fft_r2c" , trace_only = True )
@@ -174,12 +266,22 @@ def aten__fft_r2c(
174
266
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs
175
267
176
268
self_rank = len (self .shape )
269
+ signal_size = op .Size (signal )
270
+
177
271
# ONNX DFT input assumes the last dimension is the complex dimension.
178
272
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
179
273
dim = [(d - 1 ) + self_rank if d < 0 else d for d in dim ]
180
274
181
- return _fftn_onnx (signal , dim , normalization , inverse = False , onesided = onesided )
275
+ # Torch computes one-sided FFT on the last dimension only.
276
+ transformed = op .DFT (signal , axis = dim [- 1 ], inverse = False , onesided = onesided )
277
+ transformed = _fftn_onnx_normalization (transformed , normalization , signal_size )
278
+
279
+ for dimension in reversed (dim [:- 1 ]):
280
+ transformed = op .DFT (transformed , axis = dimension , inverse = False , onesided = False )
281
+ transformed = _fftn_onnx_normalization (transformed , normalization , signal_size )
182
282
283
+ # Unsure if output format is correct
284
+ return transformed
183
285
184
286
def aten_fft_fft (
185
287
self : TensorType , n : Optional [int ] = None , dim : int = - 1 , norm : Optional [str ] = None
0 commit comments