Skip to content

Commit 57c079f

Browse files
committed
New Tiny
1 parent 98ce91c commit 57c079f

File tree

7 files changed

+78
-152
lines changed

7 files changed

+78
-152
lines changed

configs_template/config_template.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"upsample_initial_channel": 512,
5555
"upsample_kernel_sizes": [16,16, 4, 4, 4],
5656
"n_layers_q": 3,
57+
"n_flow_layer": 4,
5758
"use_spectral_norm": false,
5859
"gin_channels": 768,
5960
"ssl_dim": 768,
@@ -63,7 +64,6 @@
6364
"speaker_embedding":false,
6465
"vol_embedding":false,
6566
"use_depthwise_conv":false,
66-
"use_depthwise_transposeconv":false,
6767
"use_automatic_f0_prediction": true
6868
},
6969
"spk": {

models.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def __init__(self,
322322
vol_embedding=False,
323323
vocoder_name = "nsf-hifigan",
324324
use_depthwise_conv = False,
325-
use_depthwise_transposeconv = False,
326325
use_automatic_f0_prediction = True,
326+
n_flow_layer = 4,
327327
**kwargs):
328328

329329
super().__init__()
@@ -372,8 +372,7 @@ def __init__(self,
372372
"upsample_initial_channel": upsample_initial_channel,
373373
"upsample_kernel_sizes": upsample_kernel_sizes,
374374
"gin_channels": gin_channels,
375-
"use_depthwise_conv":use_depthwise_conv,
376-
"use_depthwise_transposeconv":use_depthwise_transposeconv
375+
"use_depthwise_conv":use_depthwise_conv
377376
}
378377

379378
modules.set_Conv1dModel(self.use_depthwise_conv)
@@ -390,7 +389,7 @@ def __init__(self,
390389
self.dec = Generator(h=hps)
391390

392391
self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
393-
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
392+
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, n_flow_layer, gin_channels=gin_channels)
394393
if self.use_automatic_f0_prediction:
395394
self.f0_decoder = F0Decoder(
396395
1,

modules/modules.py

-41
Original file line numberDiff line numberDiff line change
@@ -66,47 +66,6 @@ def forward(self, x, x_mask):
6666
return x * x_mask
6767

6868

69-
class DDSConv(nn.Module):
70-
"""
71-
Dialted and Depth-Separable Convolution
72-
"""
73-
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
74-
super().__init__()
75-
self.channels = channels
76-
self.kernel_size = kernel_size
77-
self.n_layers = n_layers
78-
self.p_dropout = p_dropout
79-
80-
self.drop = nn.Dropout(p_dropout)
81-
self.convs_sep = nn.ModuleList()
82-
self.convs_1x1 = nn.ModuleList()
83-
self.norms_1 = nn.ModuleList()
84-
self.norms_2 = nn.ModuleList()
85-
for i in range(n_layers):
86-
dilation = kernel_size ** i
87-
padding = (kernel_size * dilation - dilation) // 2
88-
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
89-
groups=channels, dilation=dilation, padding=padding
90-
))
91-
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
92-
self.norms_1.append(LayerNorm(channels))
93-
self.norms_2.append(LayerNorm(channels))
94-
95-
def forward(self, x, x_mask, g=None):
96-
if g is not None:
97-
x = x + g
98-
for i in range(self.n_layers):
99-
y = self.convs_sep[i](x * x_mask)
100-
y = self.norms_1[i](y)
101-
y = F.gelu(y)
102-
y = self.convs_1x1[i](y)
103-
y = self.norms_2[i](y)
104-
y = F.gelu(y)
105-
y = self.drop(y)
106-
x = x + y
107-
return x * x_mask
108-
109-
11069
class WN(torch.nn.Module):
11170
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
11271
super(WN, self).__init__()

vdecoder/hifigan/models.py

+30-45
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,18 @@
1-
import os
21
import json
3-
from .env import AttrDict
2+
import os
3+
44
import numpy as np
55
import torch
6-
import torch.nn.functional as F
76
import torch.nn as nn
8-
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
9-
from torch.nn.utils import weight_norm,spectral_norm
10-
from .utils import init_weights, get_padding
7+
import torch.nn.functional as F
8+
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
9+
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
1110

12-
from modules.DSConv import weight_norm_modules, remove_weight_norm_modules, Depthwise_Separable_Conv1D, Depthwise_Separable_TransposeConv1D
11+
from .env import AttrDict
12+
from .utils import get_padding, init_weights
1313

1414
LRELU_SLOPE = 0.1
1515

16-
Conv1dModel = nn.Conv1d
17-
ConvTranspose1dModel = nn.ConvTranspose1d
18-
19-
def set_Conv1dModel(use_depthwise_conv):
20-
global Conv1dModel
21-
Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d
22-
23-
def set_ConvTranspose1dModel(use_depthwise_transposeconv):
24-
global ConvTranspose1dModel
25-
ConvTranspose1dModel = Depthwise_Separable_TransposeConv1D if use_depthwise_transposeconv else nn.ConvTranspose1d
2616

2717
def load_model(model_path, device='cuda'):
2818
config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
@@ -48,21 +38,21 @@ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
4838
super(ResBlock1, self).__init__()
4939
self.h = h
5040
self.convs1 = nn.ModuleList([
51-
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
41+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
5242
padding=get_padding(kernel_size, dilation[0]))),
53-
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
43+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
5444
padding=get_padding(kernel_size, dilation[1]))),
55-
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2],
45+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
5646
padding=get_padding(kernel_size, dilation[2])))
5747
])
5848
self.convs1.apply(init_weights)
5949

