Skip to content

Commit 911524b

Browse files
committed
Enable fft_c2r
1 parent c9ed050 commit 911524b

File tree

5 files changed

+126
-3
lines changed

5 files changed

+126
-3
lines changed

src/ATen/native/xpu/SpectralOps.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,25 @@ Tensor& _fft_c2c_xpu_out(
4242
#endif // USE_ONEMKL
4343
}
4444

45+
Tensor _fft_c2r_xpu(
46+
const Tensor& self,
47+
IntArrayRef dim,
48+
int64_t normalization,
49+
int64_t last_dim_size) {
50+
TORCH_CHECK(self.is_complex());
51+
52+
return native::xpu::_fft_c2r_mkl(self, dim, normalization, last_dim_size);
53+
}
54+
55+
Tensor _fft_c2r_xpu_out(
56+
const Tensor& self,
57+
IntArrayRef dim,
58+
int64_t normalization,
59+
int64_t last_dim_size,
60+
Tensor& out) {
61+
TORCH_CHECK(self.is_complex());
62+
63+
return native::xpu::_fft_c2r_mkl_out(self, dim, normalization, last_dim_size, out);
64+
}
65+
4566
} // namespace at::native

src/ATen/native/xpu/XPUFallback.template

-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
190190
"_cholesky_solve_helper",
191191
"dot",
192192
"_efficient_attention_forward",
193-
"_fft_c2r",
194193
"_fft_r2c",
195194
"_flash_attention_forward",
196195
"geqrf",

src/ATen/native/xpu/mkl/SpectralOps.cpp

+81-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
#include <ATen/native/Resize.h>
33
#include <ATen/native/SpectralOpsUtils.h>
44
#include <ATen/native/xpu/mkl/SpectralOps.h>
5+
#include <ATen/ops/complex.h>
6+
#include <ATen/ops/imag.h>
57
#include <ATen/ops/mul.h>
8+
#include <ATen/ops/real.h>
9+
#include <ATen/ops/zeros_like.h>
610
#include <comm/SYCLContext.h>
711
#include <comm/TensorInfo.h>
812
#include <oneapi/mkl.hpp>
@@ -84,8 +88,7 @@ void _mkl_dft(
8488
}
8589

8690
if (!complex_input || !complex_output) {
87-
desc.set_value(
88-
config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
91+
desc.set_value(config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
8992
}
9093

9194
desc.set_value(
@@ -398,5 +401,81 @@ Tensor& _fft_c2c_mkl_out(
398401
out, result, normalization, result.sizes(), dim);
399402
}
400403

404+
void HermitSymmImpl(Tensor& input, int64_t dim, int pos) {
405+
std::vector<at::indexing::TensorIndex> indices(
406+
input.dim(), at::indexing::Slice());
407+
408+
indices[dim] = pos;
409+
410+
Tensor values = at::complex(
411+
at::real(input.index(indices)),
412+
at::zeros_like(at::imag(input.index(indices))));
413+
414+
input.index_put_(indices, values);
415+
}
416+
417+
void HermitSymm(Tensor& input, int64_t dim, int64_t out_size) {
418+
HermitSymmImpl(input, dim, 0);
419+
420+
if (out_size % 2 == 0)
421+
HermitSymmImpl(input, dim, -1);
422+
}
423+
424+
Tensor _fft_c2r_mkl(
425+
const Tensor& self,
426+
IntArrayRef dim,
427+
int64_t normalization,
428+
int64_t last_dim_size) {
429+
if (dim.empty()) {
430+
return self.clone();
431+
}
432+
433+
auto input = self;
434+
435+
if (dim.size() > 1) {
436+
auto c2c_dims = dim.slice(0, dim.size() - 1);
437+
input = _fft_c2c_mkl(
438+
self,
439+
c2c_dims,
440+
static_cast<int64_t>(fft_norm_mode::none),
441+
/*forward=*/false);
442+
}
443+
444+
auto in_sizes = input.sizes();
445+
DimVector out_sizes(in_sizes.begin(), in_sizes.end());
446+
out_sizes[dim.back()] = last_dim_size;
447+
448+
auto out = at::empty(
449+
out_sizes,
450+
self.options().dtype(c10::toRealValueType(self.scalar_type())));
451+
452+
input = input.clone(MemoryFormat::Contiguous);
453+
454+
HermitSymm(input, dim.back(), out_sizes[dim.back()]);
455+
456+
impl::_exec_fft(
457+
out,
458+
input,
459+
out_sizes,
460+
dim.back(),
461+
/*onesided=*/true,
462+
/*forward=*/false);
463+
464+
return impl::_fft_apply_normalization(out, normalization, out_sizes, dim);
465+
}
466+
467+
Tensor& _fft_c2r_mkl_out(
468+
const Tensor& self,
469+
IntArrayRef dim,
470+
int64_t normalization,
471+
int64_t last_dim_size,
472+
Tensor& out) {
473+
auto result = _fft_c2c_mkl(
474+
self, dim, static_cast<int64_t>(fft_norm_mode::none), last_dim_size);
475+
at::native::resize_output(out, result.sizes());
476+
return impl::_fft_apply_normalization_out(
477+
out, result, normalization, result.sizes(), dim);
478+
}
479+
401480
} // namespace at::native::xpu
402481
#endif // USE_ONEMKL

src/ATen/native/xpu/mkl/SpectralOps.h

+13
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,17 @@ TORCH_XPU_API Tensor& _fft_c2c_mkl_out(
1717
bool forward,
1818
Tensor& out);
1919

20+
TORCH_XPU_API Tensor _fft_c2r_mkl(
21+
const Tensor& self,
22+
IntArrayRef dim,
23+
int64_t normalization,
24+
int64_t last_dim_size);
25+
26+
TORCH_XPU_API Tensor& _fft_c2r_mkl_out(
27+
const Tensor& self,
28+
IntArrayRef dim,
29+
int64_t normalization,
30+
int64_t last_dim_size,
31+
Tensor& out);
32+
2033
} // namespace at::native::xpu

yaml/native/native_functions.yaml

+11
Original file line numberDiff line numberDiff line change
@@ -9322,3 +9322,14 @@
93229322
variants: function
93239323
dispatch:
93249324
XPU: _fft_c2c_xpu_out
9325+
9326+
# Complex to real inverse FFT
9327+
- func: _fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
9328+
variants: function
9329+
dispatch:
9330+
XPU: _fft_c2r_xpu
9331+
9332+
- func: _fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!)
9333+
variants: function
9334+
dispatch:
9335+
XPU: _fft_c2r_xpu_out

0 commit comments

Comments
 (0)