|
2 | 2 | #include <ATen/native/Resize.h>
|
3 | 3 | #include <ATen/native/SpectralOpsUtils.h>
|
4 | 4 | #include <ATen/native/xpu/mkl/SpectralOps.h>
|
| 5 | +#include <ATen/ops/complex.h> |
| 6 | +#include <ATen/ops/imag.h> |
5 | 7 | #include <ATen/ops/mul.h>
|
| 8 | +#include <ATen/ops/real.h> |
| 9 | +#include <ATen/ops/zeros_like.h> |
6 | 10 | #include <comm/SYCLContext.h>
|
7 | 11 | #include <comm/TensorInfo.h>
|
8 | 12 | #include <oneapi/mkl.hpp>
|
@@ -84,8 +88,7 @@ void _mkl_dft(
|
84 | 88 | }
|
85 | 89 |
|
86 | 90 | if (!complex_input || !complex_output) {
|
87 |
| - desc.set_value( |
88 |
| - config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX); |
| 91 | + desc.set_value(config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX); |
89 | 92 | }
|
90 | 93 |
|
91 | 94 | desc.set_value(
|
@@ -398,5 +401,81 @@ Tensor& _fft_c2c_mkl_out(
|
398 | 401 | out, result, normalization, result.sizes(), dim);
|
399 | 402 | }
|
400 | 403 |
|
| 404 | +void HermitSymmImpl(Tensor& input, int64_t dim, int pos) { |
| 405 | + std::vector<at::indexing::TensorIndex> indices( |
| 406 | + input.dim(), at::indexing::Slice()); |
| 407 | + |
| 408 | + indices[dim] = pos; |
| 409 | + |
| 410 | + Tensor values = at::complex( |
| 411 | + at::real(input.index(indices)), |
| 412 | + at::zeros_like(at::imag(input.index(indices)))); |
| 413 | + |
| 414 | + input.index_put_(indices, values); |
| 415 | +} |
| 416 | + |
| 417 | +void HermitSymm(Tensor& input, int64_t dim, int64_t out_size) { |
| 418 | + HermitSymmImpl(input, dim, 0); |
| 419 | + |
| 420 | + if (out_size % 2 == 0) |
| 421 | + HermitSymmImpl(input, dim, -1); |
| 422 | +} |
| 423 | + |
| 424 | +Tensor _fft_c2r_mkl( |
| 425 | + const Tensor& self, |
| 426 | + IntArrayRef dim, |
| 427 | + int64_t normalization, |
| 428 | + int64_t last_dim_size) { |
| 429 | + if (dim.empty()) { |
| 430 | + return self.clone(); |
| 431 | + } |
| 432 | + |
| 433 | + auto input = self; |
| 434 | + |
| 435 | + if (dim.size() > 1) { |
| 436 | + auto c2c_dims = dim.slice(0, dim.size() - 1); |
| 437 | + input = _fft_c2c_mkl( |
| 438 | + self, |
| 439 | + c2c_dims, |
| 440 | + static_cast<int64_t>(fft_norm_mode::none), |
| 441 | + /*forward=*/false); |
| 442 | + } |
| 443 | + |
| 444 | + auto in_sizes = input.sizes(); |
| 445 | + DimVector out_sizes(in_sizes.begin(), in_sizes.end()); |
| 446 | + out_sizes[dim.back()] = last_dim_size; |
| 447 | + |
| 448 | + auto out = at::empty( |
| 449 | + out_sizes, |
| 450 | + self.options().dtype(c10::toRealValueType(self.scalar_type()))); |
| 451 | + |
| 452 | + input = input.clone(MemoryFormat::Contiguous); |
| 453 | + |
| 454 | + HermitSymm(input, dim.back(), out_sizes[dim.back()]); |
| 455 | + |
| 456 | + impl::_exec_fft( |
| 457 | + out, |
| 458 | + input, |
| 459 | + out_sizes, |
| 460 | + dim.back(), |
| 461 | + /*onesided=*/true, |
| 462 | + /*forward=*/false); |
| 463 | + |
| 464 | + return impl::_fft_apply_normalization(out, normalization, out_sizes, dim); |
| 465 | +} |
| 466 | + |
| 467 | +Tensor& _fft_c2r_mkl_out( |
| 468 | + const Tensor& self, |
| 469 | + IntArrayRef dim, |
| 470 | + int64_t normalization, |
| 471 | + int64_t last_dim_size, |
| 472 | + Tensor& out) { |
| 473 | + auto result = _fft_c2c_mkl( |
| 474 | + self, dim, static_cast<int64_t>(fft_norm_mode::none), last_dim_size); |
| 475 | + at::native::resize_output(out, result.sizes()); |
| 476 | + return impl::_fft_apply_normalization_out( |
| 477 | + out, result, normalization, result.sizes(), dim); |
| 478 | +} |
| 479 | + |
401 | 480 | } // namespace at::native::xpu
|
402 | 481 | #endif // USE_ONEMKL
|
0 commit comments