Skip to content

Remove duplicate normalisation in FFT methods and enable relevant tests #1469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

Wetitpig
Copy link

@Wetitpig Wetitpig commented Mar 14, 2025

There is a problem of redundant normalisations in the _fft_(c2c|c2r|r2c)_mkl_out functions, which should be fixed by this PR. Some loops/conditional branches have also been simplified.

@Wetitpig
Copy link
Author

Due to resource constraints I have been testing only the "usual" use cases (i.e.

  • torch.fft.rfft, torch.fft.rfft2 on XPU
  • torch.fft.irfft, torch.fft.irfft2 on CPU

and vice versa), and the results seem to be consistent with those from the original pytorch run on CPU.

@CuiYifeng
Copy link
Contributor

@Wetitpig Thanks for your PR. Since I have completed most of the debugging (seems these cases are not fixed in this PR) based on my old branches, fft_c2r and fft_r2c will be submitted by another PRs.
Glad to see the fixing of normalization. Changing this PR to modify this part may be a better choice. Thank you.

@Wetitpig Wetitpig marked this pull request as draft April 17, 2025 16:52
@Wetitpig Wetitpig changed the title Adding at::_fft_r2c and at::_fft_c2r methods with OneMKL Remove duplicate normalisation in FFT methods and enable relevant tests Apr 27, 2025
@Wetitpig Wetitpig marked this pull request as ready for review April 27, 2025 08:29
@Wetitpig
Copy link
Author

Now targeting only the repeated normalisation in *_out functions in FFT, while enabling tests for FFT and also some code refactoring

Copy link
Contributor

@CuiYifeng CuiYifeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad to see the refinement. Please ensure that the changes won't break the following cases:

Show Failed cases in op_ut with skip
=========================================================================
test_ops_xpu.py::TestCommonXPU::test_out_fft_fft2_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_fft_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_fftn_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_hfft2_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_hfft_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_hfftn_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_ifft2_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_ifftn_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_irfft2_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_irfft_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_irfftn_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_rfft2_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_rfft_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_fft_rfftn_xpu_float32
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_fft2_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_fft_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_fftn_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_hfft2_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_hfft_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_hfftn_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_ifft2_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_ifftn_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_irfft2_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_irfft_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_irfftn_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_rfft2_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_rfft_xpu
test_ops_xpu.py::TestCommonXPU::test_out_warning_fft_rfftn_xpu
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_bool
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_float64
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_int16
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_int32
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_int64
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_int8
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_fft_fft_xpu_uint8
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_stft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_meta_outplace_stft_xpu_float64
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_all_strides_fft_fft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_all_strides_stft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_bool
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_float64
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_int16
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_int32
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_int64
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_int8
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_fft_fft_xpu_uint8
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_stft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_dispatch_symbolic_meta_outplace_stft_xpu_float64
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_bool
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_float32
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_float64
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_int16
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_int32
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_int64
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_int8
test_meta_xpu.py::TestMetaXPU::test_meta_outplace_fft_fft_xpu_uint8

Comment on lines 750 to 755
- _fft_c2c
- _fft_c2c.out
- _fft_c2r
- _fft_c2r.out
- _fft_r2c
- _fft_r2c.out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove these lines. yaml/xpu_functions.yaml is replaced by yaml/native/native_functions.yaml.

Comment on lines 79 to 80
desc.set_value(config_param::FWD_STRIDES, bwd_strides.data());
desc.set_value(config_param::BWD_STRIDES, fwd_strides.data());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWD_STRIDES <- bwd_strides.data() and BWD_STRIDES <- fwd_strides.data() may cause confusion. Please rename fwd_strides and bwd_strides, or try other methods.

Comment on lines 369 to 373
auto result = _fft_c2c_mkl(
self, dim, static_cast<int64_t>(fft_norm_mode::none), forward);
at::native::resize_output(out, result.sizes());
return impl::_fft_apply_normalization_out(
out, result, normalization, result.sizes(), dim);
out.copy_(result);
return out;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semantics before and after modification are not equivalent. Please note that the normalization parameter above is fft_norm_mode::none. Seems you want to apply auto result = _fft_c2c_mkl(self, dim, normalization, forward);, right?

Comment on lines -478 to +448
return impl::_fft_apply_normalization_out(
out, result, normalization, result.sizes(), dim);
out.copy_(result);
return out;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Comment on lines 543 to 545
auto result = _fft_r2c_mkl(
self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true);

if (onesided) {
return impl::_fft_apply_normalization_out(
out, result, normalization, self.sizes(), dim);
}

at::native::resize_output(out, self.sizes());

auto last_dim = dim.back();
auto last_dim_halfsize = result.sizes()[last_dim];
auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);

impl::_fft_apply_normalization_out(
out_slice, result, normalization, self.sizes(), dim);
at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
at::native::resize_output(out, result.sizes());
out.copy_(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Comment on lines 60 to 61
std::vector<int64_t> fwd_strides(istrides.cbegin(), istrides.cbegin() + signal_ndim + 1),
bwd_strides(ostrides.cbegin(), ostrides.cbegin() + signal_ndim + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a reminder of columns.

@CuiYifeng
Copy link
Contributor

@Wetitpig Please also update PR description, thanks.

@Wetitpig Wetitpig requested a review from CuiYifeng April 30, 2025 13:26
@Wetitpig
Copy link
Author

Wetitpig commented May 3, 2025

The onesided parameter of _fft_r2c_mkl_out has also been modified to be passed onward to _fft_r2c_mkl.
I am unable to perform the tests myself due to RuntimeError: UR error on Intel Iris Xe Graphics (intel/intel-extension-for-pytorch#800)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants