@@ -279,8 +279,7 @@ Tensor& _exec_fft(
279
279
280
280
double _dft_scale (
281
281
IntArrayRef dim,
282
- IntArrayRef input_sizes,
283
- IntArrayRef out_sizes,
282
+ IntArrayRef norm_sizes,
284
283
int64_t normalization) {
285
284
const auto norm = static_cast <fft_norm_mode>(normalization);
286
285
double double_scale = 1.0 ;
@@ -289,21 +288,10 @@ double _dft_scale(
289
288
return double_scale;
290
289
}
291
290
292
- const int64_t signal_ndim = dim.size ();
293
291
int64_t signal_numel = 1 ;
294
-
295
- for (int64_t i = 0 ; i < signal_ndim; ++i) {
296
- auto in_size = input_sizes[dim[i]];
297
- auto out_size = out_sizes[dim[i]];
298
- auto signal_size = std::max (in_size, out_size);
299
-
300
- signal_numel *= signal_size;
301
- TORCH_INTERNAL_ASSERT (
302
- in_size == signal_size || in_size == (signal_size / 2 ) + 1 );
303
- TORCH_INTERNAL_ASSERT (
304
- out_size == signal_size || out_size == (signal_size / 2 ) + 1 );
292
+ for (const int64_t & d : dim) {
293
+ signal_numel *= norm_sizes[d];
305
294
}
306
-
307
295
if (norm == fft_norm_mode::by_root_n) {
308
296
double_scale = 1.0 / std::sqrt (signal_numel);
309
297
} else {
@@ -316,9 +304,9 @@ double _dft_scale(
316
304
const Tensor& _fft_apply_normalization (
317
305
const Tensor& self,
318
306
int64_t normalization,
319
- IntArrayRef sizes ,
307
+ IntArrayRef norm_sizes ,
320
308
IntArrayRef dims) {
321
- auto scale = _dft_scale (dims, sizes, self. sizes () , normalization);
309
+ auto scale = _dft_scale (dims, norm_sizes , normalization);
322
310
return (scale == 1.0 ) ? self : self.mul_ (scale);
323
311
}
324
312
0 commit comments