@@ -57,10 +57,11 @@ void _mkl_dft(
57
57
int64_t idist = istrides[0 ];
58
58
int64_t odist = ostrides[0 ];
59
59
60
- std::vector<int64_t > fwd_strides (istrides.cbegin (), istrides.cbegin () + signal_ndim + 1 ),
61
- bwd_strides (ostrides.cbegin (), ostrides.cbegin () + signal_ndim + 1 );
62
- fwd_strides[0 ] = 0 ;
63
- bwd_strides[0 ] = 0 ;
60
+ std::vector<int64_t > input_strides (
61
+ istrides.cbegin (), istrides.cbegin () + signal_ndim + 1 ),
62
+ output_strides (ostrides.cbegin (), ostrides.cbegin () + signal_ndim + 1 );
63
+ input_strides[0 ] = 0 ;
64
+ output_strides[0 ] = 0 ;
64
65
65
66
auto desc = descriptor<prec, signal_type>(mkl_signal_sizes);
66
67
desc.set_value (config_param::PLACEMENT, config_value::NOT_INPLACE);
@@ -70,14 +71,14 @@ void _mkl_dft(
70
71
desc.set_value (config_param::FWD_DISTANCE, idist);
71
72
desc.set_value (config_param::BWD_DISTANCE, odist);
72
73
73
- desc.set_value (config_param::FWD_STRIDES, fwd_strides .data ());
74
- desc.set_value (config_param::BWD_STRIDES, bwd_strides .data ());
74
+ desc.set_value (config_param::FWD_STRIDES, input_strides .data ());
75
+ desc.set_value (config_param::BWD_STRIDES, output_strides .data ());
75
76
} else {
76
77
desc.set_value (config_param::FWD_DISTANCE, odist);
77
78
desc.set_value (config_param::BWD_DISTANCE, idist);
78
79
79
- desc.set_value (config_param::FWD_STRIDES, bwd_strides .data ());
80
- desc.set_value (config_param::BWD_STRIDES, fwd_strides .data ());
80
+ desc.set_value (config_param::FWD_STRIDES, output_strides .data ());
81
+ desc.set_value (config_param::BWD_STRIDES, input_strides .data ());
81
82
}
82
83
83
84
if (!complex_input || !complex_output) {
@@ -366,8 +367,7 @@ Tensor& _fft_c2c_mkl_out(
366
367
int64_t normalization,
367
368
bool forward,
368
369
Tensor& out) {
369
- auto result = _fft_c2c_mkl (
370
- self, dim, static_cast <int64_t >(fft_norm_mode::none), forward);
370
+ auto result = _fft_c2c_mkl (self, dim, normalization, forward);
371
371
at::native::resize_output (out, result.sizes ());
372
372
out.copy_ (result);
373
373
return out;
@@ -442,8 +442,7 @@ Tensor& _fft_c2r_mkl_out(
442
442
int64_t normalization,
443
443
int64_t last_dim_size,
444
444
Tensor& out) {
445
- auto result = _fft_c2r_mkl (
446
- self, dim, static_cast <int64_t >(fft_norm_mode::none), last_dim_size);
445
+ auto result = _fft_c2r_mkl (self, dim, normalization, last_dim_size);
447
446
at::native::resize_output (out, result.sizes ());
448
447
out.copy_ (result);
449
448
return out;
@@ -540,8 +539,7 @@ Tensor& _fft_r2c_mkl_out(
540
539
int64_t normalization,
541
540
bool onesided,
542
541
Tensor& out) {
543
- auto result = _fft_r2c_mkl (
544
- self, dim, static_cast <int64_t >(fft_norm_mode::none), /* onesided=*/ true );
542
+ auto result = _fft_r2c_mkl (self, dim, normalization, /* onesided=*/ true );
545
543
546
544
at::native::resize_output (out, result.sizes ());
547
545
out.copy_ (result);
0 commit comments