Skip to content

Commit 5c37dbd

Browse files
Onnx导出 (#176)
* Add files via upload * Add files via upload * Add files via upload * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3a83ea1 commit 5c37dbd

File tree

3 files changed

+1420
-0
lines changed

3 files changed

+1420
-0
lines changed

attentions_onnx.py

+378
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
import math
2+
import torch
3+
from torch import nn
4+
from torch.nn import functional as F
5+
6+
import commons
7+
import logging
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class LayerNorm(nn.Module):
13+
def __init__(self, channels, eps=1e-5):
14+
super().__init__()
15+
self.channels = channels
16+
self.eps = eps
17+
18+
self.gamma = nn.Parameter(torch.ones(channels))
19+
self.beta = nn.Parameter(torch.zeros(channels))
20+
21+
def forward(self, x):
22+
x = x.transpose(1, -1)
23+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24+
return x.transpose(1, -1)
25+
26+
27+
@torch.jit.script
28+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29+
n_channels_int = n_channels[0]
30+
in_act = input_a + input_b
31+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
32+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33+
acts = t_act * s_act
34+
return acts
35+
36+
37+
class Encoder(nn.Module):
38+
def __init__(
39+
self,
40+
hidden_channels,
41+
filter_channels,
42+
n_heads,
43+
n_layers,
44+
kernel_size=1,
45+
p_dropout=0.0,
46+
window_size=4,
47+
isflow=True,
48+
**kwargs
49+
):
50+
super().__init__()
51+
self.hidden_channels = hidden_channels
52+
self.filter_channels = filter_channels
53+
self.n_heads = n_heads
54+
self.n_layers = n_layers
55+
self.kernel_size = kernel_size
56+
self.p_dropout = p_dropout
57+
self.window_size = window_size
58+
# if isflow:
59+
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60+
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61+
# self.cond_layer = weight_norm(cond_layer, name='weight')
62+
# self.gin_channels = 256
63+
self.cond_layer_idx = self.n_layers
64+
if "gin_channels" in kwargs:
65+
self.gin_channels = kwargs["gin_channels"]
66+
if self.gin_channels != 0:
67+
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68+
# vits2 says 3rd block, so idx is 2 by default
69+
self.cond_layer_idx = (
70+
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71+
)
72+
logging.debug(self.gin_channels, self.cond_layer_idx)
73+
assert (
74+
self.cond_layer_idx < self.n_layers
75+
), "cond_layer_idx should be less than n_layers"
76+
self.drop = nn.Dropout(p_dropout)
77+
self.attn_layers = nn.ModuleList()
78+
self.norm_layers_1 = nn.ModuleList()
79+
self.ffn_layers = nn.ModuleList()
80+
self.norm_layers_2 = nn.ModuleList()
81+
for i in range(self.n_layers):
82+
self.attn_layers.append(
83+
MultiHeadAttention(
84+
hidden_channels,
85+
hidden_channels,
86+
n_heads,
87+
p_dropout=p_dropout,
88+
window_size=window_size,
89+
)
90+
)
91+
self.norm_layers_1.append(LayerNorm(hidden_channels))
92+
self.ffn_layers.append(
93+
FFN(
94+
hidden_channels,
95+
hidden_channels,
96+
filter_channels,
97+
kernel_size,
98+
p_dropout=p_dropout,
99+
)
100+
)
101+
self.norm_layers_2.append(LayerNorm(hidden_channels))
102+
103+
def forward(self, x, x_mask, g=None):
104+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
105+
x = x * x_mask
106+
for i in range(self.n_layers):
107+
if i == self.cond_layer_idx and g is not None:
108+
g = self.spk_emb_linear(g.transpose(1, 2))
109+
g = g.transpose(1, 2)
110+
x = x + g
111+
x = x * x_mask
112+
y = self.attn_layers[i](x, x, attn_mask)
113+
y = self.drop(y)
114+
x = self.norm_layers_1[i](x + y)
115+
116+
y = self.ffn_layers[i](x, x_mask)
117+
y = self.drop(y)
118+
x = self.norm_layers_2[i](x + y)
119+
x = x * x_mask
120+
return x
121+
122+
123+
class MultiHeadAttention(nn.Module):
124+
def __init__(
125+
self,
126+
channels,
127+
out_channels,
128+
n_heads,
129+
p_dropout=0.0,
130+
window_size=None,
131+
heads_share=True,
132+
block_length=None,
133+
proximal_bias=False,
134+
proximal_init=False,
135+
):
136+
super().__init__()
137+
assert channels % n_heads == 0
138+
139+
self.channels = channels
140+
self.out_channels = out_channels
141+
self.n_heads = n_heads
142+
self.p_dropout = p_dropout
143+
self.window_size = window_size
144+
self.heads_share = heads_share
145+
self.block_length = block_length
146+
self.proximal_bias = proximal_bias
147+
self.proximal_init = proximal_init
148+
self.attn = None
149+
150+
self.k_channels = channels // n_heads
151+
self.conv_q = nn.Conv1d(channels, channels, 1)
152+
self.conv_k = nn.Conv1d(channels, channels, 1)
153+
self.conv_v = nn.Conv1d(channels, channels, 1)
154+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
155+
self.drop = nn.Dropout(p_dropout)
156+
157+
if window_size is not None:
158+
n_heads_rel = 1 if heads_share else n_heads
159+
rel_stddev = self.k_channels**-0.5
160+
self.emb_rel_k = nn.Parameter(
161+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
162+
* rel_stddev
163+
)
164+
self.emb_rel_v = nn.Parameter(
165+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
166+
* rel_stddev
167+
)
168+
169+
nn.init.xavier_uniform_(self.conv_q.weight)
170+
nn.init.xavier_uniform_(self.conv_k.weight)
171+
nn.init.xavier_uniform_(self.conv_v.weight)
172+
if proximal_init:
173+
with torch.no_grad():
174+
self.conv_k.weight.copy_(self.conv_q.weight)
175+
self.conv_k.bias.copy_(self.conv_q.bias)
176+
177+
def forward(self, x, c, attn_mask=None):
178+
q = self.conv_q(x)
179+
k = self.conv_k(c)
180+
v = self.conv_v(c)
181+
182+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
183+
184+
x = self.conv_o(x)
185+
return x
186+
187+
def attention(self, query, key, value, mask=None):
188+
# reshape [b, d, t] -> [b, n_h, t, d_k]
189+
b, d, t_s, t_t = (*key.size(), query.size(2))
190+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
191+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
192+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
193+
194+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
195+
if self.window_size is not None:
196+
assert (
197+
t_s == t_t
198+
), "Relative attention is only available for self-attention."
199+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
200+
rel_logits = self._matmul_with_relative_keys(
201+
query / math.sqrt(self.k_channels), key_relative_embeddings
202+
)
203+
scores_local = self._relative_position_to_absolute_position(rel_logits)
204+
scores = scores + scores_local
205+
if self.proximal_bias:
206+
assert t_s == t_t, "Proximal bias is only available for self-attention."
207+
scores = scores + self._attention_bias_proximal(t_s).to(
208+
device=scores.device, dtype=scores.dtype
209+
)
210+
if mask is not None:
211+
scores = scores.masked_fill(mask == 0, -1e4)
212+
if self.block_length is not None:
213+
assert (
214+
t_s == t_t
215+
), "Local attention is only available for self-attention."
216+
block_mask = (
217+
torch.ones_like(scores)
218+
.triu(-self.block_length)
219+
.tril(self.block_length)
220+
)
221+
scores = scores.masked_fill(block_mask == 0, -1e4)
222+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
223+
p_attn = self.drop(p_attn)
224+
output = torch.matmul(p_attn, value)
225+
if self.window_size is not None:
226+
relative_weights = self._absolute_position_to_relative_position(p_attn)
227+
value_relative_embeddings = self._get_relative_embeddings(
228+
self.emb_rel_v, t_s
229+
)
230+
output = output + self._matmul_with_relative_values(
231+
relative_weights, value_relative_embeddings
232+
)
233+
output = (
234+
output.transpose(2, 3).contiguous().view(b, d, t_t)
235+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
236+
return output, p_attn
237+
238+
def _matmul_with_relative_values(self, x, y):
239+
"""
240+
x: [b, h, l, m]
241+
y: [h or 1, m, d]
242+
ret: [b, h, l, d]
243+
"""
244+
ret = torch.matmul(x, y.unsqueeze(0))
245+
return ret
246+
247+
def _matmul_with_relative_keys(self, x, y):
248+
"""
249+
x: [b, h, l, d]
250+
y: [h or 1, m, d]
251+
ret: [b, h, l, m]
252+
"""
253+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
254+
return ret
255+
256+
def _get_relative_embeddings(self, relative_embeddings, length):
257+
max_relative_position = 2 * self.window_size + 1
258+
# Pad first before slice to avoid using cond ops.
259+
pad_length = max(length - (self.window_size + 1), 0)
260+
slice_start_position = max((self.window_size + 1) - length, 0)
261+
slice_end_position = slice_start_position + 2 * length - 1
262+
if pad_length > 0:
263+
padded_relative_embeddings = F.pad(
264+
relative_embeddings,
265+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
266+
)
267+
else:
268+
padded_relative_embeddings = relative_embeddings
269+
used_relative_embeddings = padded_relative_embeddings[
270+
:, slice_start_position:slice_end_position
271+
]
272+
return used_relative_embeddings
273+
274+
def _relative_position_to_absolute_position(self, x):
275+
"""
276+
x: [b, h, l, 2*l-1]
277+
ret: [b, h, l, l]
278+
"""
279+
batch, heads, length, _ = x.size()
280+
# Concat columns of pad to shift from relative to absolute indexing.
281+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
282+
283+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
284+
x_flat = x.view([batch, heads, length * 2 * length])
285+
x_flat = F.pad(
286+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
287+
)
288+
289+
# Reshape and slice out the padded elements.
290+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
291+
:, :, :length, length - 1 :
292+
]
293+
return x_final
294+
295+
def _absolute_position_to_relative_position(self, x):
296+
"""
297+
x: [b, h, l, l]
298+
ret: [b, h, l, 2*l-1]
299+
"""
300+
batch, heads, length, _ = x.size()
301+
# padd along column
302+
x = F.pad(
303+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
304+
)
305+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
306+
# add 0's in the beginning that will skew the elements after reshape
307+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
308+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
309+
return x_final
310+
311+
def _attention_bias_proximal(self, length):
312+
"""Bias for self-attention to encourage attention to close positions.
313+
Args:
314+
length: an integer scalar.
315+
Returns:
316+
a Tensor with shape [1, 1, length, length]
317+
"""
318+
r = torch.arange(length, dtype=torch.float32)
319+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
320+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
321+
322+
323+
class FFN(nn.Module):
324+
def __init__(
325+
self,
326+
in_channels,
327+
out_channels,
328+
filter_channels,
329+
kernel_size,
330+
p_dropout=0.0,
331+
activation=None,
332+
causal=False,
333+
):
334+
super().__init__()
335+
self.in_channels = in_channels
336+
self.out_channels = out_channels
337+
self.filter_channels = filter_channels
338+
self.kernel_size = kernel_size
339+
self.p_dropout = p_dropout
340+
self.activation = activation
341+
self.causal = causal
342+
343+
if causal:
344+
self.padding = self._causal_padding
345+
else:
346+
self.padding = self._same_padding
347+
348+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
349+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
350+
self.drop = nn.Dropout(p_dropout)
351+
352+
def forward(self, x, x_mask):
353+
x = self.conv_1(self.padding(x * x_mask))
354+
if self.activation == "gelu":
355+
x = x * torch.sigmoid(1.702 * x)
356+
else:
357+
x = torch.relu(x)
358+
x = self.drop(x)
359+
x = self.conv_2(self.padding(x * x_mask))
360+
return x * x_mask
361+
362+
def _causal_padding(self, x):
363+
if self.kernel_size == 1:
364+
return x
365+
pad_l = self.kernel_size - 1
366+
pad_r = 0
367+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
368+
x = F.pad(x, commons.convert_pad_shape(padding))
369+
return x
370+
371+
def _same_padding(self, x):
372+
if self.kernel_size == 1:
373+
return x
374+
pad_l = (self.kernel_size - 1) // 2
375+
pad_r = self.kernel_size // 2
376+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
377+
x = F.pad(x, commons.convert_pad_shape(padding))
378+
return x

0 commit comments

Comments
 (0)