Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions convert-to-coreml
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,16 @@ def main():
# Create sample 'audio' for tracing
wav = torch.zeros(2, int(args.length * samplerate))

# Reproduce the STFT step (which we cannot convert to Core ML, unfortunately)
_, stft_mag = estimator.compute_stft(wav)

print('==> Tracing model')
traced_model = torch.jit.trace(estimator.separator, stft_mag)
out = traced_model(stft_mag)
traced_model = torch.jit.trace(estimator, wav)

print('==> Converting to Core ML')
mlmodel = ct.convert(
traced_model,
convert_to='mlprogram',
# TODO: Investigate whether we'd want to make the input shape flexible
# See https://coremltools.readme.io/docs/flexible-inputs
inputs=[ct.TensorType(shape=stft_mag.shape)]
inputs=[ct.TensorType(shape=wav.shape)]
)

output_dir: Path = args.output
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ description = "Spleeter implementation in PyTorch"
# and fail during model conversions e.g. noting that BlobWriter is not available.
requires-python = "<3.11"
dependencies = [
"coremltools >= 6.3, < 7",
"coremltools == 7.0b1",
"numpy >= 1.24, < 2",
"tensorflow >= 2.13.0rc0",
"torch >= 2.0, < 3",
Expand Down
16 changes: 12 additions & 4 deletions spleeter_pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,27 @@ def compute_stft(self, wav):

stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win,
center=True, return_complex=True, pad_mode='constant')

# implement torch.view_as_real(stft) manually since coremltools doesn't support it
stft = torch.stack((torch.real(stft), torch.imag(stft)), axis=-1)

# only keep freqs smaller than self.F
stft = stft[:, :self.F, :]
mag = stft.abs()
stft = stft[:, :self.F]

return torch.view_as_real(stft), mag
# implement torch.hypot manually since coremltools doesn't support it
mag = torch.sqrt(stft[..., 0] ** 2 + stft[..., 1] ** 2)

return stft, mag

def inverse_stft(self, stft):
"""Inverses stft to wave form"""

pad = self.win_length // 2 + 1 - stft.size(1)
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
stft = torch.view_as_complex(stft)

# implement torch.view_as_complex(stft) manually since coremltools doesn't support it
stft = torch.complex(stft[..., 0], stft[..., 1])

wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
window=self.win)
return wav.detach()
Expand Down