Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fft torchop #2141

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
194 changes: 91 additions & 103 deletions onnxscript/function_libs/torch_lib/ops/fft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
"""torch.ops.aten operators under the `fft` module.

Expand All @@ -12,108 +10,53 @@

from __future__ import annotations

from typing import Optional, Sequence
from typing import Literal, Optional, Sequence

Check notice

Code scanning / CodeQL

Unused import

Import of 'Literal' is not used.

from onnxscript import INT64
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
private=True,
complex=True,
trace_only=True,
)
def _fftn_onnx_normalization(
self,
transformed: TFloat,
self: TFloat,
normalization: int,
forward: bool,
dims: Sequence[int],
signal_size: INT64,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question for @justinchuby : what does INT64 type signify? Specifically, is there a convention that tells us whether something is a statically known int value, or a symint, or a dynamic int value?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that op.Size returns an int64, so I was using INT64. However, from the test errors, I realized that I need to cast signal_size to be the same type as self so that I can properly call op.Divide/op.Multiply with self/signal_size being the same type.

I am not sure if this means signal_size should be float or TFloat.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is signal_size an output of an ONNX op? Output from an ONNX op should be Tensor, so it should be TFloat.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s usually symint. We use the same INT64 for symint and tensors

) -> TFloat:
# Obtain the total_sample_count (n) for normalization
self_shape = op.Shape(self)
total_sample_count = op.ReduceProd(op.Gather(self_shape, dims), keepdims=0)
total_sample_count = op.CastLike(total_sample_count, transformed)

# Normalize the result
# Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
# Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
"""
"""
# TODO: Make more efficient - there should be a faster way to recalculate everything
# Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
# Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
# Modes:
# 0: no normalization (backward)
# 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
# 2: divide by signal_size (forward)
if normalization == 1:
# "forward" - normalize by 1/n
if forward:
result = op.Div(transformed, op.Sqrt(total_sample_count))
else:
result = op.Mul(transformed, op.Sqrt(total_sample_count))
self = op.Div(self, op.Sqrt(signal_size))
elif normalization == 2:
# "ortho" - normalize by 1/sqrt(n)
if forward:
result = op.Div(transformed, total_sample_count)
else:
result = transformed
else:
# "backward" - no normalization
if forward:
result = transformed
else:
result = op.Mul(transformed, total_sample_count)

return result

self = op.Div(self, signal_size)
return self

@torch_op(
("aten::_fft_c2c", "aten::_fft_c2r", "aten::_fft_r2c"),
trace_only=True,
private=True,
complex=True,
)
def _fftn_onnx(
self: TFloat, dims: Sequence[int], normalization: int, inverse: bool, onesided: bool
def _fftn_onnx_inverse_normalization(
self: TFloat,
normalization: int,
signal_size: INT64,
) -> TFloat:
"""Standard complex to complex or real to complex FFT (forward or backward).

This is a private shared function for implementing the various FFT functions.

Args:
self: The input tensor.
dims: The dimensions to apply FFT.
normalization: The normalization mode.
inverse: Whether to compute the inverse FFT.
onesided: Whether to compute the one-sided FFT, which retains only the
positive frequencies.

Returns:
The transformed tensor.
"""

# NOTE: trace_only because we need to process each dimension in a loop
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support

# The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
# dimension at the beginning to represent the batch dimension.
transformed = op.Unsqueeze(self, axes=[0])

# Add 1 to account for the batch dimension when counting axes from the left
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]

for dim in new_dims[:-1]:
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)

# Torch computers one-sided FFT on the last dimension only.
if onesided:
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True)
else:
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False)

# Remove the batch dimension
transformed = op.Squeeze(transformed, axes=[0])

return _fftn_onnx_normalization(self, transformed, normalization, not inverse, dims)

"""
# TODO: Make more efficient - there should be a faster way to recalculate everything
# Norm values defined in https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOps.cpp#L117-L131
# Norm modes: https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/aten/src/ATen/native/SpectralOpsUtils.h#L15-L19
# Modes:
# 0: no normalization (backward)
# 1: "ortho" - divide by 1/sqrt(signal_size) (ortho)
# 2: divide by signal_size (forward)
if normalization == 1:
self = op.Mul(self, op.Sqrt(signal_size))
elif normalization == 0:
self = op.Mul(self, signal_size)
return self

@torch_op("aten::_fft_c2c", trace_only=True, complex=True)
def aten__fft_c2c(
Expand All @@ -124,39 +67,74 @@
Standard complex to complex FFT (forward or backward).
"""

# NOTE: trace_only because we need to negate forward
# NOTE: SymInt dim is not support because DFT-17 needs a static axis
# TODO(justinchuby): Make dim dynamic and remove trace_only when ONNX provides support
# NOTE: SymInt dim is not supported because DFT-17 needs a static axis

# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [d - 1 if d < 0 else d for d in dim]
return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False)
assert(dim[2] in dim == 2, "Unexpected input size")

Check failure

Code scanning / CodeQL

Asserting a tuple

Assertion of non-empty tuple is always True.

signal = self
self_rank = len(self.shape)
signal_size = op.Size(signal)

# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]

transformed = signal

for dimension in reversed(dim):
transformed = op.DFT(transformed, axis=dimension, inverse=not forward, onesided=False)
if forward:
transformed = _fftn_onnx_normalization(transformed, normalization, signal_size)
else:
transformed = _fftn_onnx_inverse_normalization(transformed, normalization, signal_size)

# Unsure if output format is correct
return transformed


@torch_op("aten::_fft_c2r", trace_only=True, complex=True)
def aten__fft_c2r(
self: TFloat,
dim: Sequence[int],
normalization: int,
last_dim_size: INT64, # pylint: disable=unused-argument
last_dim_size: INT64,
) -> TFloat:
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor

Complex to real inverse FFT.
"""
assert(dim[2] in dim == 2, "Unexpected input size")

Check failure

Code scanning / CodeQL

Asserting a tuple

Assertion of non-empty tuple is always True.

# TODO(justinchuby): Figure out what last_dim_size does

signal = self
self_rank = len(self.shape)
signal_size = op.Size(signal)

# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=False)
# Take only the real part
real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])

