@@ -325,7 +325,7 @@ def _stft(
325325 We can write STFT in terms of convolutions with a DFT kernel.
326326 At the end:
327327 * The real part output is: cos_base * input_real + sin_base * input_imag
328- * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
328+ * The imaginary part output is: cos_base * input_imag - sin_base * input_real
329329 Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py
330330 """
331331 hop_length = hop_length or mb .floor_div (x = n_fft , y = 4 , before_op = before_op )
@@ -358,12 +358,13 @@ def _stft(
358358 if input_imaginary :
359359 signal_imaginary = mb .expand_dims (x = input_imaginary , axes = (1 ,), before_op = before_op )
360360
361- # conv with DFT kernel across the input signal
362- # The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is:
363- # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
364- # If x is complex then x[n]=(a+i*b)
365- # So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
366- # So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
361+ # Convolve the DFT kernel with the input signal
362+ # DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n])
363+ # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k))
364+ # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k))
365+ # But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k):
366+ # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k))
367+ # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k))
367368 cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
368369 sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
369370 if input_imaginary :
@@ -372,11 +373,11 @@ def _stft(
372373
373374 # add everything together
374375 if input_imaginary :
375- real_result = mb .add (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
376- imag_result = mb .sub (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
376+ real_result = mb .sub (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
377+ imag_result = mb .add (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
377378 else :
378379 real_result = cos_windows_real
379- imag_result = mb . sub ( x = 0. , y = sin_windows_real , before_op = before_op )
380+ imag_result = sin_windows_real
380381
381382 # reduce the rank of the output
382383 if should_increase_rank :
@@ -417,10 +418,10 @@ def _istft(
417418 # By default, use the entire frame
418419 win_length = win_length or n_fft
419420
420- input_shape = mb .shape (x = x , before_op = before_op )
421+ input_shape = mb .shape (x = input_real , before_op = before_op )
421422 n_frames = input_shape .val [- 1 ]
422423 fft_size = input_shape .val [- 2 ]
423- # expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
424+ expected_output_signal_len = n_fft .val + hop_length .val * (n_frames - 1 )
424425
425426 is_onesided = onesided .val if onesided else fft_size != n_fft
426427 cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
@@ -447,14 +448,13 @@ def _istft(
447448 signal_real = mb .mul (x = signal_real , y = multiplier , before_op = before_op )
448449 signal_imaginary = mb .mul (x = signal_imaginary , y = multiplier , before_op = before_op )
449450
450- # Conv with DFT kernel across the input signal
451- # We can describe the IDFT in terms of DFT just by swapping the input and output
451+ # Convolve the DFT kernel with the input signal
452+ # We can describe the IDFT in terms of DFT just by swapping the input and output.
452453 # ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT
453- # So IDFT(x) = (1/N) * swap(DFT(swap(x)))
454- # and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i)
455- # If x is complex then x[n]=(a+i*b)
456- # then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
457- # then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
454+ # IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N
455+ # So using the definition in stft function, we get:
456+ # real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n))
457+ # imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n))
458458 cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
459459 sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
460460 cos_windows_imag = mb .conv (x = signal_imaginary , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
@@ -750,17 +750,21 @@ def _lower_complex_istft(op: Operation):
750750 is_complex = types .is_complex (op .input .dtype )
751751
752752 # check parameters for validity
753+ if is_complex :
754+ raise ValueError ("Only complex inputs are allowed" )
753755 if op .win_length and op .win_length .val > op .n_fft .val :
754756 raise ValueError ("Window length must be less than or equal to n_fft" )
755- if is_complex and op .onesided and op .onesided .val :
756- raise ValueError ("Onesided is only valid for real inputs " )
757+ if op . return_complex and op .onesided and op .onesided .val :
758+ raise ValueError ("Complex output is not compatible with onesided " )
757759
758760 real , imag = _istft (
759- op .input .real if is_complex else op .input ,
760- op .input .imag if is_complex else None ,
761- op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , before_op = op )
761+ op .input .real , op .input .imag ,
762+ op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , op .length , before_op = op )
762763
763- return _wrap_complex_output (op .outputs [0 ], real , imag )
764+ if op .return_complex :
765+ return _wrap_complex_output (op .outputs [0 ], real , imag )
766+ else
767+ return real
764768
765769
766770@LowerComplex .register_lower_func (op_type = "complex_shape" )
0 commit comments