From ba404c664eece1f5621e9d52c0b3fc812c4578e1 Mon Sep 17 00:00:00 2001 From: Wetitpig Date: Sun, 9 Mar 2025 14:07:07 +0800 Subject: [PATCH 1/6] Reduce conditional branches in FFT code --- src/ATen/native/xpu/mkl/SpectralOps.cpp | 76 ++++++------------- src/ATen/native/xpu/sycl/FFTKernelFunctor.cpp | 8 +- yaml/xpu_functions.yaml | 6 ++ 3 files changed, 34 insertions(+), 56 deletions(-) diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp index 4f1e028b4..0b04fce17 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.cpp +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -57,18 +57,10 @@ void _mkl_dft( int64_t idist = istrides[0]; int64_t odist = ostrides[0]; - std::vector fwd_strides(1 + signal_ndim, 0), - bwd_strides(1 + signal_ndim, 0); - - for (int64_t i = 1; i <= signal_ndim; i++) { - if (!inverse) { - fwd_strides[i] = istrides[i]; - bwd_strides[i] = ostrides[i]; - } else { - fwd_strides[i] = ostrides[i]; - bwd_strides[i] = istrides[i]; - } - } + std::vector fwd_strides(istrides.cbegin(), istrides.cbegin() + signal_ndim + 1), + bwd_strides(ostrides.cbegin(), ostrides.cbegin() + signal_ndim + 1); + fwd_strides[0] = 0; + bwd_strides[0] = 0; auto desc = descriptor(mkl_signal_sizes); desc.set_value(config_param::PLACEMENT, config_value::NOT_INPLACE); @@ -77,16 +69,15 @@ void _mkl_dft( if (!inverse) { desc.set_value(config_param::FWD_DISTANCE, idist); desc.set_value(config_param::BWD_DISTANCE, odist); + + desc.set_value(config_param::FWD_STRIDES, fwd_strides.data()); + desc.set_value(config_param::BWD_STRIDES, bwd_strides.data()); } else { desc.set_value(config_param::FWD_DISTANCE, odist); desc.set_value(config_param::BWD_DISTANCE, idist); - } - if (!fwd_strides.empty()) { - desc.set_value(config_param::FWD_STRIDES, fwd_strides.data()); - } - if (!bwd_strides.empty()) { - desc.set_value(config_param::BWD_STRIDES, bwd_strides.data()); + desc.set_value(config_param::FWD_STRIDES, bwd_strides.data()); + desc.set_value(config_param::BWD_STRIDES, fwd_strides.data()); } if (!complex_input || !complex_output) { @@ -136,10 +127,10 @@ void _fft_with_size( // real/imag dimension must aligned when viewed as of complex type if (complex_input) { - bool need_contiguous = input_.stride(-1) != 1; - + const auto strides = input_.strides(); + bool need_contiguous = strides.back() != 1; for (int64_t i = 0; !need_contiguous && i <= signal_ndim; i++) { - need_contiguous |= input_.stride(i) % 2 != 0; + need_contiguous |= strides[i] % 2; } if (need_contiguous) { @@ -230,12 +221,13 @@ Tensor& _exec_fft( batched_sizes.begin() + 1); input = input.reshape(batched_sizes); - const auto batch_size = input.sizes()[0]; + const auto in_sizes = input.sizes(); + const auto batch_size = in_sizes[0]; DimVector signal_size(signal_ndim + 1); signal_size[0] = batch_size; for (const auto i : c10::irange(signal_ndim)) { - auto in_size = input.sizes()[i + 1]; + auto in_size = in_sizes[i + 1]; auto out_size = out_sizes[dim[i]]; signal_size[i + 1] = std::max(in_size, out_size); TORCH_INTERNAL_ASSERT( @@ -272,12 +264,12 @@ Tensor& _exec_fft( int64_t batch_numel = 1; for (int64_t i = batch_dims - 1; i >= 0; --i) { - out_strides[dim_permute[i]] = batch_numel * out.strides()[0]; + out_strides[dim_permute[i]] = batch_numel * out.stride(0); batch_numel *= out_sizes[dim_permute[i]]; } for (const auto i : c10::irange(batch_dims, ndim)) { - out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)]; + out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims)); } out.as_strided_(out_sizes, out_strides, out.storage_offset()); @@ -330,16 +322,6 @@ const Tensor& _fft_apply_normalization( return (scale == 1.0) ? self : self.mul_(scale); } -Tensor& _fft_apply_normalization_out( - Tensor& out, - const Tensor& self, - int64_t normalization, - IntArrayRef sizes, - IntArrayRef dims) { - auto scale = _dft_scale(dims, sizes, self.sizes(), normalization); - return at::mul_out(out, self, c10::scalar_to_tensor(scale)); -} - } // namespace impl Tensor _fft_c2c_mkl( @@ -399,8 +381,8 @@ Tensor& _fft_c2c_mkl_out( auto result = _fft_c2c_mkl( self, dim, static_cast(fft_norm_mode::none), forward); at::native::resize_output(out, result.sizes()); - return impl::_fft_apply_normalization_out( - out, result, normalization, result.sizes(), dim); + out.copy_(result); + return out; } void HermitSymmImpl(Tensor& input, int64_t dim, int pos) { @@ -475,8 +457,8 @@ Tensor& _fft_c2r_mkl_out( auto result = _fft_c2r_mkl( self, dim, static_cast(fft_norm_mode::none), last_dim_size); at::native::resize_output(out, result.sizes()); - return impl::_fft_apply_normalization_out( - out, result, normalization, result.sizes(), dim); + out.copy_(result); + return out; } REGISTER_XPU_DISPATCH( @@ -573,20 +555,8 @@ Tensor& _fft_r2c_mkl_out( auto result = _fft_r2c_mkl( self, dim, static_cast(fft_norm_mode::none), /*onesided=*/true); - if (onesided) { - return impl::_fft_apply_normalization_out( - out, result, normalization, self.sizes(), dim); - } - - at::native::resize_output(out, self.sizes()); - - auto last_dim = dim.back(); - auto last_dim_halfsize = result.sizes()[last_dim]; - auto out_slice = out.slice(last_dim, 0, last_dim_halfsize); - - impl::_fft_apply_normalization_out( - out_slice, result, normalization, self.sizes(), dim); - at::native::_fft_fill_with_conjugate_symmetry_(out, dim); + at::native::resize_output(out, result.sizes()); + out.copy_(result); return out; } diff --git a/src/ATen/native/xpu/sycl/FFTKernelFunctor.cpp b/src/ATen/native/xpu/sycl/FFTKernelFunctor.cpp index 8cf2e6257..fb39630f8 100644 --- a/src/ATen/native/xpu/sycl/FFTKernelFunctor.cpp +++ b/src/ATen/native/xpu/sycl/FFTKernelFunctor.cpp @@ -28,11 +28,13 @@ struct HermitianSymmetryOffsetCalculator { TORCH_INTERNAL_ASSERT(sizes.size() <= XPU_MAX_TENSORINFO_DIMS); dims = sizes.size(); - for (dim_type i = 0; i < XPU_MAX_TENSORINFO_DIMS; ++i) { - if (i < dims) { + { + dim_type i; + for (i = 0; i < dims; ++i) { sizes_[i] = at::detail::IntDivider(sizes[i]); strides_[i] = strides[i] / element_size; - } else { + } + for (; i < XPU_MAX_TENSORINFO_DIMS; ++i) { sizes_[i] = at::detail::IntDivider(1); strides_[i] = 0; } diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 6af8143f0..6621e93e3 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -747,3 +747,9 @@ supported: - take.out - segment_reduce - _segment_reduce_backward + - _fft_c2c + - _fft_c2c.out + - _fft_c2r + - _fft_c2r.out + - _fft_r2c + - _fft_r2c.out From b945b50953934c532ddce4a5ed956a75459eac68 Mon Sep 17 00:00:00 2001 From: Wetitpig Date: Tue, 11 Mar 2025 10:04:43 +0800 Subject: [PATCH 2/6] Do not skip FFT tests --- test/xpu/skip_list_common.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 0de152f27..19585f2a7 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -635,25 +635,6 @@ "test_python_ref_torch_fallback__refs_div_trunc_rounding_xpu_float64", # TODO: passed from source code building version, investigate "test_python_ref__refs_log2_xpu_complex128", - # The following dtypes did not work in backward but are listed by the OpInfo: {torch.bfloat16}. - "test_dtypes_fft_fft2_xpu", - "test_dtypes_fft_fft_xpu", - "test_dtypes_fft_fftn_xpu", - "test_dtypes_fft_hfft2_xpu", - "test_dtypes_fft_hfft_xpu", - "test_dtypes_fft_hfftn_xpu", - "test_dtypes_fft_ifft2_xpu", - "test_dtypes_fft_ifft_xpu", - "test_dtypes_fft_ifftn_xpu", - "test_dtypes_fft_ihfft2_xpu", - "test_dtypes_fft_ihfft_xpu", - "test_dtypes_fft_ihfftn_xpu", - "test_dtypes_fft_irfft2_xpu", - "test_dtypes_fft_irfft_xpu", - "test_dtypes_fft_irfftn_xpu", - "test_dtypes_fft_rfft2_xpu", - "test_dtypes_fft_rfft_xpu", - "test_dtypes_fft_rfftn_xpu", ), "test_binary_ufuncs_xpu.py": ( "test_fmod_remainder_by_zero_integral_xpu_int64", # zero division is an undefined behavior: different handles on different backends From 8faa4c39c5380bcc543d59c7e01997b0b71730bc Mon Sep 17 00:00:00 2001 From: Wetitpig Date: Tue, 11 Mar 2025 22:56:39 +0800 Subject: [PATCH 3/6] Move dft_scale to individual functions --- src/ATen/native/xpu/mkl/SpectralOps.cpp | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp index 0b04fce17..ae986d1df 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.cpp +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -279,8 +279,7 @@ Tensor& _exec_fft( double _dft_scale( IntArrayRef dim, - IntArrayRef input_sizes, - IntArrayRef out_sizes, + IntArrayRef norm_sizes, int64_t normalization) { const auto norm = static_cast(normalization); double double_scale = 1.0; @@ -289,21 +288,10 @@ double _dft_scale( return double_scale; } - const int64_t signal_ndim = dim.size(); int64_t signal_numel = 1; - - for (int64_t i = 0; i < signal_ndim; ++i) { - auto in_size = input_sizes[dim[i]]; - auto out_size = out_sizes[dim[i]]; - auto signal_size = std::max(in_size, out_size); - - signal_numel *= signal_size; - TORCH_INTERNAL_ASSERT( - in_size == signal_size || in_size == (signal_size / 2) + 1); - TORCH_INTERNAL_ASSERT( - out_size == signal_size || out_size == (signal_size / 2) + 1); + for (const int64_t& d : dim) { + signal_numel *= norm_sizes[d]; } - if (norm == fft_norm_mode::by_root_n) { double_scale = 1.0 / std::sqrt(signal_numel); } else { @@ -316,9 +304,9 @@ double _dft_scale( const Tensor& _fft_apply_normalization( const Tensor& self, int64_t normalization, - IntArrayRef sizes, + IntArrayRef norm_sizes, IntArrayRef dims) { - auto scale = _dft_scale(dims, sizes, self.sizes(), normalization); + auto scale = _dft_scale(dims, norm_sizes, normalization); return (scale == 1.0) ? self : self.mul_(scale); } From 86c865b2b85a4209da0bf80ad47c8e7d20d45ca6 Mon Sep 17 00:00:00 2001 From: Wetitpig Date: Wed, 30 Apr 2025 16:49:37 +0800 Subject: [PATCH 4/6] Revert yaml/xpu_functions.yaml --- yaml/xpu_functions.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 6621e93e3..6af8143f0 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -747,9 +747,3 @@ supported: - take.out - segment_reduce - _segment_reduce_backward - - _fft_c2c - - _fft_c2c.out - - _fft_c2r - - _fft_c2r.out - - _fft_r2c - - _fft_r2c.out From afee28dd384d68983fe3c471c03a1c8d51a95589 Mon Sep 17 00:00:00 2001 From: Wetitpig Date: Wed, 30 Apr 2025 16:58:29 +0800 Subject: [PATCH 5/6] Fix normalisation modes in `_out` functions --- src/ATen/native/xpu/mkl/SpectralOps.cpp | 26 ++++++++++++------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp index ae986d1df..dbd976a44 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.cpp +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -57,10 +57,11 @@ void _mkl_dft( int64_t idist = istrides[0]; int64_t odist = ostrides[0]; - std::vector fwd_strides(istrides.cbegin(), istrides.cbegin() + signal_ndim + 1), - bwd_strides(ostrides.cbegin(), ostrides.cbegin() + signal_ndim + 1); - fwd_strides[0] = 0; - bwd_strides[0] = 0; + std::vector input_strides( + istrides.cbegin(), istrides.cbegin() + signal_ndim + 1), + output_strides(ostrides.cbegin(), ostrides.cbegin() + signal_ndim + 1); + input_strides[0] = 0; + output_strides[0] = 0; auto desc = descriptor(mkl_signal_sizes); desc.set_value(config_param::PLACEMENT, config_value::NOT_INPLACE); @@ -70,14 +71,14 @@ void _mkl_dft( desc.set_value(config_param::FWD_DISTANCE, idist); desc.set_value(config_param::BWD_DISTANCE, odist); - desc.set_value(config_param::FWD_STRIDES, fwd_strides.data()); - desc.set_value(config_param::BWD_STRIDES, bwd_strides.data()); + desc.set_value(config_param::FWD_STRIDES, input_strides.data()); + desc.set_value(config_param::BWD_STRIDES, output_strides.data()); } else { desc.set_value(config_param::FWD_DISTANCE, odist); desc.set_value(config_param::BWD_DISTANCE, idist); - desc.set_value(config_param::FWD_STRIDES, bwd_strides.data()); - desc.set_value(config_param::BWD_STRIDES, fwd_strides.data()); + desc.set_value(config_param::FWD_STRIDES, output_strides.data()); + desc.set_value(config_param::BWD_STRIDES, input_strides.data()); } if (!complex_input || !complex_output) { @@ -366,8 +367,7 @@ Tensor& _fft_c2c_mkl_out( int64_t normalization, bool forward, Tensor& out) { - auto result = _fft_c2c_mkl( - self, dim, static_cast(fft_norm_mode::none), forward); + auto result = _fft_c2c_mkl(self, dim, normalization, forward); at::native::resize_output(out, result.sizes()); out.copy_(result); return out; @@ -442,8 +442,7 @@ Tensor& _fft_c2r_mkl_out( int64_t normalization, int64_t last_dim_size, Tensor& out) { - auto result = _fft_c2r_mkl( - self, dim, static_cast(fft_norm_mode::none), last_dim_size); + auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size); at::native::resize_output(out, result.sizes()); out.copy_(result); return out; @@ -540,8 +539,7 @@ Tensor& _fft_r2c_mkl_out( int64_t normalization, bool onesided, Tensor& out) { - auto result = _fft_r2c_mkl( - self, dim, static_cast(fft_norm_mode::none), /*onesided=*/true); + auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true); at::native::resize_output(out, result.sizes()); out.copy_(result); From 276f2e77e2c52423dec2aa982be961cd54886181 Mon Sep 17 00:00:00 2001 From: Wetitpig Date: Fri, 2 May 2025 23:20:29 +0800 Subject: [PATCH 6/6] Pass `onesided` to `_fft_r2c_mkl_out` --- src/ATen/native/xpu/mkl/SpectralOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp index dbd976a44..6ffdc2087 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.cpp +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -539,7 +539,7 @@ Tensor& _fft_r2c_mkl_out( int64_t normalization, bool onesided, Tensor& out) { - auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true); + auto result = _fft_r2c_mkl(self, dim, normalization, onesided); at::native::resize_output(out, result.sizes()); out.copy_(result);