Skip to content

Commit badb125

Browse files
Change the utils.download_emo_models (#199)
* Change the utils.download_emo_models Change utils.download_emo_models(config.mirror, model_name, REPO_ID) to utils.download_emo_models(config.mirror, REPO_ID, model_name) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2c528ce commit badb125

File tree

1 file changed

+162
-161
lines changed

1 file changed

+162
-161
lines changed

emo_gen.py

+162-161
Original file line numberDiff line numberDiff line change
@@ -1,162 +1,163 @@
1-
import argparse
2-
import os
3-
from pathlib import Path
4-
5-
import librosa
6-
import numpy as np
7-
import torch
8-
import torch.nn as nn
9-
from torch.utils.data import DataLoader, Dataset
10-
from tqdm import tqdm
11-
from transformers import Wav2Vec2Processor
12-
from transformers.models.wav2vec2.modeling_wav2vec2 import (
13-
Wav2Vec2Model,
14-
Wav2Vec2PreTrainedModel,
15-
)
16-
17-
import utils
18-
from config import config
19-
20-
21-
class RegressionHead(nn.Module):
22-
r"""Classification head."""
23-
24-
def __init__(self, config):
25-
super().__init__()
26-
27-
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
28-
self.dropout = nn.Dropout(config.final_dropout)
29-
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
30-
31-
def forward(self, features, **kwargs):
32-
x = features
33-
x = self.dropout(x)
34-
x = self.dense(x)
35-
x = torch.tanh(x)
36-
x = self.dropout(x)
37-
x = self.out_proj(x)
38-
39-
return x
40-
41-
42-
class EmotionModel(Wav2Vec2PreTrainedModel):
43-
r"""Speech emotion classifier."""
44-
45-
def __init__(self, config):
46-
super().__init__(config)
47-
48-
self.config = config
49-
self.wav2vec2 = Wav2Vec2Model(config)
50-
self.classifier = RegressionHead(config)
51-
self.init_weights()
52-
53-
def forward(
54-
self,
55-
input_values,
56-
):
57-
outputs = self.wav2vec2(input_values)
58-
hidden_states = outputs[0]
59-
hidden_states = torch.mean(hidden_states, dim=1)
60-
logits = self.classifier(hidden_states)
61-
62-
return hidden_states, logits
63-
64-
65-
class AudioDataset(Dataset):
66-
def __init__(self, list_of_wav_files, sr, processor):
67-
self.list_of_wav_files = list_of_wav_files
68-
self.processor = processor
69-
self.sr = sr
70-
71-
def __len__(self):
72-
return len(self.list_of_wav_files)
73-
74-
def __getitem__(self, idx):
75-
wav_file = self.list_of_wav_files[idx]
76-
audio_data, _ = librosa.load(wav_file, sr=self.sr)
77-
processed_data = self.processor(audio_data, sampling_rate=self.sr)[
78-
"input_values"
79-
][0]
80-
return torch.from_numpy(processed_data)
81-
82-
83-
def process_func(
84-
x: np.ndarray,
85-
sampling_rate: int,
86-
model: EmotionModel,
87-
processor: Wav2Vec2Processor,
88-
device: str,
89-
embeddings: bool = False,
90-
) -> np.ndarray:
91-
r"""Predict emotions or extract embeddings from raw audio signal."""
92-
model = model.to(device)
93-
y = processor(x, sampling_rate=sampling_rate)
94-
y = y["input_values"][0]
95-
y = torch.from_numpy(y).unsqueeze(0).to(device)
96-
97-
# run through model
98-
with torch.no_grad():
99-
y = model(y)[0 if embeddings else 1]
100-
101-
# convert to numpy
102-
y = y.detach().cpu().numpy()
103-
104-
return y
105-
106-
107-
def get_emo(path):
108-
wav, sr = librosa.load(path, 16000)
109-
device = config.bert_gen_config.device
110-
return process_func(
111-
np.expand_dims(wav, 0).astype(np.float64),
112-
sr,
113-
model,
114-
processor,
115-
device,
116-
embeddings=True,
117-
).squeeze(0)
118-
119-
120-
if __name__ == "__main__":
121-
parser = argparse.ArgumentParser()
122-
parser.add_argument(
123-
"-c", "--config", type=str, default=config.bert_gen_config.config_path
124-
)
125-
parser.add_argument(
126-
"--num_processes", type=int, default=config.bert_gen_config.num_processes
127-
)
128-
args, _ = parser.parse_known_args()
129-
config_path = args.config
130-
hps = utils.get_hparams_from_file(config_path)
131-
132-
device = config.bert_gen_config.device
133-
134-
model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
135-
REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
136-
if not Path(model_name).joinpath("pytorch_model.bin").exists():
137-
utils.download_emo_models(config.mirror, model_name, REPO_ID)
138-
139-
processor = Wav2Vec2Processor.from_pretrained(model_name)
140-
model = EmotionModel.from_pretrained(model_name).to(device)
141-
142-
lines = []
143-
with open(hps.data.training_files, encoding="utf-8") as f:
144-
lines.extend(f.readlines())
145-
146-
with open(hps.data.validation_files, encoding="utf-8") as f:
147-
lines.extend(f.readlines())
148-
149-
wavnames = [line.split("|")[0] for line in lines]
150-
dataset = AudioDataset(wavnames, 16000, processor)
151-
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=16)
152-
153-
with torch.no_grad():
154-
for i, data in tqdm(enumerate(data_loader), total=len(data_loader)):
155-
wavname = wavnames[i]
156-
emo_path = wavname.replace(".wav", ".emo.npy")
157-
if os.path.exists(emo_path):
158-
continue
159-
emb = model(data.to(device))[0].detach().cpu().numpy()
160-
np.save(emo_path, emb)
161-
1+
import argparse
2+
import os
3+
from pathlib import Path
4+
5+
import librosa
6+
import numpy as np
7+
import torch
8+
import torch.nn as nn
9+
from torch.utils.data import Dataset
10+
from torch.utils.data import DataLoader, Dataset
11+
from tqdm import tqdm
12+
from transformers import Wav2Vec2Processor
13+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
14+
Wav2Vec2Model,
15+
Wav2Vec2PreTrainedModel,
16+
)
17+
18+
import utils
19+
from config import config
20+
21+
22+
class RegressionHead(nn.Module):
23+
r"""Classification head."""
24+
25+
def __init__(self, config):
26+
super().__init__()
27+
28+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
29+
self.dropout = nn.Dropout(config.final_dropout)
30+
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
31+
32+
def forward(self, features, **kwargs):
33+
x = features
34+
x = self.dropout(x)
35+
x = self.dense(x)
36+
x = torch.tanh(x)
37+
x = self.dropout(x)
38+
x = self.out_proj(x)
39+
40+
return x
41+
42+
43+
class EmotionModel(Wav2Vec2PreTrainedModel):
44+
r"""Speech emotion classifier."""
45+
46+
def __init__(self, config):
47+
super().__init__(config)
48+
49+
self.config = config
50+
self.wav2vec2 = Wav2Vec2Model(config)
51+
self.classifier = RegressionHead(config)
52+
self.init_weights()
53+
54+
def forward(
55+
self,
56+
input_values,
57+
):
58+
outputs = self.wav2vec2(input_values)
59+
hidden_states = outputs[0]
60+
hidden_states = torch.mean(hidden_states, dim=1)
61+
logits = self.classifier(hidden_states)
62+
63+
return hidden_states, logits
64+
65+
66+
class AudioDataset(Dataset):
67+
def __init__(self, list_of_wav_files, sr, processor):
68+
self.list_of_wav_files = list_of_wav_files
69+
self.processor = processor
70+
self.sr = sr
71+
72+
def __len__(self):
73+
return len(self.list_of_wav_files)
74+
75+
def __getitem__(self, idx):
76+
wav_file = self.list_of_wav_files[idx]
77+
audio_data, _ = librosa.load(wav_file, sr=self.sr)
78+
processed_data = self.processor(audio_data, sampling_rate=self.sr)[
79+
"input_values"
80+
][0]
81+
return torch.from_numpy(processed_data)
82+
83+
84+
def process_func(
85+
x: np.ndarray,
86+
sampling_rate: int,
87+
model: EmotionModel,
88+
processor: Wav2Vec2Processor,
89+
device: str,
90+
embeddings: bool = False,
91+
) -> np.ndarray:
92+
r"""Predict emotions or extract embeddings from raw audio signal."""
93+
model = model.to(device)
94+
y = processor(x, sampling_rate=sampling_rate)
95+
y = y["input_values"][0]
96+
y = torch.from_numpy(y).unsqueeze(0).to(device)
97+
98+
# run through model
99+
with torch.no_grad():
100+
y = model(y)[0 if embeddings else 1]
101+
102+
# convert to numpy
103+
y = y.detach().cpu().numpy()
104+
105+
return y
106+
107+
108+
def get_emo(path):
109+
wav, sr = librosa.load(path, 16000)
110+
device = config.bert_gen_config.device
111+
return process_func(
112+
np.expand_dims(wav, 0).astype(np.float),
113+
sr,
114+
model,
115+
processor,
116+
device,
117+
embeddings=True,
118+
).squeeze(0)
119+
120+
121+
if __name__ == "__main__":
122+
parser = argparse.ArgumentParser()
123+
parser.add_argument(
124+
"-c", "--config", type=str, default=config.bert_gen_config.config_path
125+
)
126+
parser.add_argument(
127+
"--num_processes", type=int, default=config.bert_gen_config.num_processes
128+
)
129+
args, _ = parser.parse_known_args()
130+
config_path = args.config
131+
hps = utils.get_hparams_from_file(config_path)
132+
133+
device = config.bert_gen_config.device
134+
135+
model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
136+
REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
137+
if not Path(model_name).joinpath("pytorch_model.bin").exists():
138+
utils.download_emo_models(config.mirror, REPO_ID, model_name)
139+
140+
processor = Wav2Vec2Processor.from_pretrained(model_name)
141+
model = EmotionModel.from_pretrained(model_name).to(device)
142+
143+
lines = []
144+
with open(hps.data.training_files, encoding="utf-8") as f:
145+
lines.extend(f.readlines())
146+
147+
with open(hps.data.validation_files, encoding="utf-8") as f:
148+
lines.extend(f.readlines())
149+
150+
wavnames = [line.split("|")[0] for line in lines]
151+
dataset = AudioDataset(wavnames, 16000, processor)
152+
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=16)
153+
154+
with torch.no_grad():
155+
for i, data in tqdm(enumerate(data_loader), total=len(data_loader)):
156+
wavname = wavnames[i]
157+
emo_path = wavname.replace(".wav", ".emo.npy")
158+
if os.path.exists(emo_path):
159+
continue
160+
emb = model(data.to(device))[0].detach().cpu().numpy()
161+
np.save(emo_path, emb)
162+
162163
print("Emo vec 生成完毕!")

0 commit comments

Comments
 (0)