|
2 | 2 | #include <ATen/native/Resize.h>
|
3 | 3 | #include <ATen/native/SpectralOpsUtils.h>
|
4 | 4 | #include <ATen/native/xpu/mkl/SpectralOps.h>
|
| 5 | +#include <ATen/native/xpu/sycl/FFTKernelFunctor.h> |
5 | 6 | #include <ATen/ops/complex.h>
|
6 | 7 | #include <ATen/ops/imag.h>
|
7 | 8 | #include <ATen/ops/mul.h>
|
@@ -52,6 +53,7 @@ void _mkl_dft(
|
52 | 53 |
|
53 | 54 | auto istrides = input.strides();
|
54 | 55 | auto ostrides = output.strides();
|
| 56 | + |
55 | 57 | int64_t idist = istrides[0];
|
56 | 58 | int64_t odist = ostrides[0];
|
57 | 59 |
|
@@ -477,5 +479,116 @@ Tensor& _fft_c2r_mkl_out(
|
477 | 479 | out, result, normalization, result.sizes(), dim);
|
478 | 480 | }
|
479 | 481 |
|
| 482 | +REGISTER_XPU_DISPATCH( |
| 483 | + fft_fill_with_conjugate_symmetry_stub, |
| 484 | + &_fft_fill_with_conjugate_symmetry_xpu); |
| 485 | + |
| 486 | +Tensor _fft_r2c_mkl( |
| 487 | + const Tensor& self, |
| 488 | + IntArrayRef dim, |
| 489 | + int64_t normalization, |
| 490 | + bool onesided) { |
| 491 | + if (dim.empty()) { |
| 492 | + return self.clone(); |
| 493 | + } |
| 494 | + |
| 495 | + auto input_sizes = self.sizes(); |
| 496 | + DimVector onesided_sizes(input_sizes.begin(), input_sizes.end()); |
| 497 | + auto last_dim = dim.back(); |
| 498 | + auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; |
| 499 | + onesided_sizes[last_dim] = last_dim_halfsize; |
| 500 | + |
| 501 | + IntArrayRef out_sizes = onesided ? onesided_sizes : input_sizes; |
| 502 | + |
| 503 | + auto sorted_dims = impl::_sort_dims(self, dim, /*exclude_last=*/true); |
| 504 | + auto out = at::empty( |
| 505 | + out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type()))); |
| 506 | + |
| 507 | + auto working_tensor = self.clone(MemoryFormat::Contiguous); |
| 508 | + |
| 509 | + // First do the R2C transform on the last dimension |
| 510 | + impl::_exec_fft( |
| 511 | + out, working_tensor, out_sizes, last_dim, onesided, /*forward=*/true); |
| 512 | + |
| 513 | + if (dim.size() > 1) { |
| 514 | + working_tensor = at::empty( |
| 515 | + out_sizes, |
| 516 | + self.options().dtype(c10::toComplexType(self.scalar_type()))); |
| 517 | + } |
| 518 | + |
| 519 | + sorted_dims.resize(sorted_dims.size() - 1); |
| 520 | + |
| 521 | + while (!sorted_dims.empty()) { |
| 522 | + if (working_tensor.is_same(self)) { |
| 523 | + working_tensor = std::move(out); |
| 524 | + out = at::empty( |
| 525 | + out_sizes, |
| 526 | + self.options().dtype(c10::toComplexType(self.scalar_type()))); |
| 527 | + } else { |
| 528 | + std::swap(out, working_tensor); |
| 529 | + } |
| 530 | + |
| 531 | + const auto max_dims = |
| 532 | + std::min(static_cast<size_t>(impl::mkl_max_ndim), sorted_dims.size()); |
| 533 | + auto fft_dims = |
| 534 | + IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims); |
| 535 | + impl::_exec_fft( |
| 536 | + out, |
| 537 | + working_tensor, |
| 538 | + out_sizes, |
| 539 | + fft_dims, |
| 540 | + onesided, |
| 541 | + /*forward=*/true); |
| 542 | + sorted_dims.resize(sorted_dims.size() - max_dims); |
| 543 | + |
| 544 | + if (sorted_dims.empty()) { |
| 545 | + break; |
| 546 | + } |
| 547 | + |
| 548 | + sorted_dims = impl::_sort_dims(self, sorted_dims); |
| 549 | + } |
| 550 | + |
| 551 | + // Only need to normalize the onesided slice since data in the other half is |
| 552 | + // overwritten |
| 553 | + auto out_slice = out.slice(last_dim, 0, last_dim_halfsize); |
| 554 | + working_tensor = self; |
| 555 | + if (!onesided) { |
| 556 | + if (out.sizes()[last_dim] != out_sizes[last_dim]) { |
| 557 | + working_tensor.resize_(out_sizes, MemoryFormat::Contiguous); |
| 558 | + working_tensor.slice(last_dim, 0, last_dim_halfsize).copy_(out); |
| 559 | + out = std::move(working_tensor); |
| 560 | + } |
| 561 | + at::native::_fft_fill_with_conjugate_symmetry_(out, dim); |
| 562 | + } |
| 563 | + |
| 564 | + return impl::_fft_apply_normalization(out, normalization, input_sizes, dim); |
| 565 | +} |
| 566 | + |
| 567 | +Tensor& _fft_r2c_mkl_out( |
| 568 | + const Tensor& self, |
| 569 | + IntArrayRef dim, |
| 570 | + int64_t normalization, |
| 571 | + bool onesided, |
| 572 | + Tensor& out) { |
| 573 | + auto result = _fft_r2c_mkl( |
| 574 | + self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true); |
| 575 | + |
| 576 | + if (onesided) { |
| 577 | + return impl::_fft_apply_normalization_out( |
| 578 | + out, result, normalization, self.sizes(), dim); |
| 579 | + } |
| 580 | + |
| 581 | + at::native::resize_output(out, self.sizes()); |
| 582 | + |
| 583 | + auto last_dim = dim.back(); |
| 584 | + auto last_dim_halfsize = result.sizes()[last_dim]; |
| 585 | + auto out_slice = out.slice(last_dim, 0, last_dim_halfsize); |
| 586 | + |
| 587 | + impl::_fft_apply_normalization_out( |
| 588 | + out_slice, result, normalization, self.sizes(), dim); |
| 589 | + at::native::_fft_fill_with_conjugate_symmetry_(out, dim); |
| 590 | + return out; |
| 591 | +} |
| 592 | + |
480 | 593 | } // namespace at::native::xpu
|
481 | 594 | #endif // USE_ONEMKL
|
0 commit comments