return op.Squeeze(real_part, axes=[-1])
transformed = signal
for dimension in reversed(dim):
transformed = op.DFT(transformed, axis=dimension, inverse=True, onesided=False)
transformed = _fftn_onnx_inverse_normalization(transformed, normalization, signal_size)

# Unsure if output format is correct
transformed = op.Squeeze(transformed, axes=[-1])

if transformed.shape[-1] < last_dim_size:
pads = [0, last_dim_size - transformed.shape[-1]]
mode = 'constant'
constant_value = 0.0
transformed = op.Pad(mode=mode, data=transformed, pads=pads, constant_value=constant_value, axes=[-1])

Check warning on line 130 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L127-L130

Added lines #L127 - L130 were not covered by tests
elif transformed.shape[-1] > last_dim_size:
starts = [0]*(self_rank-1)
ends = list(self.shape)
ends[-1] = last_dim_size
transformed = op.Slice(data=transformed, starts=starts, ends=ends)

Check warning on line 135 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L132-L135

Added lines #L132 - L135 were not covered by tests

return transformed

Check warning on line 137 in onnxscript/function_libs/torch_lib/ops/fft.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/fft.py#L137

Added line #L137 was not covered by tests


@torch_op("aten::_fft_r2c", trace_only=True)
Expand All @@ -174,12 +152,22 @@
# https://onnx.ai/onnx/operators/onnx__DFT.html#inputs

self_rank = len(self.shape)
signal_size = op.Size(signal)

# ONNX DFT input assumes the last dimension is the complex dimension.
# Thus dim=-1 in PyTorch is dim=-2 in ONNX.
dim = [(d - 1) + self_rank if d < 0 else d for d in dim]

return _fftn_onnx(signal, dim, normalization, inverse=False, onesided=onesided)
# Torch computes one-sided FFT on the last dimension only.
transformed = op.DFT(signal, axis=dim[-1], inverse=False, onesided=onesided)
transformed = _fftn_onnx_normalization(transformed, normalization, signal_size)

for dimension in reversed(dim[:-1]):
transformed = op.DFT(transformed, axis=dimension, inverse=False, onesided=False)
transformed = _fftn_onnx_normalization(transformed, normalization, signal_size)

# Unsure if output format is correct
return transformed

def aten_fft_fft(
self: TensorType, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None
Expand Down
38 changes: 19 additions & 19 deletions onnxscript/ir/tensor_adapters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,25 @@ def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype):

@parameterized.parameterized.expand(
[
(torch.bfloat16),
(torch.bool),
(torch.complex128),
(torch.complex64),
(torch.float16),
(torch.float32),
(torch.float64),
(torch.float8_e4m3fn),
(torch.float8_e4m3fnuz),
(torch.float8_e5m2),
(torch.float8_e5m2fnuz),
(torch.int16),
(torch.int32),
(torch.int64),
(torch.int8),
(torch.uint16),
(torch.uint32),
(torch.uint64),
(torch.uint8),
(torch.bfloat16,),
(torch.bool,),
(torch.complex128,),
(torch.complex64,),
(torch.float16,),
(torch.float32,),
(torch.float64,),
(torch.float8_e4m3fn,),
(torch.float8_e4m3fnuz,),
(torch.float8_e5m2,),
(torch.float8_e5m2fnuz,),
(torch.int16,),
(torch.int32,),
(torch.int64,),
(torch.int8,),
(torch.uint16,),
(torch.uint32,),
(torch.uint64,),
(torch.uint8,),
],
)
def test_tobytes(self, dtype: torch.dtype):
Expand Down
5 changes: 5 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,14 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
(0, 1),
(0, 1, 2),
]:
# Slice
yield opinfo_core.SampleInput(
nd_tensor(), dim=dim, normalization=normalization, last_dim_size=6
)
# Pad
yield opinfo_core.SampleInput(
nd_tensor(), dim=dim, normalization=normalization, last_dim_size=64
)


def _index_variable_bool(shape, max_indices, device):
Expand Down
3 changes: 0 additions & 3 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,6 @@ def _where_input_wrangler(
fft_ops.aten__fft_c2r,
tolerance={torch.complex64: (3e-3, 1.8e-4)},
complex=True,
).xfail(
dtypes=(torch.complex64,),
reason="fixme: the result is wrong: https://github.com/microsoft/onnxscript/pull/926",
),
TorchLibOpInfo(
"ops.aten._fft_r2c", # Custom from extra_opinfo
Expand Down
Loading