1
- import os
2
1
import json
3
- from .env import AttrDict
2
+ import os
3
+
4
4
import numpy as np
5
5
import torch
6
- import torch .nn .functional as F
7
6
import torch .nn as nn
8
- from torch .nn import Conv1d , ConvTranspose1d , AvgPool1d , Conv2d
9
- from torch .nn . utils import weight_norm , spectral_norm
10
- from . utils import init_weights , get_padding
7
+ import torch .nn . functional as F
8
+ from torch .nn import AvgPool1d , Conv1d , Conv2d , ConvTranspose1d
9
+ from torch . nn . utils import remove_weight_norm , spectral_norm , weight_norm
11
10
12
- from modules .DSConv import weight_norm_modules , remove_weight_norm_modules , Depthwise_Separable_Conv1D , Depthwise_Separable_TransposeConv1D
11
+ from .env import AttrDict
12
+ from .utils import get_padding , init_weights
13
13
14
14
LRELU_SLOPE = 0.1
15
15
16
- Conv1dModel = nn .Conv1d
17
- ConvTranspose1dModel = nn .ConvTranspose1d
18
-
19
- def set_Conv1dModel (use_depthwise_conv ):
20
- global Conv1dModel
21
- Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn .Conv1d
22
-
23
- def set_ConvTranspose1dModel (use_depthwise_transposeconv ):
24
- global ConvTranspose1dModel
25
- ConvTranspose1dModel = Depthwise_Separable_TransposeConv1D if use_depthwise_transposeconv else nn .ConvTranspose1d
26
16
27
17
def load_model (model_path , device = 'cuda' ):
28
18
config_file = os .path .join (os .path .split (model_path )[0 ], 'config.json' )
@@ -48,21 +38,21 @@ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
48
38
super (ResBlock1 , self ).__init__ ()
49
39
self .h = h
50
40
self .convs1 = nn .ModuleList ([
51
- weight_norm_modules ( Conv1dModel (channels , channels , kernel_size , 1 , dilation = dilation [0 ],
41
+ weight_norm ( Conv1d (channels , channels , kernel_size , 1 , dilation = dilation [0 ],
52
42
padding = get_padding (kernel_size , dilation [0 ]))),
53
- weight_norm_modules ( Conv1dModel (channels , channels , kernel_size , 1 , dilation = dilation [1 ],
43
+ weight_norm ( Conv1d (channels , channels , kernel_size , 1 , dilation = dilation [1 ],
54
44
padding = get_padding (kernel_size , dilation [1 ]))),
55
- weight_norm_modules ( Conv1dModel (channels , channels , kernel_size , 1 , dilation = dilation [2 ],
45
+ weight_norm ( Conv1d (channels , channels , kernel_size , 1 , dilation = dilation [2 ],
56
46
padding = get_padding (kernel_size , dilation [2 ])))
57
47
])
58
48
self .convs1 .apply (init_weights )
59
49
60
50
self .convs2 = nn .ModuleList ([
61
- weight_norm_modules ( Conv1dModel (channels , channels , kernel_size , 1 , dilation = 1 ,
51
+ weight_norm ( Conv1d (channels , channels , kernel_size , 1 , dilation = 1 ,
62
52
padding = get_padding (kernel_size , 1 ))),
63
- weight_norm_modules ( Conv1dModel (channels , channels , kernel_size , 1 , dilation = 1 ,
53
+ weight_norm ( Conv1d (channels , channels , kernel_size , 1 , dilation = 1 ,
64
54
padding = get_padding (kernel_size , 1 ))),
65
- weight_norm_modules ( Conv1dModel (channels , channels , kernel_size , 1 , dilation = 1 ,
55
+ weight_norm ( Conv1d (channels , channels , kernel_size , 1 , dilation = 1 ,
66
56
padding = get_padding (kernel_size , 1 )))
67
57
])
68
58
self .convs2 .apply (init_weights )
@@ -78,19 +68,19 @@ def forward(self, x):
78
68
79
69
def remove_weight_norm (self ):
80
70
for l in self .convs1 :
81
- remove_weight_norm_modules (l )
71
+ remove_weight_norm (l )
82
72
for l in self .convs2 :
83
- remove_weight_norm_modules (l )
73
+ remove_weight_norm (l )
84
74
85
75
86
76
class ResBlock2 (torch .nn .Module ):
87
77
def __init__ (self , h , channels , kernel_size = 3 , dilation = (1 , 3 )):
88
78
super (ResBlock2 , self ).__init__ ()
89
79
self .h = h
90
80
self .convs = nn .ModuleList ([
91
- weight_norm_modules ( Conv1dModel (channels , channels , kernel_size , 1 , dilation = dilation [0 ],
81
+ weight_norm ( Conv1d (channels , channels , kernel_size , 1 , dilation = dilation [0 ],
92
82
padding = get_padding (kernel_size , dilation [0 ]))),
93
- weight_norm_modules ( Conv1dModel (channels , channels , kernel_size , 1 , dilation = dilation [1 ],
83
+ weight_norm ( Conv1d (channels , channels , kernel_size , 1 , dilation = dilation [1 ],
94
84
padding = get_padding (kernel_size , dilation [1 ])))
95
85
])
96
86
self .convs .apply (init_weights )
@@ -104,7 +94,7 @@ def forward(self, x):
104
94
105
95
def remove_weight_norm (self ):
106
96
for l in self .convs :
107
- remove_weight_norm_modules (l )
97
+ remove_weight_norm (l )
108
98
109
99
110
100
def padDiff (x ):
@@ -211,8 +201,6 @@ def forward(self, f0):
211
201
output uv: tensor(batchsize=1, length, 1)
212
202
"""
213
203
with torch .no_grad ():
214
- f0_buf = torch .zeros (f0 .shape [0 ], f0 .shape [1 ], self .dim ,
215
- device = f0 .device )
216
204
# fundamental component
217
205
fn = torch .multiply (f0 , torch .FloatTensor ([[range (1 , self .harmonic_num + 2 )]]).to (f0 .device ))
218
206
@@ -289,28 +277,25 @@ class Generator(torch.nn.Module):
289
277
def __init__ (self , h ):
290
278
super (Generator , self ).__init__ ()
291
279
self .h = h
292
-
293
- set_Conv1dModel (h ["use_depthwise_conv" ])
294
- set_ConvTranspose1dModel (h ["use_depthwise_transposeconv" ])
295
-
280
+
296
281
self .num_kernels = len (h ["resblock_kernel_sizes" ])
297
282
self .num_upsamples = len (h ["upsample_rates" ])
298
283
self .f0_upsamp = torch .nn .Upsample (scale_factor = np .prod (h ["upsample_rates" ]))
299
284
self .m_source = SourceModuleHnNSF (
300
285
sampling_rate = h ["sampling_rate" ],
301
286
harmonic_num = 8 )
302
287
self .noise_convs = nn .ModuleList ()
303
- self .conv_pre = weight_norm_modules ( Conv1dModel (h ["inter_channels" ], h ["upsample_initial_channel" ], 7 , 1 , padding = 3 ))
288
+ self .conv_pre = weight_norm ( Conv1d (h ["inter_channels" ], h ["upsample_initial_channel" ], 7 , 1 , padding = 3 ))
304
289
resblock = ResBlock1 if h ["resblock" ] == '1' else ResBlock2
305
290
self .ups = nn .ModuleList ()
306
291
for i , (u , k ) in enumerate (zip (h ["upsample_rates" ], h ["upsample_kernel_sizes" ])):
307
292
c_cur = h ["upsample_initial_channel" ] // (2 ** (i + 1 ))
308
- self .ups .append (weight_norm_modules (
309
- ConvTranspose1dModel (h ["upsample_initial_channel" ] // (2 ** i ), h ["upsample_initial_channel" ] // (2 ** (i + 1 )),
310
- k , u , padding = (k - u + 1 ) // 2 )))
293
+ self .ups .append (weight_norm (
294
+ ConvTranspose1d (h ["upsample_initial_channel" ] // (2 ** i ), h ["upsample_initial_channel" ] // (2 ** (i + 1 )),
295
+ k , u , padding = (k - u + 1 ) // 2 )))
311
296
if i + 1 < len (h ["upsample_rates" ]): #
312
297
stride_f0 = np .prod (h ["upsample_rates" ][i + 1 :])
313
- self .noise_convs .append (Conv1dModel (
298
+ self .noise_convs .append (Conv1d (
314
299
1 , c_cur , kernel_size = stride_f0 * 2 , stride = stride_f0 , padding = (stride_f0 + 1 ) // 2 ))
315
300
else :
316
301
self .noise_convs .append (Conv1d (1 , c_cur , kernel_size = 1 ))
@@ -320,7 +305,7 @@ def __init__(self, h):
320
305
for j , (k , d ) in enumerate (zip (h ["resblock_kernel_sizes" ], h ["resblock_dilation_sizes" ])):
321
306
self .resblocks .append (resblock (h , ch , k , d ))
322
307
323
- self .conv_post = weight_norm_modules ( Conv1dModel (ch , 1 , 7 , 1 , padding = 3 ))
308
+ self .conv_post = weight_norm ( Conv1d (ch , 1 , 7 , 1 , padding = 3 ))
324
309
self .ups .apply (init_weights )
325
310
self .conv_post .apply (init_weights )
326
311
self .cond = nn .Conv1d (h ['gin_channels' ], h ['upsample_initial_channel' ], 1 )
@@ -357,18 +342,18 @@ def forward(self, x, f0, g=None):
357
342
def remove_weight_norm (self ):
358
343
print ('Removing weight norm...' )
359
344
for l in self .ups :
360
- remove_weight_norm_modules (l )
345
+ remove_weight_norm (l )
361
346
for l in self .resblocks :
362
347
l .remove_weight_norm ()
363
- remove_weight_norm_modules (self .conv_pre )
364
- remove_weight_norm_modules (self .conv_post )
348
+ remove_weight_norm (self .conv_pre )
349
+ remove_weight_norm (self .conv_post )
365
350
366
351
367
352
class DiscriminatorP (torch .nn .Module ):
368
353
def __init__ (self , period , kernel_size = 5 , stride = 3 , use_spectral_norm = False ):
369
354
super (DiscriminatorP , self ).__init__ ()
370
355
self .period = period
371
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
356
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
372
357
self .convs = nn .ModuleList ([
373
358
norm_f (Conv2d (1 , 32 , (kernel_size , 1 ), (stride , 1 ), padding = (get_padding (5 , 1 ), 0 ))),
374
359
norm_f (Conv2d (32 , 128 , (kernel_size , 1 ), (stride , 1 ), padding = (get_padding (5 , 1 ), 0 ))),
@@ -427,7 +412,7 @@ def forward(self, y, y_hat):
427
412
class DiscriminatorS (torch .nn .Module ):
428
413
def __init__ (self , use_spectral_norm = False ):
429
414
super (DiscriminatorS , self ).__init__ ()
430
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
415
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
431
416
self .convs = nn .ModuleList ([
432
417
norm_f (Conv1d (1 , 128 , 15 , 1 , padding = 7 )),
433
418
norm_f (Conv1d (128 , 128 , 41 , 2 , groups = 4 , padding = 20 )),
0 commit comments