diff --git a/.gitignore b/.gitignore index feb4ef97..9ba52f53 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ notebooks foobar* run.sh README.rst +legacy +notebooks +run.sh pretrained_models deepvoice3_pytorch/version.py checkpoints* diff --git a/README.md b/README.md index 2971dd1d..6536aaa2 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,7 @@ python preprocess.py nikl_s ${your_nikl_root_path} data/nikl_s --preset=presets/ python train.py --data-root=./data/nikl_s --checkpoint-dir checkpoint_nikl_s --preset=presets/deepvoice3_nikls.json ``` -### 4. Monitor with Tensorboard +### 3. Monitor with Tensorboard Logs are dumped in `./log` directory by default. You can monitor logs by tensorboard: @@ -205,7 +205,7 @@ Logs are dumped in `./log` directory by default. You can monitor logs by tensorb tensorboard --logdir=log ``` -### 5. Synthesize from a checkpoint +### 4. Synthesize from a checkpoint Given a list of text, `synthesis.py` synthesize audio signals from trained model. Usage is: diff --git a/dump_hparams_to_json.py b/dump_hparams_to_json.py index d67e88d3..f0554605 100644 --- a/dump_hparams_to_json.py +++ b/dump_hparams_to_json.py @@ -12,13 +12,9 @@ import sys import os from os.path import dirname, join, basename, splitext +import json -import audio - -# The deepvoice3 model -from deepvoice3_pytorch import frontend from hparams import hparams -import json if __name__ == "__main__": args = docopt(__doc__) diff --git a/generate_aligned_predictions.py b/generate_aligned_predictions.py new file mode 100644 index 00000000..a72053a5 --- /dev/null +++ b/generate_aligned_predictions.py @@ -0,0 +1,177 @@ +# coding: utf-8 +""" +Generate ground trouth-aligned predictions + +usage: generate_aligned_predictions.py [options] + +options: + --hparams= Hyper parameters [default: ]. + --preset= Path of preset parameters (json). + --overwrite Overwrite audio and mel outputs. + -h, --help Show help message. +""" +from docopt import docopt +import os +from tqdm import tqdm +import importlib +from os.path import join +from warnings import warn +import sys + +import numpy as np +import torch +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as F + +# The deepvoice3 model +from deepvoice3_pytorch import frontend +from hparams import hparams + +use_cuda = torch.cuda.is_available() +_frontend = None # to be set later + + +def preprocess(model, in_dir, out_dir, text, audio_filename, mel_filename, + p=0, speaker_id=None, + fast=False): + """Generate ground truth-aligned prediction + + The output of the network and corresponding audio are saved after time + resolution adjustment. + """ + r = hparams.outputs_per_step + downsample_step = hparams.downsample_step + + if use_cuda: + model = model.cuda() + model.eval() + if fast: + model.make_generation_fast_() + + mel_org = np.load(join(in_dir, mel_filename)) + # zero padd + b_pad = r # imitates initial state + e_pad = r - len(mel_org) % r if len(mel_org) % r > 0 else 0 + mel = np.pad(mel_org, [(b_pad, e_pad), (0, 0)], + mode="constant", constant_values=0) + + mel = Variable(torch.from_numpy(mel)).unsqueeze(0).contiguous() + + # Downsample mel spectrogram + if downsample_step > 1: + mel = mel[:, 0::downsample_step, :].contiguous() + + decoder_target_len = mel.shape[1] // r + s, e = 1, decoder_target_len + 1 + frame_positions = torch.arange(s, e).long().unsqueeze(0) + frame_positions = Variable(frame_positions) + + sequence = np.array(_frontend.text_to_sequence(text, p=p)) + sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0) + text_positions = torch.arange(1, sequence.size(-1) + 1).unsqueeze(0).long() + text_positions = Variable(text_positions) + speaker_ids = None if speaker_id is None else Variable(torch.LongTensor([speaker_id])) + if use_cuda: + sequence = sequence.cuda() + text_positions = text_positions.cuda() + speaker_ids = None if speaker_ids is None else speaker_ids.cuda() + mel = mel.cuda() + frame_positions = frame_positions.cuda() + + # **Teacher forcing** decoding + mel_outputs, _, _, _ = model( + sequence, mel, text_positions=text_positions, + frame_positions=frame_positions, speaker_ids=speaker_ids) + + mel_output = mel_outputs[0].data.cpu().numpy() + # **Time resolution adjustment** + mel_output = mel_output[:-(b_pad + e_pad)] + + wav = np.load(join(in_dir, audio_filename)) + assert len(wav) % hparams.hop_size == 0 + + # Coarse upsample just for convenience + # so that we can upsample conditional features by hop_size in wavenet + if downsample_step > 0: + mel_output = np.repeat(mel_output, downsample_step, axis=0) + # downsampling -> upsampling, then we should have length equal to or larger than + # the original mel length + assert mel_output.shape[0] >= mel_org.shape[0] + + # Make sure we have correct lengths + assert mel_output.shape[0] * hparams.hop_size == len(wav) + + timesteps = len(wav) + + # save + np.save(join(out_dir, audio_filename), wav, allow_pickle=False) + np.save(join(out_dir, mel_filename), mel_output.astype(np.float32), + allow_pickle=False) + + if speaker_id is None: + return (audio_filename, mel_filename, timesteps, text) + else: + return (audio_filename, mel_filename, timesteps, text, speaker_id) + + +def write_metadata(metadata, out_dir): + with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f: + for m in metadata: + f.write('|'.join([str(x) for x in m]) + '\n') + frames = sum([m[2] for m in metadata]) + sr = hparams.sample_rate + hours = frames / sr / 3600 + print('Wrote %d utterances, %d time steps (%.2f hours)' % (len(metadata), frames, hours)) + print('Max input length: %d' % max(len(m[3]) for m in metadata)) + print('Max output length: %d' % max(m[2] for m in metadata)) + + +if __name__ == "__main__": + args = docopt(__doc__) + checkpoint_path = args[""] + in_dir = args[""] + out_dir = args[""] + preset = args["--preset"] + + # Load preset if specified + if preset is not None: + with open(preset) as f: + hparams.parse_json(f.read()) + # Override hyper parameters + hparams.parse(args["--hparams"]) + assert hparams.name == "deepvoice3" + + _frontend = getattr(frontend, hparams.frontend) + import train + train._frontend = _frontend + from train import build_model + + model = build_model() + + # Load checkpoint + print("Load checkpoint from {}".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path) + model.load_state_dict(checkpoint["state_dict"]) + + os.makedirs(out_dir, exist_ok=True) + results = [] + with open(os.path.join(in_dir, "train.txt")) as f: + lines = f.readlines() + + for idx in tqdm(range(len(lines))): + l = lines[idx] + l = l[:-1].split("|") + audio_filename, mel_filename, _, text = l[:4] + speaker_id = int(l[4]) if len(l) > 4 else None + if text == "N/A": + raise RuntimeError("No transcription available") + + result = preprocess(model, in_dir, out_dir, text, audio_filename, + mel_filename, p=0, + speaker_id=speaker_id, fast=True) + results.append(result) + + write_metadata(results, out_dir) + + sys.exit(0) diff --git a/hparams.py b/hparams.py index 2373a050..729c2605 100644 --- a/hparams.py +++ b/hparams.py @@ -43,7 +43,7 @@ # whether to rescale waveform or not. # Let x is an input waveform, rescaled waveform y is given by: # y = x / np.abs(x).max() * rescaling_max - rescaling=False, + rescaling=True, rescaling_max=0.999, # mel-spectrogram is normalized to [0, 1] for each utterance and clipping may # happen depends on min_level_db and ref_level_db, causing clipping noise. @@ -51,25 +51,25 @@ allow_clipping_in_normalization=True, # Model: - downsample_step=4, # must be 4 when builder="nyanko" - outputs_per_step=1, # must be 1 when builder="nyanko" + downsample_step=1, # must be 4 when builder="nyanko" + outputs_per_step=4, # must be 1 when builder="nyanko" embedding_weight_std=0.1, speaker_embedding_weight_std=0.01, padding_idx=0, # Maximum number of input text length # try setting larger value if you want to give very long text input - max_positions=512, - dropout=1 - 0.95, - kernel_size=3, - text_embed_dim=128, - encoder_channels=256, - decoder_channels=256, + max_positions=2048, + dropout=1 - 0.90, + kernel_size=5, + text_embed_dim=256, + encoder_channels=512, + decoder_channels=512, # Note: large converter channels requires significant computational cost converter_channels=256, query_position_rate=1.0, # can be computed by `compute_timestamp_ratio.py`. key_position_rate=1.385, # 2.37 for jsut - key_projection=False, + key_projection=True, value_projection=False, use_memory_mask=True, trainable_positional_encodings=False, @@ -99,7 +99,7 @@ adam_beta1=0.5, adam_beta2=0.9, adam_eps=1e-6, - initial_learning_rate=5e-4, # 0.001, + initial_learning_rate=1e-3, # 0.001, lr_schedule="noam_learning_rate_decay", lr_schedule_kwargs={}, nepochs=2000, diff --git a/preprocess.py b/preprocess.py index d76de83f..9a4eeac3 100644 --- a/preprocess.py +++ b/preprocess.py @@ -52,7 +52,6 @@ def write_metadata(metadata, out_dir): # Override hyper parameters hparams.parse(args["--hparams"]) assert hparams.name == "deepvoice3" - print(hparams_debug_string()) assert name in ["jsut", "ljspeech", "vctk", "nikl_m", "nikl_s", "json_meta"] mod = importlib.import_module(name) diff --git a/presets/deepvoice3_ljspeech_wavenet.json b/presets/deepvoice3_ljspeech_wavenet.json new file mode 100644 index 00000000..38c757cf --- /dev/null +++ b/presets/deepvoice3_ljspeech_wavenet.json @@ -0,0 +1,65 @@ +{ + "name": "deepvoice3", + "frontend": "en", + "replace_pronunciation_prob": 0.5, + "builder": "deepvoice3", + "n_speakers": 1, + "speaker_embed_dim": 16, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "sample_rate": 22050, + "preemphasis": 0.97, + "min_level_db": -100, + "ref_level_db": 20, + "rescaling": true, + "rescaling_max": 0.999, + "allow_clipping_in_normalization": true, + "downsample_step": 1, + "outputs_per_step": 4, + "embedding_weight_std": 0.1, + "speaker_embedding_weight_std": 0.01, + "padding_idx": 0, + "max_positions": 2048, + "dropout": 0.09999999999999998, + "kernel_size": 5, + "text_embed_dim": 256, + "encoder_channels": 512, + "decoder_channels": 512, + "converter_channels": 256, + "query_position_rate": 1.0, + "key_position_rate": 1.385, + "key_projection": true, + "value_projection": false, + "use_memory_mask": true, + "trainable_positional_encodings": false, + "freeze_embedding": false, + "use_decoder_state_for_postnet_input": true, + "pin_memory": true, + "num_workers": 2, + "masked_loss_weight": 0.5, + "priority_freq": 3000, + "priority_freq_weight": 0.0, + "binary_divergence_weight": 0.1, + "use_guided_attention": true, + "guided_attention_sigma": 0.2, + "batch_size": 16, + "adam_beta1": 0.5, + "adam_beta2": 0.9, + "adam_eps": 1e-06, + "initial_learning_rate": 0.001, + "lr_schedule": "noam_learning_rate_decay", + "lr_schedule_kwargs": {}, + "nepochs": 2000, + "weight_decay": 0.0, + "clip_thresh": 0.1, + "checkpoint_interval": 10000, + "eval_interval": 10000, + "save_optimizer_state": true, + "force_monotonic_attention": true, + "window_ahead": 3, + "window_backward": 1, + "power": 1.4 +} \ No newline at end of file diff --git a/synthesis.py b/synthesis.py index fbecdf2d..9a21629b 100644 --- a/synthesis.py +++ b/synthesis.py @@ -9,6 +9,7 @@ --preset= Path of preset parameters (json). --checkpoint-seq2seq= Load seq2seq model from checkpoint path. --checkpoint-postnet= Load postnet model from checkpoint path. + --checkpoint-wavenet= Load WaveNet vocoder. --file-name-suffix= File name suffix [default: ]. --max-decoder-steps= Max decoder steps [default: 500]. --replace_pronunciation_prob= Prob [default: 0.0]. @@ -39,7 +40,7 @@ _frontend = None # to be set later -def tts(model, text, p=0, speaker_id=None, fast=False): +def tts(model, text, p=0, speaker_id=None, fast=False, wavenet=None): """Convert text to speech waveform given a deepvoice3 model. Args: @@ -73,7 +74,30 @@ def tts(model, text, p=0, speaker_id=None, fast=False): mel = audio._denormalize(mel) # Predicted audio signal - waveform = audio.inv_spectrogram(linear_output.T) + if wavenet is not None: + if use_cuda: + wavenet = wavenet.cuda() + wavenet.eval() + if fast: + wavenet.make_generation_fast_() + + # TODO: assuming scalar input + initial_value = 0.0 + initial_input = Variable(torch.zeros(1, 1, 1)).fill_(initial_value) + # (B, T, C) -> (B, C, T) + c = mel_outputs.transpose(1, 2).contiguous() + g = None + Tc = c.size(-1) + length = Tc * 256 + if use_cuda: + initial_input = initial_input.cuda() + c = c.cuda() + waveform = wavenet.incremental_forward( + initial_input, c=c, g=g, T=length, tqdm=tqdm, softmax=True, quantize=True, + log_scale_min=float(np.log(1e-14))) + waveform = waveform.view(-1).cpu().data.numpy() + else: + waveform = audio.inv_spectrogram(linear_output.T) return waveform, alignment, spectrogram, mel @@ -95,6 +119,7 @@ def _load(checkpoint_path): dst_dir = args[""] checkpoint_seq2seq_path = args["--checkpoint-seq2seq"] checkpoint_postnet_path = args["--checkpoint-postnet"] + checkpoint_wavenet_path = args["--checkpoint-wavenet"] max_decoder_steps = int(args["--max-decoder-steps"]) file_name_suffix = args["--file-name-suffix"] replace_pronunciation_prob = float(args["--replace_pronunciation_prob"]) @@ -132,6 +157,19 @@ def _load(checkpoint_path): model.load_state_dict(checkpoint["state_dict"]) checkpoint_name = splitext(basename(checkpoint_path))[0] + # Load WaveNet vocoder + if checkpoint_wavenet_path is not None: + from wavenet_vocoder import builder + wavenet = builder.wavenet(out_channels=3 * 10, layers=24, stacks=4, residual_channels=512, + gate_channels=512, skip_out_channels=256, dropout=1 - 0.95, + kernel_size=3, weight_normalization=True, cin_channels=80, + upsample_conditional_features=True, upsample_scales=[4, 4, 4, 4], + freq_axis_kernel_size=3, gin_channels=-1, scalar_input=True) + checkpoint = torch.load(checkpoint_wavenet_path) + wavenet.load_state_dict(checkpoint["state_dict"]) + else: + wavenet = None + model.seq2seq.decoder.max_decoder_steps = max_decoder_steps os.makedirs(dst_dir, exist_ok=True) @@ -141,7 +179,8 @@ def _load(checkpoint_path): text = line.decode("utf-8")[:-1] words = nltk.word_tokenize(text) waveform, alignment, _, _ = tts( - model, text, p=replace_pronunciation_prob, speaker_id=speaker_id, fast=True) + model, text, p=replace_pronunciation_prob, speaker_id=speaker_id, fast=True, + wavenet=wavenet) dst_wav_path = join(dst_dir, "{}_{}{}.wav".format( idx, checkpoint_name, file_name_suffix)) dst_alignment_path = join( diff --git a/train.py b/train.py index b7918066..ce650257 100644 --- a/train.py +++ b/train.py @@ -993,6 +993,9 @@ def restore_parts(path, model): clip_thresh=hparams.clip_thresh, train_seq2seq=train_seq2seq, train_postnet=train_postnet) except KeyboardInterrupt: + print("Interrupted!") + pass + finally: save_checkpoint( model, optimizer, global_step, checkpoint_dir, global_epoch, train_seq2seq, train_postnet)