Skip to content

Commit 1ede8c0

Browse files
committed
Fix normalisation modes in _out functions
1 parent c753671 commit 1ede8c0

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/ATen/native/xpu/mkl/SpectralOps.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ void _mkl_dft(
5757
int64_t idist = istrides[0];
5858
int64_t odist = ostrides[0];
5959

60-
std::vector<int64_t> fwd_strides(istrides.cbegin(), istrides.cbegin() + signal_ndim + 1),
61-
bwd_strides(ostrides.cbegin(), ostrides.cbegin() + signal_ndim + 1);
62-
fwd_strides[0] = 0;
63-
bwd_strides[0] = 0;
60+
std::vector<int64_t> input_strides(istrides.cbegin(), istrides.cbegin() + signal_ndim + 1),
61+
output_strides(ostrides.cbegin(), ostrides.cbegin() + signal_ndim + 1);
62+
input_strides[0] = 0;
63+
output_strides[0] = 0;
6464

6565
auto desc = descriptor<prec, signal_type>(mkl_signal_sizes);
6666
desc.set_value(config_param::PLACEMENT, config_value::NOT_INPLACE);
@@ -70,14 +70,14 @@ void _mkl_dft(
7070
desc.set_value(config_param::FWD_DISTANCE, idist);
7171
desc.set_value(config_param::BWD_DISTANCE, odist);
7272

73-
desc.set_value(config_param::FWD_STRIDES, fwd_strides.data());
74-
desc.set_value(config_param::BWD_STRIDES, bwd_strides.data());
73+
desc.set_value(config_param::FWD_STRIDES, input_strides.data());
74+
desc.set_value(config_param::BWD_STRIDES, output_strides.data());
7575
} else {
7676
desc.set_value(config_param::FWD_DISTANCE, odist);
7777
desc.set_value(config_param::BWD_DISTANCE, idist);
7878

79-
desc.set_value(config_param::FWD_STRIDES, bwd_strides.data());
80-
desc.set_value(config_param::BWD_STRIDES, fwd_strides.data());
79+
desc.set_value(config_param::FWD_STRIDES, output_strides.data());
80+
desc.set_value(config_param::BWD_STRIDES, input_strides.data());
8181
}
8282

8383
if (!complex_input || !complex_output) {
@@ -367,7 +367,7 @@ Tensor& _fft_c2c_mkl_out(
367367
bool forward,
368368
Tensor& out) {
369369
auto result = _fft_c2c_mkl(
370-
self, dim, static_cast<int64_t>(fft_norm_mode::none), forward);
370+
self, dim, normalization, forward);
371371
at::native::resize_output(out, result.sizes());
372372
out.copy_(result);
373373
return out;
@@ -443,7 +443,7 @@ Tensor& _fft_c2r_mkl_out(
443443
int64_t last_dim_size,
444444
Tensor& out) {
445445
auto result = _fft_c2r_mkl(
446-
self, dim, static_cast<int64_t>(fft_norm_mode::none), last_dim_size);
446+
self, dim, normalization, last_dim_size);
447447
at::native::resize_output(out, result.sizes());
448448
out.copy_(result);
449449
return out;
@@ -541,7 +541,7 @@ Tensor& _fft_r2c_mkl_out(
541541
bool onesided,
542542
Tensor& out) {
543543
auto result = _fft_r2c_mkl(
544-
self, dim, static_cast<int64_t>(fft_norm_mode::none), /*onesided=*/true);
544+
self, dim, normalization, /*onesided=*/true);
545545

546546
at::native::resize_output(out, result.sizes());
547547
out.copy_(result);

0 commit comments

Comments
 (0)