@@ -9596,13 +9596,13 @@ def forward(self, x):
95969596 [None , 1 , 3 ], # channels
95979597 [16 , 32 ], # n_fft
95989598 [5 , 9 ], # num_frames
9599- [None , 4 , 5 ], # hop_length
9599+ [None , 5 ], # hop_length
96009600 [None , 10 , 8 ], # win_length
96019601 [None , torch .hann_window ], # window
96029602 [False , True ], # center
96039603 [False , True ], # normalized
96049604 [None , False , True ], # onesided
9605- [None , 30 , 40 ], # length
9605+ [None , "shorter" , "larger" ], # length
96069606 [False , True ], # return_complex
96079607 )
96089608 )
@@ -9613,9 +9613,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
96139613 if hop_length is None and win_length is not None :
96149614 pytest .skip ("If win_length is set then we must set hop_length and 0 < hop_length <= win_length" )
96159615
9616+ # Compute input_shape to generate test case
96169617 freq = n_fft // 2 + 1 if onesided else n_fft
96179618 input_shape = (channels , freq , num_frames ) if channels else (freq , num_frames )
96189619
9620+ # If not set,c ompute hop_length for capturing errors
9621+ if hop_length is None :
9622+ hop_length = n_fft // 4
9623+
9624+ if length == "shorter" :
9625+ length = n_fft // 2 + hop_length * (num_frames - 1 )
9626+ elif length == "larger" :
9627+ length = n_fft * 3 // 2 + hop_length * (num_frames - 1 )
9628+
96199629 class ISTFTModel (torch .nn .Module ):
96209630 def forward (self , x ):
96219631 applied_window = window (win_length ) if window and win_length else None
@@ -9635,7 +9645,7 @@ def forward(self, x):
96359645 else :
96369646 return torch .real (x )
96379647
9638- if win_length and center is False :
9648+ if ( center is False and win_length ) or ( center and win_length and length ) :
96399649 # For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
96409650 with pytest .raises (RuntimeError , match = "istft\(.*\) window overlap add min: 1" ):
96419651 TorchBaseTest .run_compare_torch (
@@ -9644,7 +9654,7 @@ def forward(self, x):
96449654 backend = backend ,
96459655 compute_unit = compute_unit
96469656 )
9647- elif length is not None and return_complex is True :
9657+ elif length and return_complex :
96489658 with pytest .raises (ValueError , match = "New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`" ):
96499659 TorchBaseTest .run_compare_torch (
96509660 input_shape ,
0 commit comments