diff --git a/convert-to-coreml b/convert-to-coreml index 27bfa55..cf3a003 100755 --- a/convert-to-coreml +++ b/convert-to-coreml @@ -26,18 +26,19 @@ def main(): args = parser.parse_args() samplerate = 44100 - estimator = Estimator(num_instruments=args.num_instruments, checkpoint_path=args.model) + estimator = Estimator( + num_instruments=args.num_instruments, + checkpoint_path=args.model, + use_torch_stft=False, + ) estimator.eval() # Create sample 'audio' for tracing wav = torch.zeros(2, int(args.length * samplerate)) - # Reproduce the STFT step (which we cannot convert to Core ML, unfortunately) - _, stft_mag = estimator.compute_stft(wav) - print('==> Tracing model') - traced_model = torch.jit.trace(estimator.separator, stft_mag) - out = traced_model(stft_mag) + traced_model = torch.jit.trace(estimator, wav) + out = traced_model(wav) print('==> Converting to Core ML') mlmodel = ct.convert( @@ -45,7 +46,7 @@ def main(): convert_to='mlprogram', # TODO: Investigate whether we'd want to make the input shape flexible # See https://coremltools.readme.io/docs/flexible-inputs - inputs=[ct.TensorType(shape=stft_mag.shape)] + inputs=[ct.TensorType(shape=wav.shape)] ) output_dir: Path = args.output diff --git a/spleeter_pytorch/__init__.py b/spleeter_pytorch/__init__.py index 00eba8f..1d11e2a 100644 --- a/spleeter_pytorch/__init__.py +++ b/spleeter_pytorch/__init__.py @@ -12,10 +12,15 @@ def main(): parser.add_argument('-n', '--num-instruments', type=int, default=2, help='The number of stems.') parser.add_argument('-m', '--model', type=Path, default=ROOT / 'checkpoints' / '2stems' / 'model', help='The path to the model to use.') parser.add_argument('-o', '--output', type=Path, default=ROOT / 'output' / 'stems', help='The path to the output directory.') + parser.add_argument('--torch-stft', default=True, action=argparse.BooleanOptionalAction, help="Whether to use PyTorch's native STFT.") parser.add_argument('input', type=Path, help='The path to the input file to process') args = parser.parse_args() - estimator = Estimator(num_instruments=args.num_instruments, checkpoint_path=args.model) + estimator = Estimator( + num_instruments=args.num_instruments, + checkpoint_path=args.model, + use_torch_stft=args.torch_stft, + ) estimator.eval() # Load wav audio diff --git a/spleeter_pytorch/estimator.py b/spleeter_pytorch/estimator.py index a6c042b..d76a30b 100644 --- a/spleeter_pytorch/estimator.py +++ b/spleeter_pytorch/estimator.py @@ -5,15 +5,21 @@ from torch import nn from spleeter_pytorch.separator import Separator +from spleeter_pytorch.util import overlap_and_add class Estimator(nn.Module): - def __init__(self, num_instruments: int, checkpoint_path: Path): + def __init__( + self, + num_instruments: int, + checkpoint_path: Path, + use_torch_stft: bool=True, + ): super().__init__() # stft config self.F = 1024 self.T = 512 - self.win_length = 4096 + self.win_length = 4096 # should be a power of two, see https://github.com/tensorflow/tensorflow/blob/6935c8f706dde1906e388b3142906c92cdcc36db/tensorflow/python/ops/signal/spectral_ops.py#L48-L49 self.hop_length = 1024 self.win = nn.Parameter( torch.hann_window(self.win_length), @@ -21,8 +27,9 @@ def __init__(self, num_instruments: int, checkpoint_path: Path): ) self.separator = Separator(num_instruments=num_instruments, checkpoint_path=checkpoint_path) + self.use_torch_stft = use_torch_stft - def compute_stft(self, wav): + def compute_stft(self, wav: torch.Tensor): """ Computes stft feature from wav @@ -30,8 +37,22 @@ def compute_stft(self, wav): wav (Tensor): B x L """ - stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win, - center=True, return_complex=True, pad_mode='constant') + if self.use_torch_stft: + stft = torch.stft( + wav, + n_fft=self.win_length, + hop_length=self.hop_length, + window=self.win, + center=True, + return_complex=True, + pad_mode='constant' + ) + else: + L = wav.shape[-1] + framed_wav = wav.unfold(-1, size=self.win_length, step=self.hop_length) + framed_wav *= self.win + stft = torch.fft.rfft(framed_wav, self.win_length) + stft = stft.transpose(1, 2) # only keep freqs smaller than self.F stft = stft[:, :self.F, :] @@ -45,8 +66,19 @@ def inverse_stft(self, stft): pad = self.win_length // 2 + 1 - stft.size(1) stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) stft = torch.view_as_complex(stft) - wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True, - window=self.win) + if self.use_torch_stft: + wav = torch.istft( + stft, + self.win_length, + hop_length=self.hop_length, + center=True, + window=self.win + ) + else: + stft = stft.transpose(1, 2) + wav: torch.Tensor = torch.fft.irfft(stft, self.win_length) + wav *= self.win + wav = overlap_and_add(wav, self.hop_length) return wav.detach() def forward(self, wav): diff --git a/spleeter_pytorch/unet.py b/spleeter_pytorch/unet.py index bd2bd02..4a1c285 100644 --- a/spleeter_pytorch/unet.py +++ b/spleeter_pytorch/unet.py @@ -5,7 +5,7 @@ class CustomPad(nn.Module): def __init__(self, padding_setting=(1, 2, 1, 2)): - super(CustomPad, self).__init__() + super().__init__() self.padding_setting = padding_setting def forward(self, x): @@ -14,7 +14,7 @@ def forward(self, x): class CustomTransposedPad(nn.Module): def __init__(self, padding_setting=(1, 2, 1, 2)): - super(CustomTransposedPad, self).__init__() + super().__init__() self.padding_setting = padding_setting def forward(self, x): @@ -45,7 +45,7 @@ def up_block(in_filters, out_filters, dropout=False): class UNet(nn.Module): def __init__(self, in_channels=2): - super(UNet, self).__init__() + super().__init__() self.down1_conv, self.down1_act = down_block(in_channels, 16) self.down2_conv, self.down2_act = down_block(16, 32) self.down3_conv, self.down3_act = down_block(32, 64) diff --git a/spleeter_pytorch/util.py b/spleeter_pytorch/util.py index 2cd00a5..0a03f06 100644 --- a/spleeter_pytorch/util.py +++ b/spleeter_pytorch/util.py @@ -1,5 +1,7 @@ +import math import numpy as np import tensorflow as tf +import torch from pathlib import Path @@ -76,3 +78,44 @@ def tf2pytorch(checkpoint_path: Path, num_instruments: int): conv_idx += 1 return outputs + +# Source: https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py +# MIT-licensed, Copyright (c) 2018 Kaituo XU + +def overlap_and_add(signal: torch.Tensor, frame_step: int): + ''' + Reconstructs a signal from a framed representation. + Adds potentially overlapping frames of a signal with shape + `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. + The resulting tensor has shape `[..., output_size]` where + output_size = (frames - 1) * frame_step + frame_length + + Args: + signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. + frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. + + Returns: + A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. + output_size = (frames - 1) * frame_step + frame_length + + Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py + ''' + outer_dimensions = signal.size()[:-2] + frames, frame_length = signal.size()[-2:] + + subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor + subframe_step = frame_step // subframe_length + subframes_per_frame = frame_length // subframe_length + output_size = frame_step * (frames - 1) + frame_length + output_subframes = output_size // subframe_length + + subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) + + frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) + frame = signal.new_tensor(frame).long() # signal may in GPU or CPU + frame = frame.contiguous().view(-1) + + result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) + result.index_add_(-2, frame, subframe_signal) + result = result.view(*outer_dimensions, -1) + return result