@@ -325,7 +325,7 @@ def _stft(
325325 We can write STFT in terms of convolutions with a DFT kernel.
326326 At the end:
327327 * The real part output is: cos_base * input_real + sin_base * input_imag
328- * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag)
328+ * The imaginary part output is: cos_base * input_imag - sin_base * input_real
329329 Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py
330330 """
331331 hop_length = hop_length or mb .floor_div (x = n_fft , y = 4 , before_op = before_op )
@@ -342,14 +342,13 @@ def _stft(
342342
343343 # create a window of centered 1s of the requested size
344344 if win_length :
345- window = _get_window (win_length = win_length , n_fft = n_fft , before_op = before_op )
345+ window = _get_window (win_length = win_length , n_fft = n_fft , window = window , before_op = before_op )
346346
347347 # apply time window
348348 if window :
349349 cos_base = mb .mul (x = window , y = cos_base , before_op = before_op )
350350 sin_base = mb .mul (x = window , y = sin_base , before_op = before_op )
351351
352-
353352 # Expand
354353 cos_base = mb .expand_dims (x = cos_base , axes = (1 ,), before_op = before_op )
355354 sin_base = mb .expand_dims (x = sin_base , axes = (1 ,), before_op = before_op )
@@ -358,12 +357,13 @@ def _stft(
358357 if input_imaginary :
359358 signal_imaginary = mb .expand_dims (x = input_imaginary , axes = (1 ,), before_op = before_op )
360359
361- # conv with DFT kernel across the input signal
362- # The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is:
363- # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i)
364- # If x is complex then x[n]=(a+i*b)
365- # So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
366- # So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
360+ # Convolve the DFT kernel with the input signal
361+ # DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n])
362+ # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k))
363+ # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k))
364+ # But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k):
365+ # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k))
366+ # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k))
367367 cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
368368 sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
369369 if input_imaginary :
@@ -372,11 +372,11 @@ def _stft(
372372
373373 # add everything together
374374 if input_imaginary :
375- real_result = mb .add (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
376- imag_result = mb .sub (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
375+ real_result = mb .sub (x = cos_windows_real , y = sin_windows_imag , before_op = before_op )
376+ imag_result = mb .add (x = cos_windows_imag , y = sin_windows_real , before_op = before_op )
377377 else :
378378 real_result = cos_windows_real
379- imag_result = mb . sub ( x = 0. , y = sin_windows_real , before_op = before_op )
379+ imag_result = sin_windows_real
380380
381381 # reduce the rank of the output
382382 if should_increase_rank :
@@ -417,17 +417,18 @@ def _istft(
417417 # By default, use the entire frame
418418 win_length = win_length or n_fft
419419
420- input_shape = mb .shape (x = x , before_op = before_op )
421- n_frames = input_shape .val [- 1 ]
422- fft_size = input_shape .val [- 2 ]
423- # expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1)
420+ input_shape = mb .shape (x = input_real , before_op = before_op )
421+ channels = input_shape .val [0 ]
422+ fft_size = input_shape .val [1 ]
423+ n_frames = input_shape .val [2 ]
424+ expected_output_signal_len = n_fft .val + hop_length .val * (n_frames - 1 )
424425
425426 is_onesided = onesided .val if onesided else fft_size != n_fft
426427 cos_base , sin_base = _calculate_dft_matrix (n_fft , onesided = is_onesided , before_op = before_op )
427428
428429 # create a window of centered 1s of the requested size
429430 if win_length :
430- window = _get_window (win_length = win_length , n_fft = n_fft , before_op = before_op )
431+ window = _get_window (win_length = win_length , n_fft = n_fft , window = window , before_op = before_op )
431432
432433 # apply time window
433434 if window :
@@ -447,14 +448,13 @@ def _istft(
447448 signal_real = mb .mul (x = signal_real , y = multiplier , before_op = before_op )
448449 signal_imaginary = mb .mul (x = signal_imaginary , y = multiplier , before_op = before_op )
449450
450- # Conv with DFT kernel across the input signal
451- # We can describe the IDFT in terms of DFT just by swapping the input and output
451+ # Convolve the DFT kernel with the input signal
452+ # We can describe the IDFT in terms of DFT just by swapping the input and output.
452453 # ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT
453- # So IDFT(x) = (1/N) * swap(DFT(swap(x)))
454- # and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i)
455- # If x is complex then x[n]=(a+i*b)
456- # then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N))
457- # then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N))
454+ # IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N
455+ # So using the definition in stft function, we get:
456+ # real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n))
457+ # imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n))
458458 cos_windows_real = mb .conv (x = signal_real , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
459459 sin_windows_real = mb .conv (x = signal_real , weight = sin_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
460460 cos_windows_imag = mb .conv (x = signal_imaginary , weight = cos_base , strides = hop_size , pad_type = 'valid' , before_op = before_op )
@@ -519,6 +519,7 @@ def _overlap_add(
519519def _get_window (
520520 win_length : Var ,
521521 n_fft : Var ,
522+ window : Optional [Var ],
522523 before_op : Operation ,
523524) -> Var :
524525 n_left = (n_fft .val - win_length .val ) // 2
@@ -750,17 +751,21 @@ def _lower_complex_istft(op: Operation):
750751 is_complex = types .is_complex (op .input .dtype )
751752
752753 # check parameters for validity
754+ if is_complex :
755+ raise ValueError ("Only complex inputs are allowed" )
753756 if op .win_length and op .win_length .val > op .n_fft .val :
754757 raise ValueError ("Window length must be less than or equal to n_fft" )
755- if is_complex and op .onesided and op .onesided .val :
756- raise ValueError ("Onesided is only valid for real inputs " )
758+ if op . return_complex and op .onesided and op .onesided .val :
759+ raise ValueError ("Complex output is not compatible with onesided " )
757760
758761 real , imag = _istft (
759- op .input .real if is_complex else op .input ,
760- op .input .imag if is_complex else None ,
761- op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , before_op = op )
762+ op .input .real , op .input .imag ,
763+ op .n_fft , op .hop_length , op .win_length , op .window , op .normalized , op .onesided , op .length , before_op = op )
762764
763- return _wrap_complex_output (op .outputs [0 ], real , imag )
765+ if op .return_complex :
766+ return _wrap_complex_output (op .outputs [0 ], real , imag )
767+ else
768+ return real
764769
765770
766771@LowerComplex .register_lower_func (op_type = "complex_shape" )
0 commit comments