-
Notifications
You must be signed in to change notification settings - Fork 39
Remove duplicate normalisation in FFT methods and enable relevant tests #1469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Due to resource constraints I have been testing only the "usual" use cases (i.e.
and vice versa), and the results seem to be consistent with those from the original pytorch run on CPU. |
@Wetitpig Thanks for your PR. Since I have completed most of the debugging (seems these cases are not fixed in this PR) based on my old branches, |
at::_fft_r2c
and at::_fft_c2r
methods with OneMKL
Now targeting only the repeated normalisation in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Glad to see the refinement. Please ensure that the changes won't break the following cases:
Show Failed cases in op_ut with skip
=========================================================================
test_ops_xpu.py::TestCommonXPU::test_out_fft_fft2_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_fft_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_fftn_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_hfft2_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_hfft_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_hfftn_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_ifft2_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_ifftn_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_irfft2_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_irfft_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_irfftn_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_rfft2_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_rfft_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_rfftn_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_fft2_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_fft_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_fftn_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_hfft2_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_hfft_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_hfftn_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_ifft2_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_ifftn_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_irfft2_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_irfft_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_irfftn_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_rfft2_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_rfft_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_rfftn_xpu
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_bool
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_float64
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_int16
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_int32
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_int64
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_int8
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_uint8
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_stft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_stft_xpu_float64
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_all_strides_fft_fft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_all_strides_stft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_bool
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_float64
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_int16
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_int32
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_int64
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_int8
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_uint8
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_stft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_stft_xpu_float64
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_bool
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_float64
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_int16
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_int32
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_int64
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_int8
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_uint8
yaml/xpu_functions.yaml
Outdated
- _fft_c2c | ||
- _fft_c2c.out | ||
- _fft_c2r | ||
- _fft_c2r.out | ||
- _fft_r2c | ||
- _fft_r2c.out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove these lines. yaml/xpu_functions.yaml
is replaced by yaml/native/native_functions.yaml
.
desc.set_value(config_param::FWD_STRIDES, bwd_strides.data()); | ||
desc.set_value(config_param::BWD_STRIDES, fwd_strides.data()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWD_STRIDES <- bwd_strides.data()
and BWD_STRIDES <- fwd_strides.data()
may cause confusion. Please rename fwd_strides
and bwd_strides
, or try other methods.
auto result = _fft_c2c_mkl( | ||
self, dim, static_cast<int64_t>(fft_norm_mode::none), forward); | ||
at::native::resize_output(out, result.sizes()); | ||
return impl::_fft_apply_normalization_out( | ||
out, result, normalization, result.sizes(), dim); | ||
out.copy_(result); | ||
return out; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The semantics before and after modification are not equivalent. Please note that the normalization
parameter above is fft_norm_mode::none
. Seems you want to apply auto result = _fft_c2c_mkl(self, dim, normalization, forward);
, right?
return impl::_fft_apply_normalization_out( | ||
out, result, normalization, result.sizes(), dim); | ||
out.copy_(result); | ||
return out; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
auto result = _fft_r2c_mkl( | ||
self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true); | ||
|
||
if (onesided) { | ||
return impl::_fft_apply_normalization_out( | ||
out, result, normalization, self.sizes(), dim); | ||
} | ||
|
||
at::native::resize_output(out, self.sizes()); | ||
|
||
auto last_dim = dim.back(); | ||
auto last_dim_halfsize = result.sizes()[last_dim]; | ||
auto out_slice = out.slice(last_dim, 0, last_dim_halfsize); | ||
|
||
impl::_fft_apply_normalization_out( | ||
out_slice, result, normalization, self.sizes(), dim); | ||
at::native::_fft_fill_with_conjugate_symmetry_(out, dim); | ||
at::native::resize_output(out, result.sizes()); | ||
out.copy_(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
std::vector<int64_t> fwd_strides(istrides.cbegin(), istrides.cbegin() + signal_ndim + 1), | ||
bwd_strides(ostrides.cbegin(), ostrides.cbegin() + signal_ndim + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a reminder of columns.
@Wetitpig Please also update PR description, thanks. |
The |
There is a problem of redundant normalisations in the
_fft_(c2c|c2r|r2c)_mkl_out
functions, which should be fixed by this PR. Some loops/conditional branches have also been simplified.