Skip to content

Commit e016e68

Browse files
author
Alejandro Gaston Alvarez Franceschi
committed
More fixes
1 parent 8cb63ed commit e016e68

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9596,13 +9596,13 @@ def forward(self, x):
95969596
[None, 1, 3], # channels
95979597
[16, 32], # n_fft
95989598
[5, 9], # num_frames
9599-
[None, 4, 5], # hop_length
9599+
[None, 5], # hop_length
96009600
[None, 10, 8], # win_length
96019601
[None, torch.hann_window], # window
96029602
[False, True], # center
96039603
[False, True], # normalized
96049604
[None, False, True], # onesided
9605-
[None, 30, 40], # length
9605+
[None, "shorter", "larger"], # length
96069606
[False, True], # return_complex
96079607
)
96089608
)
@@ -9613,9 +9613,19 @@ def test_istft(self, compute_unit, backend, channels, n_fft, num_frames, hop_len
96139613
if hop_length is None and win_length is not None:
96149614
pytest.skip("If win_length is set then we must set hop_length and 0 < hop_length <= win_length")
96159615

9616+
# Compute input_shape to generate test case
96169617
freq = n_fft//2+1 if onesided else n_fft
96179618
input_shape = (channels, freq, num_frames) if channels else (freq, num_frames)
96189619

9620+
# If not set,c ompute hop_length for capturing errors
9621+
if hop_length is None:
9622+
hop_length = n_fft // 4
9623+
9624+
if length == "shorter":
9625+
length = n_fft//2 + hop_length * (num_frames - 1)
9626+
elif length == "larger":
9627+
length = n_fft*3//2 + hop_length * (num_frames - 1)
9628+
96199629
class ISTFTModel(torch.nn.Module):
96209630
def forward(self, x):
96219631
applied_window = window(win_length) if window and win_length else None
@@ -9635,7 +9645,7 @@ def forward(self, x):
96359645
else:
96369646
return torch.real(x)
96379647

9638-
if win_length and center is False:
9648+
if (center is False and win_length) or (center and win_length and length):
96399649
# For some reason Pytorch raises an error https://github.com/pytorch/audio/issues/427#issuecomment-1829593033
96409650
with pytest.raises(RuntimeError, match="istft\(.*\) window overlap add min: 1"):
96419651
TorchBaseTest.run_compare_torch(
@@ -9644,7 +9654,7 @@ def forward(self, x):
96449654
backend=backend,
96459655
compute_unit=compute_unit
96469656
)
9647-
elif length is not None and return_complex is True:
9657+
elif length and return_complex:
96489658
with pytest.raises(ValueError, match="New var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>` not a subtype of existing var type `<class 'coremltools.converters.mil.mil.types.type_tensor.tensor.<locals>.tensor'>`"):
96499659
TorchBaseTest.run_compare_torch(
96509660
input_shape,

0 commit comments

Comments
 (0)