Skip to content

Commit 28c4a6c

Browse files
committed
Move dft_scale to individual functions
1 parent e224c65 commit 28c4a6c

File tree

1 file changed

+5
-17
lines changed

1 file changed

+5
-17
lines changed

src/ATen/native/xpu/mkl/SpectralOps.cpp

+5-17
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,7 @@ Tensor& _exec_fft(
279279

280280
double _dft_scale(
281281
IntArrayRef dim,
282-
IntArrayRef input_sizes,
283-
IntArrayRef out_sizes,
282+
IntArrayRef norm_sizes,
284283
int64_t normalization) {
285284
const auto norm = static_cast<fft_norm_mode>(normalization);
286285
double double_scale = 1.0;
@@ -289,21 +288,10 @@ double _dft_scale(
289288
return double_scale;
290289
}
291290

292-
const int64_t signal_ndim = dim.size();
293291
int64_t signal_numel = 1;
294-
295-
for (int64_t i = 0; i < signal_ndim; ++i) {
296-
auto in_size = input_sizes[dim[i]];
297-
auto out_size = out_sizes[dim[i]];
298-
auto signal_size = std::max(in_size, out_size);
299-
300-
signal_numel *= signal_size;
301-
TORCH_INTERNAL_ASSERT(
302-
in_size == signal_size || in_size == (signal_size / 2) + 1);
303-
TORCH_INTERNAL_ASSERT(
304-
out_size == signal_size || out_size == (signal_size / 2) + 1);
292+
for (const int64_t& d : dim) {
293+
signal_numel *= norm_sizes[d];
305294
}
306-
307295
if (norm == fft_norm_mode::by_root_n) {
308296
double_scale = 1.0 / std::sqrt(signal_numel);
309297
} else {
@@ -316,9 +304,9 @@ double _dft_scale(
316304
const Tensor& _fft_apply_normalization(
317305
const Tensor& self,
318306
int64_t normalization,
319-
IntArrayRef sizes,
307+
IntArrayRef norm_sizes,
320308
IntArrayRef dims) {
321-
auto scale = _dft_scale(dims, sizes, self.sizes(), normalization);
309+
auto scale = _dft_scale(dims, norm_sizes, normalization);
322310
return (scale == 1.0) ? self : self.mul_(scale);
323311
}
324312

0 commit comments

Comments
 (0)