6050
self.convs2 = nn.ModuleList([
61-
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
51+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
6252
padding=get_padding(kernel_size, 1))),
63-
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
53+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
6454
padding=get_padding(kernel_size, 1))),
65-
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
55+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
6656
padding=get_padding(kernel_size, 1)))
6757
])
6858
self.convs2.apply(init_weights)
@@ -78,19 +68,19 @@ def forward(self, x):
7868

7969
def remove_weight_norm(self):
8070
for l in self.convs1:
81-
remove_weight_norm_modules(l)
71+
remove_weight_norm(l)
8272
for l in self.convs2:
83-
remove_weight_norm_modules(l)
73+
remove_weight_norm(l)
8474

8575

8676
class ResBlock2(torch.nn.Module):
8777
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
8878
super(ResBlock2, self).__init__()
8979
self.h = h
9080
self.convs = nn.ModuleList([
91-
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
81+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
9282
padding=get_padding(kernel_size, dilation[0]))),
93-
weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
83+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
9484
padding=get_padding(kernel_size, dilation[1])))
9585
])
9686
self.convs.apply(init_weights)
@@ -104,7 +94,7 @@ def forward(self, x):
10494

10595
def remove_weight_norm(self):
10696
for l in self.convs:
107-
remove_weight_norm_modules(l)
97+
remove_weight_norm(l)
10898

10999

110100
def padDiff(x):
@@ -211,8 +201,6 @@ def forward(self, f0):
211201
output uv: tensor(batchsize=1, length, 1)
212202
"""
213203
with torch.no_grad():
214-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
215-
device=f0.device)
216204
# fundamental component
217205
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
218206

@@ -289,28 +277,25 @@ class Generator(torch.nn.Module):
289277
def __init__(self, h):
290278
super(Generator, self).__init__()
291279
self.h = h
292-
293-
set_Conv1dModel(h["use_depthwise_conv"])
294-
set_ConvTranspose1dModel(h["use_depthwise_transposeconv"])
295-
280+
296281
self.num_kernels = len(h["resblock_kernel_sizes"])
297282
self.num_upsamples = len(h["upsample_rates"])
298283
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"]))
299284
self.m_source = SourceModuleHnNSF(
300285
sampling_rate=h["sampling_rate"],
301286
harmonic_num=8)
302287
self.noise_convs = nn.ModuleList()
303-
self.conv_pre = weight_norm_modules(Conv1dModel(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3))
288+
self.conv_pre = weight_norm(Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3))
304289
resblock = ResBlock1 if h["resblock"] == '1' else ResBlock2
305290
self.ups = nn.ModuleList()
306291
for i, (u, k) in enumerate(zip(h["upsample_rates"], h["upsample_kernel_sizes"])):
307292
c_cur = h["upsample_initial_channel"] // (2 ** (i + 1))
308-
self.ups.append(weight_norm_modules(
309-
ConvTranspose1dModel(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)),
310-
k, u, padding=(k - u + 1 ) // 2)))
293+
self.ups.append(weight_norm(
294+
ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)),
295+
k, u, padding=(k - u +1 ) // 2)))
311296
if i + 1 < len(h["upsample_rates"]): #
312297
stride_f0 = np.prod(h["upsample_rates"][i + 1:])
313-
self.noise_convs.append(Conv1dModel(
298+
self.noise_convs.append(Conv1d(
314299
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
315300
else:
316301
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
@@ -320,7 +305,7 @@ def __init__(self, h):
320305
for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])):
321306
self.resblocks.append(resblock(h, ch, k, d))
322307

323-
self.conv_post = weight_norm_modules(Conv1dModel(ch, 1, 7, 1, padding=3))
308+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
324309
self.ups.apply(init_weights)
325310
self.conv_post.apply(init_weights)
326311
self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1)
@@ -357,18 +342,18 @@ def forward(self, x, f0, g=None):
357342
def remove_weight_norm(self):
358343
print('Removing weight norm...')
359344
for l in self.ups:
360-
remove_weight_norm_modules(l)
345+
remove_weight_norm(l)
361346
for l in self.resblocks:
362347
l.remove_weight_norm()
363-
remove_weight_norm_modules(self.conv_pre)
364-
remove_weight_norm_modules(self.conv_post)
348+
remove_weight_norm(self.conv_pre)
349+
remove_weight_norm(self.conv_post)
365350

366351

367352
class DiscriminatorP(torch.nn.Module):
368353
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
369354
super(DiscriminatorP, self).__init__()
370355
self.period = period
371-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
356+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
372357
self.convs = nn.ModuleList([
373358
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
374359
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
@@ -427,7 +412,7 @@ def forward(self, y, y_hat):
427412
class DiscriminatorS(torch.nn.Module):
428413
def __init__(self, use_spectral_norm=False):
429414
super(DiscriminatorS, self).__init__()
430-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
415+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
431416
self.convs = nn.ModuleList([
432417
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
433418
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),

vdecoder/hifigan/utils.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import glob
22
import os
3-
import matplotlib
4-
import torch
5-
from torch.nn.utils import weight_norm
3+
64
# matplotlib.use("Agg")
75
import matplotlib.pylab as plt
6+
import torch
7+
from torch.nn.utils import weight_norm
88

99

1010
def plot_spectrogram(spectrogram):
@@ -21,10 +21,7 @@ def plot_spectrogram(spectrogram):
2121

2222
def init_weights(m, mean=0.0, std=0.01):
2323
classname = m.__class__.__name__
24-
if "Depthwise_Separable" in classname:
25-
m.depth_conv.weight.data.normal_(mean, std)
26-
m.point_conv.weight.data.normal_(mean, std)
27-
elif classname.find("Conv") != -1:
24+
if classname.find("Conv") != -1:
2825
m.weight.data.normal_(mean, std)
2926

3027

0 commit comments

Comments
 (0)