|
3 | 3 | #else
|
4 | 4 | #include <ATen/native/Resize.h>
|
5 | 5 | #include <ATen/ops/_fft_c2c_native.h>
|
| 6 | +#include <ATen/ops/_fft_c2r_native.h> |
6 | 7 | #endif // USE_ONEMKL
|
7 | 8 |
|
8 | 9 | namespace at::native {
|
@@ -49,18 +50,33 @@ Tensor _fft_c2r_xpu(
|
49 | 50 | int64_t last_dim_size) {
|
50 | 51 | TORCH_CHECK(self.is_complex());
|
51 | 52 |
|
| 53 | +#if defined(USE_ONEMKL) |
52 | 54 | return native::xpu::_fft_c2r_mkl(self, dim, normalization, last_dim_size);
|
| 55 | +#else |
| 56 | + Tensor out_cpu = native::_fft_c2r_mkl( |
| 57 | + self.to(Device(at::kCPU)), dim, normalization, last_dim_size); |
| 58 | + return out_cpu.to(Device(at::kXPU)); |
| 59 | +#endif // USE_ONEMKL |
53 | 60 | }
|
54 | 61 |
|
55 |
| -Tensor _fft_c2r_xpu_out( |
| 62 | +Tensor& _fft_c2r_xpu_out( |
56 | 63 | const Tensor& self,
|
57 | 64 | IntArrayRef dim,
|
58 | 65 | int64_t normalization,
|
59 | 66 | int64_t last_dim_size,
|
60 | 67 | Tensor& out) {
|
61 | 68 | TORCH_CHECK(self.is_complex());
|
62 | 69 |
|
63 |
| - return native::xpu::_fft_c2r_mkl_out(self, dim, normalization, last_dim_size, out); |
| 70 | +#if defined(USE_ONEMKL) |
| 71 | + return native::xpu::_fft_c2r_mkl_out( |
| 72 | + self, dim, normalization, last_dim_size, out); |
| 73 | +#else |
| 74 | + Tensor out_cpu = native::_fft_c2r_mkl( |
| 75 | + self.to(Device(at::kCPU)), dim, normalization, last_dim_size); |
| 76 | + at::native::resize_output(out, out_cpu.sizes()); |
| 77 | + out.copy_(out_cpu); |
| 78 | + return out; |
| 79 | +#endif // USE_ONEMKL |
64 | 80 | }
|
65 | 81 |
|
66 | 82 | } // namespace at::native
|
0 commit comments