From 7ff2983d5a0e321669b2fda6ddd2b40e55506fb0 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Fri, 29 Aug 2025 11:35:01 +0300 Subject: [PATCH 1/4] Fix torchscript related test failures. --- src/torchaudio/functional/filtering.py | 50 ++++++++++++++++-------- src/torchaudio/functional/functional.py | 37 ++++++++++++++++-- src/torchaudio/transforms/_transforms.py | 3 +- 3 files changed, 69 insertions(+), 21 deletions(-) diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 1a7aa3e37e..9aec37a5f8 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs): b_coeff_flipped = b_coeffs.flip(1).contiguous() padded_waveform = F.pad(waveform, (n_order - 1, 0)) output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel) - ctx.save_for_backward(waveform, b_coeffs, output) + if not torch.jit.is_scripting(): + ctx.save_for_backward(waveform, b_coeffs, output) return output @staticmethod @@ -955,21 +956,28 @@ def backward(ctx, dy): n_batch = x.size(0) n_channel = x.size(1) n_order = b_coeffs.size(1) - db = ( - F.conv1d( - F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1), - dy.view(n_batch * n_channel, 1, -1), - groups=n_batch * n_channel, - ) - .view(n_batch, n_channel, -1) - .sum(0) - .flip(1) - if b_coeffs.requires_grad - else None - ) - dx = F.conv1d(F.pad(dy, (0, n_order - 1)), b_coeffs.unsqueeze(1), groups=n_channel) if x.requires_grad else None + + db = F.conv1d( + F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1), + dy.view(n_batch * n_channel, 1, -1), + groups=n_batch * n_channel + ).view( + n_batch, n_channel, -1 + ).sum(0).flip(1) if b_coeffs.requires_grad else None + dx = F.conv1d( + F.pad(dy, (0, n_order - 1)), + b_coeffs.unsqueeze(1), + groups=n_channel + ) if x.requires_grad else None return (dx, db) + @staticmethod + def ts_apply(waveform, b_coeffs): + if torch.jit.is_scripting(): + return DifferentiableFIR.forward(torch.empty(0), waveform, b_coeffs) + else: + return DifferentiableFIR.apply(waveform, b_coeffs) + class DifferentiableIIR(torch.autograd.Function): @staticmethod @@ -984,7 +992,8 @@ def forward(ctx, waveform, a_coeffs_normalized): ) _lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform) output = padded_output_waveform[:, :, n_order - 1 :] - ctx.save_for_backward(waveform, a_coeffs_normalized, output) + if not torch.jit.is_scripting(): + ctx.save_for_backward(waveform, a_coeffs_normalized, output) return output @staticmethod @@ -1006,10 +1015,17 @@ def backward(ctx, dy): ) return (dx, da) + @staticmethod + def ts_apply(waveform, a_coeffs_normalized): + if torch.jit.is_scripting(): + return DifferentiableIIR.forward(torch.empty(0), waveform, a_coeffs_normalized) + else: + return DifferentiableIIR.apply(waveform, a_coeffs_normalized) + def _lfilter(waveform, a_coeffs, b_coeffs): - filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1]) - return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1]) + filtered_waveform = DifferentiableFIR.ts_apply(waveform, b_coeffs / a_coeffs[:, 0:1]) + return DifferentiableIIR.ts_apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1]) def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor: diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index cf9967c8f2..73aff71f14 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -847,7 +847,8 @@ def mask_along_axis_iid( if axis not in [dim - 2, dim - 1]: raise ValueError( - f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)." + "Only Frequency and Time masking are supported" + f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)." ) if not 0.0 <= p <= 1.0: @@ -919,7 +920,8 @@ def mask_along_axis( if axis not in [dim - 2, dim - 1]: raise ValueError( - f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)." + "Only Frequency and Time masking are supported" + f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)." ) if not 0.0 <= p <= 1.0: @@ -1731,6 +1733,35 @@ def backward(ctx, dy): result = grad * grad_out return (result, None, None, None, None, None, None, None) + @staticmethod + def ts_apply( + logits, + targets, + logit_lengths, + target_lengths, + blank: int, + clamp: float, + fused_log_softmax: bool): + if torch.jit.is_scripting(): + output, saved = torch.ops.torchaudio.rnnt_loss_forward( + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + fused_log_softmax) + return output + else: + return RnntLoss.apply( + logits, + targets, + logit_lengths, + target_lengths, + blank, + clamp, + fused_log_softmax) + def rnnt_loss( logits: Tensor, @@ -1774,7 +1805,7 @@ def rnnt_loss( if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank - costs = RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax) + costs = RnntLoss.ts_apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax) if reduction == "mean": return costs.mean() diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 43b0ab6495..4700be1669 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -1202,7 +1202,8 @@ def forward(self, specgram: Tensor, mask_value: Union[float, torch.Tensor] = 0.0 specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p ) else: - return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p) + mask_value_ = float(mask_value) if isinstance(mask_value, Tensor) else mask_value + return F.mask_along_axis(specgram, self.mask_param, mask_value_, self.axis + specgram.dim() - 3, p=self.p) class FrequencyMasking(_AxisMasking): From cb080659ec8ba7bc43dabf2425236627af9d2c3c Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 10 Sep 2025 13:32:45 +0300 Subject: [PATCH 2/4] Rebase against main --- src/torchaudio/functional/filtering.py | 25 ++++++++++++----------- src/torchaudio/functional/functional.py | 27 ++++--------------------- 2 files changed, 17 insertions(+), 35 deletions(-) diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 9aec37a5f8..8f18b35de2 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -957,18 +957,19 @@ def backward(ctx, dy): n_channel = x.size(1) n_order = b_coeffs.size(1) - db = F.conv1d( - F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1), - dy.view(n_batch * n_channel, 1, -1), - groups=n_batch * n_channel - ).view( - n_batch, n_channel, -1 - ).sum(0).flip(1) if b_coeffs.requires_grad else None - dx = F.conv1d( - F.pad(dy, (0, n_order - 1)), - b_coeffs.unsqueeze(1), - groups=n_channel - ) if x.requires_grad else None + db = ( + F.conv1d( + F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1), + dy.view(n_batch * n_channel, 1, -1), + groups=n_batch * n_channel, + ) + .view(n_batch, n_channel, -1) + .sum(0) + .flip(1) + if b_coeffs.requires_grad + else None + ) + dx = F.conv1d(F.pad(dy, (0, n_order - 1)), b_coeffs.unsqueeze(1), groups=n_channel) if x.requires_grad else None return (dx, db) @staticmethod diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 73aff71f14..75bd6c57eb 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1734,33 +1734,14 @@ def backward(ctx, dy): return (result, None, None, None, None, None, None, None) @staticmethod - def ts_apply( - logits, - targets, - logit_lengths, - target_lengths, - blank: int, - clamp: float, - fused_log_softmax: bool): + def ts_apply(logits, targets, logit_lengths, target_lengths, blank: int, clamp: float, fused_log_softmax: bool): if torch.jit.is_scripting(): output, saved = torch.ops.torchaudio.rnnt_loss_forward( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax) + logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax + ) return output else: - return RnntLoss.apply( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax) + return RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax) def rnnt_loss( From 6b471054a94cf10f742c3fcc1a44432f104d0f7b Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 10 Sep 2025 13:50:01 +0300 Subject: [PATCH 3/4] Enable torchscript tests in CI workflow --- .github/scripts/unittest-linux/run_test.sh | 2 +- .github/scripts/unittest-windows/run_test.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/scripts/unittest-linux/run_test.sh b/.github/scripts/unittest-linux/run_test.sh index 06e77dc6ae..c06e0f2ab6 100755 --- a/.github/scripts/unittest-linux/run_test.sh +++ b/.github/scripts/unittest-linux/run_test.sh @@ -29,5 +29,5 @@ fi export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MOD_pytorch_lightning=true export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MULTIGPU_CUDA=true cd test - pytest torchaudio_unittest -k "not torchscript and not fairseq and not demucs ${PYTEST_K_EXTRA}" + pytest torchaudio_unittest -k "not fairseq and not demucs ${PYTEST_K_EXTRA}" ) diff --git a/.github/scripts/unittest-windows/run_test.sh b/.github/scripts/unittest-windows/run_test.sh index 292fe1d2b0..bc5fc935a2 100644 --- a/.github/scripts/unittest-windows/run_test.sh +++ b/.github/scripts/unittest-windows/run_test.sh @@ -13,8 +13,8 @@ env | grep TORCHAUDIO || true cd test if [ -z "${CUDA_VERSION:-}" ] ; then - pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not torchscript and not fairseq and not demucs and not librosa" + pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not fairseq and not demucs and not librosa" else - pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not cpu and (cuda or gpu) and not torchscript and not fairseq and not demucs and not librosa" + pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not cpu and (cuda or gpu) and and not fairseq and not demucs and not librosa" fi coverage html From f53790a102b661df016cb3e8d3aed7b5b8d71202 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Fri, 12 Dec 2025 16:11:01 +0200 Subject: [PATCH 4/4] Fix typo --- .github/scripts/unittest-windows/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/unittest-windows/run_test.sh b/.github/scripts/unittest-windows/run_test.sh index bc5fc935a2..8e77044489 100644 --- a/.github/scripts/unittest-windows/run_test.sh +++ b/.github/scripts/unittest-windows/run_test.sh @@ -15,6 +15,6 @@ cd test if [ -z "${CUDA_VERSION:-}" ] ; then pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not fairseq and not demucs and not librosa" else - pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not cpu and (cuda or gpu) and and not fairseq and not demucs and not librosa" + pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not cpu and (cuda or gpu) and not fairseq and not demucs and not librosa" fi coverage html