-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement fft torchop #2141
base: main
Are you sure you want to change the base?
Implement fft torchop #2141
Conversation
@@ -12,108 +10,53 @@ | |||
|
|||
from __future__ import annotations | |||
|
|||
from typing import Optional, Sequence | |||
from typing import Literal, Optional, Sequence |
Check notice
Code scanning / CodeQL
Unused import
|
||
# ONNX DFT input assumes the last dimension is the complex dimension. | ||
# Thus dim=-1 in PyTorch is dim=-2 in ONNX. | ||
dim = [d - 1 if d < 0 else d for d in dim] | ||
return _fftn_onnx(self, dim, normalization, inverse=not forward, onesided=False) | ||
assert(dim[2] in dim == 2, "Unexpected input size") |
Check failure
Code scanning / CodeQL
Asserting a tuple
) -> TFloat: | ||
"""_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor | ||
|
||
Complex to real inverse FFT. | ||
""" | ||
assert(dim[2] in dim == 2, "Unexpected input size") |
Check failure
Code scanning / CodeQL
Asserting a tuple
❌ 2 Tests Failed:
View the full list of 2 ❄️ flaky tests
To view more test analytics, go to the Test Analytics Dashboard |
The error looks like a mismatch of complex (torch) vs real representation (ONNX). Maybe explore
Or using
|
normalization: int, | ||
forward: bool, | ||
dims: Sequence[int], | ||
signal_size: INT64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question for @justinchuby : what does INT64 type signify? Specifically, is there a convention that tells us whether something is a statically known int value, or a symint, or a dynamic int value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw that op.Size
returns an int64, so I was using INT64
. However, from the test errors, I realized that I need to cast signal_size
to be the same type as self
so that I can properly call op.Divide/op.Multiply with self/signal_size being the same type.
I am not sure if this means signal_size
should be float
or TFloat
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is signal_size
an output of an ONNX op? Output from an ONNX op should be Tensor, so it should be TFloat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We check https://github.com/pytorch/pytorch/blob/68414512e6fe641b02a6fe217fd516b7b776ea0d/aten/src/ATen/native/native_functions.yaml#L2991 for the signature of functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s usually symint. We use the same INT64 for symint and tensors
output doesn't need to be converted. We always use the real repr for complex values in onnx. The torch exporter will keep track of this information. |
@bmehta001 please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
@@ -21,6 +21,12 @@ | |||
from onnxscript.onnx_types import TensorType | |||
|
|||
|
|||
@torch_op( | |||
("aten::_fft_c2c", "aten::_fft_r2c"), | |||
private=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to add the torch_op decorator to private functions now. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, that's me told him to do so. Sorry, @bmehta001
WIP
r2c = forwards, could be one-sided
c2r = backwards/inverse, never one-sided
c2c could be either forwards/backwards, never one-sided
Must respect normalization method provided - however, op.DFT calls "backwards" normalization, if 'inverse' is set to True, so need to account for normalization being done by op.DFT
When running above functions across multiple axes, need to run FFT in reverse order through op.DFT one-by-one
Currently have issues with: