Skip to content

Commit 4fe3310

Browse files
committed
Fix normalisation modes in _out functions
1 parent c753671 commit 4fe3310

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

src/ATen/native/xpu/mkl/SpectralOps.cpp

+12-14
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ void _mkl_dft(
5757
int64_t idist = istrides[0];
5858
int64_t odist = ostrides[0];
5959

60-
std::vector<int64_t> fwd_strides(istrides.cbegin(), istrides.cbegin() + signal_ndim + 1),
61-
bwd_strides(ostrides.cbegin(), ostrides.cbegin() + signal_ndim + 1);
62-
fwd_strides[0] = 0;
63-
bwd_strides[0] = 0;
60+
std::vector<int64_t> input_strides(
61+
istrides.cbegin(), istrides.cbegin() + signal_ndim + 1),
62+
output_strides(ostrides.cbegin(), ostrides.cbegin() + signal_ndim + 1);
63+
input_strides[0] = 0;
64+
output_strides[0] = 0;
6465

6566
auto desc = descriptor<prec, signal_type>(mkl_signal_sizes);
6667
desc.set_value(config_param::PLACEMENT, config_value::NOT_INPLACE);
@@ -70,14 +71,14 @@ void _mkl_dft(
7071
desc.set_value(config_param::FWD_DISTANCE, idist);
7172
desc.set_value(config_param::BWD_DISTANCE, odist);
7273

73-
desc.set_value(config_param::FWD_STRIDES, fwd_strides.data());
74-
desc.set_value(config_param::BWD_STRIDES, bwd_strides.data());
74+
desc.set_value(config_param::FWD_STRIDES, input_strides.data());
75+
desc.set_value(config_param::BWD_STRIDES, output_strides.data());
7576
} else {
7677
desc.set_value(config_param::FWD_DISTANCE, odist);
7778
desc.set_value(config_param::BWD_DISTANCE, idist);
7879

79-
desc.set_value(config_param::FWD_STRIDES, bwd_strides.data());
80-
desc.set_value(config_param::BWD_STRIDES, fwd_strides.data());
80+
desc.set_value(config_param::FWD_STRIDES, output_strides.data());
81+
desc.set_value(config_param::BWD_STRIDES, input_strides.data());
8182
}
8283

8384
if (!complex_input || !complex_output) {
@@ -366,8 +367,7 @@ Tensor& _fft_c2c_mkl_out(
366367
int64_t normalization,
367368
bool forward,
368369
Tensor& out) {
369-
auto result = _fft_c2c_mkl(
370-
self, dim, static_cast<int64_t>(fft_norm_mode::none), forward);
370+
auto result = _fft_c2c_mkl(self, dim, normalization, forward);
371371
at::native::resize_output(out, result.sizes());
372372
out.copy_(result);
373373
return out;
@@ -442,8 +442,7 @@ Tensor& _fft_c2r_mkl_out(
442442
int64_t normalization,
443443
int64_t last_dim_size,
444444
Tensor& out) {
445-
auto result = _fft_c2r_mkl(
446-
self, dim, static_cast<int64_t>(fft_norm_mode::none), last_dim_size);
445+
auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size);
447446
at::native::resize_output(out, result.sizes());
448447
out.copy_(result);
449448
return out;
@@ -540,8 +539,7 @@ Tensor& _fft_r2c_mkl_out(
540539
int64_t normalization,
541540
bool onesided,
542541
Tensor& out) {
543-
auto result = _fft_r2c_mkl(
544-
self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true);
542+
auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true);
545543

546544
at::native::resize_output(out, result.sizes());
547545
out.copy_(result);

0 commit comments

Comments
 (0)