Skip to content

Commit 319d8e8

Browse files
We open source Metis in Amphion (#401)
* metis model adn inference * add metis in readme * add example wav * Update README.md * Update README.md --------- Co-authored-by: Chaoren Wang <[email protected]>
1 parent a705139 commit 319d8e8

34 files changed

+1875
-14
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,6 @@ logs
6161
source_audio
6262
result
6363
conversion_results
64-
get_available_gpu.py
64+
get_available_gpu.py
65+
66+
*.safetensors

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
In addition to the specific generation tasks, Amphion includes several **vocoders** and **evaluation metrics**. A vocoder is an important module for producing high-quality audio signals, while evaluation metrics are critical for ensuring consistent metrics in generation tasks. Moreover, Amphion is dedicated to advancing audio generation in real-world applications, such as building **large-scale datasets** for speech synthesis.
3535

3636
## 🚀 News
37-
- **2025/02/24**: *The Emilia-Large dataset, featuring over 200,000 hours of data, is now available!!!* Emilia-Large combines the original 101k-hour Emilia dataset (licensed under `CC BY-NC 4.0`) with the brand-new 114k-hour **Emilia-YODAS dataset** (licensed under `CC BY 4.0`). Download at [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset). Check details at [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2501.15907).
37+
- **2025/02/26**: We release [***Metis***](https://github.com/open-mmlab/Amphion/tree/main/models/tts/metis), a foundation model for unified speech generation. The system supports zero-shot text-to-speech, voice conversion, target speaker extraction, speech enhancement, and lip-to-speech. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/pdf/2502.03128) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/metis)
38+
- **2025/02/26**: *The Emilia-Large dataset, featuring over 200,000 hours of data, is now available!!!* Emilia-Large combines the original 101k-hour Emilia dataset (licensed under `CC BY-NC 4.0`) with the brand-new 114k-hour **Emilia-YODAS dataset** (licensed under `CC BY 4.0`). Download at [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Dataset-yellow)](https://huggingface.co/datasets/amphion/Emilia-Dataset). Check details at [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2501.15907).
3839
- **2025/01/30**: We release [Amphion v0.2 Technical Report](https://arxiv.org/abs/2501.15442), which provides a comprehensive overview of the Amphion updates in 2024. [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2501.15442)
3940
- **2025/01/23**: [MaskGCT](https://arxiv.org/abs/2409.00750) and [Vevo](https://openreview.net/pdf?id=anQDiQZhDP) got accepted by ICLR 2025! 🎉
4041
- **2024/12/22**: We release the reproduction of **Vevo**, a zero-shot voice imitation framework with controllable timbre and style. Vevo can be applied into a series of speech generation tasks, including VC, TTS, AC, and more. The released pre-trained models are trained on [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) dataset and achieve SOTA zero-shot VC performance. [![arXiv](https://img.shields.io/badge/OpenReview-Paper-COLOR.svg)](https://openreview.net/pdf?id=anQDiQZhDP) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/Vevo) [![WebPage](https://img.shields.io/badge/WebPage-Demo-red)](https://versavoice.github.io/) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](models/vc/vevo/README.md)

imgs/metis/fine-tune.png

230 KB
Loading

imgs/metis/pre-train.png

249 KB
Loading

imgs/metis/two-stage.png

261 KB
Loading

models/tts/maskgct/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ MaskGCT (**Mask**ed **G**enerative **C**odec **T**ransformer) is *a fully non-au
2121

2222
## News
2323

24+
- **2025/02/26**: We release [**Metis**](https://github.com/open-mmlab/Amphion/tree/main/models/tts/metis), an upgraded version of MaskGCT that supports multiple speech generation tasks (text-to-speech, voice conversion, target speaker extraction, speech enhancement, and lip2speech) within a unified framework.
25+
26+
- **2025/01/25**: MaskGCT gets accepted by ICLR 2025.
27+
2428
- **2024/10/19**: We release **MaskGCT**, a fully non-autoregressive TTS model that eliminates the need for explicit alignment information between text and speech supervision. MaskGCT is trained on [Emilia](https://huggingface.co/datasets/amphion/Emilia-Dataset) dataset and achieves SOTA zero-shot TTS performance.
2529

2630
## Issues

models/tts/maskgct/llama_nar.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,13 @@ def __init__(
430430
hidden_size=1024,
431431
num_heads=16,
432432
num_layers=16,
433+
use_phone_cond=True,
433434
config=LlamaConfig(0, 256, 1024, 1, 1),
434435
):
435436
super().__init__(config)
436437

438+
self.use_phone_cond = use_phone_cond
439+
437440
self.layers = nn.ModuleList(
438441
[
439442
LlamaNARDecoderLayer(
@@ -458,11 +461,12 @@ def __init__(
458461
nn.Linear(hidden_size * 4, hidden_size),
459462
)
460463

461-
self.cond_mlp = nn.Sequential(
462-
nn.Linear(hidden_size, hidden_size * 4),
463-
nn.SiLU(),
464-
nn.Linear(hidden_size * 4, hidden_size),
465-
)
464+
if self.use_phone_cond:
465+
self.cond_mlp = nn.Sequential(
466+
nn.Linear(hidden_size, hidden_size * 4),
467+
nn.SiLU(),
468+
nn.Linear(hidden_size * 4, hidden_size),
469+
)
466470

467471
for layer in self.layers:
468472
layer.input_layernorm = LlamaAdaptiveRMSNorm(
@@ -535,10 +539,15 @@ def forward(
535539

536540
# retrieve some shape info
537541

538-
phone_embedding = self.cond_mlp(phone_embedding) # (B, T, C)
539-
phone_length = phone_embedding.shape[1]
540-
inputs_embeds = torch.cat([phone_embedding, x], dim=1)
541-
attention_mask = torch.cat([phone_mask, x_mask], dim=1)
542+
if self.use_phone_cond and phone_embedding is not None:
543+
phone_embedding = self.cond_mlp(phone_embedding) # (B, T, C)
544+
phone_length = phone_embedding.shape[1]
545+
inputs_embeds = torch.cat([phone_embedding, x], dim=1)
546+
attention_mask = torch.cat([phone_mask, x_mask], dim=1)
547+
else:
548+
inputs_embeds = x
549+
attention_mask = x_mask
550+
phone_length = 0
542551

543552
# diffusion step embedding
544553
diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)

models/tts/maskgct/maskgct_t2s.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
cfg_scale=0.2,
4242
cond_codebook_size=8192,
4343
cond_dim=1024,
44+
use_phone_cond=True,
4445
cfg=None,
4546
):
4647
super().__init__()
@@ -73,28 +74,37 @@ def __init__(
7374
cond_dim = (
7475
cfg.cond_dim if cfg is not None and hasattr(cfg, "cond_dim") else cond_dim
7576
)
77+
use_phone_cond = (
78+
cfg.use_phone_cond
79+
if cfg is not None and hasattr(cfg, "use_phone_cond")
80+
else use_phone_cond
81+
)
7682

7783
self.hidden_size = hidden_size
7884
self.num_layers = num_layers
7985
self.num_heads = num_heads
8086
self.cfg_scale = cfg_scale
8187
self.cond_codebook_size = cond_codebook_size
8288
self.cond_dim = cond_dim
89+
self.use_phone_cond = use_phone_cond
8390

8491
self.mask_emb = nn.Embedding(1, self.hidden_size)
8592

8693
self.to_logit = nn.Linear(self.hidden_size, self.cond_codebook_size)
8794

8895
self.cond_emb = nn.Embedding(cond_codebook_size, self.hidden_size)
8996

90-
self.phone_emb = nn.Embedding(1024, hidden_size, padding_idx=1023)
97+
if self.use_phone_cond:
98+
self.phone_emb = nn.Embedding(1024, hidden_size, padding_idx=1023)
99+
torch.nn.init.normal_(self.phone_emb.weight, mean=0.0, std=0.02)
91100

92101
self.reset_parameters()
93102

94103
self.diff_estimator = DiffLlamaPrefix(
95104
hidden_size=hidden_size,
96105
num_heads=num_heads,
97106
num_layers=num_layers,
107+
use_phone_cond=use_phone_cond,
98108
)
99109

100110
def mask_prob(self, t):

models/tts/metis/README.md

+240
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# *Metis*: A Foundation Speech Generation Model with Masked Generative Pre-training
2+
3+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/pdf/2502.03128)
4+
[![readme](https://img.shields.io/badge/README-Key%20Features-blue)](../../../models/tts/metis/README.md)
5+
[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/metis)
6+
[![ModelScope](https://img.shields.io/badge/ModelScope-model-cyan)](https://modelscope.cn/models/amphion/metis)
7+
8+
<!-- [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/metis) -->
9+
<!-- [![ModelScope](https://img.shields.io/badge/ModelScope-space-purple)](https://modelscope.cn/studios/amphion/metis) -->
10+
11+
## Overview
12+
13+
We introduce ***Metis***, a foundation model for unified speech generation.
14+
Unlike previous task-specific or multi-task models, Metis follows a pre-training and fine-tuning paradigm. It is pre-trained on large-scale unlabeled speech data using masked generative modeling and then fine-tuned to adapt to diverse speech generation tasks.
15+
Specifically, (1) Metis utilizes two discrete speech representations: SSL tokens derived from speech self-supervised learning (SSL) features, and acoustic tokens directly quantized from waveforms. (2) Metis performs masked generative pre-training on SSL tokens, utilizing 300K hours of diverse speech data, without any additional condition. (3) Through fine-tuning with task-specific conditions, Metis achieves efficient adaptation to various speech generation tasks while supporting multimodal input, even when using limited data and trainable parameters.
16+
Experiments demonstrate that Metis can serve as a foundation model for unified speech generation: Metis outperforms state-of-the-art task-specific or multi-task systems
17+
across five speech generation tasks, including zero-shot text-to-speech, voice conversion, target speaker extraction, speech enhancement, and lip-to-speech, even with fewer than 20M trainable parameters or 300 times less training data.
18+
Audio samples are available at [demo page](https://metis-demo.github.io/).
19+
20+
21+
<div align="center">
22+
<img src="../../../imgs/metis/pre-train.png" width="42%">
23+
<img src="../../../imgs/metis/fine-tune.png" width="48%">
24+
</div>
25+
<div align="center">
26+
<p><i>Pre-training (left) and fine-tuning (right).</i></p>
27+
</div>
28+
29+
## News
30+
31+
- **2025/02/26**: We release ***Metis***, a foundation model for unified speech generation. The system supports zero-shot text-to-speech, voice conversion, target speaker extraction, speech enhancement, and lip-to-speech.
32+
33+
34+
<!-- ## Todo List
35+
36+
- [ ] Add inference code for lip2speech -->
37+
38+
39+
## Model Introduction
40+
41+
Metis is fully compatible with MaskGCT and shares several key model components with it. These shared components are:
42+
43+
44+
| Model Name | Description |
45+
| --------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------- |
46+
| [Semantic Codec](https://huggingface.co/amphion/MaskGCT/tree/main/semantic_codec) | Converting speech to semantic tokens. |
47+
| [Acoustic Codec](https://huggingface.co/amphion/MaskGCT/tree/main/acoustic_codec) | Converting speech to acoustic tokens and reconstructing waveform from acoustic tokens. |
48+
| [Semantic2Acoustic](https://huggingface.co/amphion/MaskGCT/tree/main/s2a_model) | Predicts acoustic tokens conditioned on semantic tokens. |
49+
<!-- | [MaskGCT-T2S](https://huggingface.co/amphion/MaskGCT/tree/main/t2s_model) | Predicting semantic tokens with text and prompt semantic tokens. | -->
50+
51+
We open-source the pretrained model checkpoint of the first stage of Metis (with masked generative pre-training), as well as the fine-tuned models for speech enhancement (SE), target speaker extraction (TSE), voice conversion (VC), lip-to-speech (L2S), and the unified multi-task (Omni) model.
52+
53+
For zero-shot text-to-speech, you can download the text2semantic model from MaskGCT, which is compatible with the Metis framework.
54+
55+
| Model Name | Description |
56+
| --- | --- |
57+
| [Metis-Base](https://huggingface.co/amphion/metis/tree/main/metis_base) | The base model pre-trained with masked generative pre-training. |
58+
| [Metis-TSE](https://huggingface.co/amphion/metis/tree/main/metis_tse) | Fine-tuned model for target speaker extraction. Available in both full-scale and LoRA ($r = 32$) versions. |
59+
| [Metis-VC](https://huggingface.co/amphion/metis/tree/main/metis_vc) | Fine-tuned model for voice conversion. Available in full-scale version. |
60+
| [Metis-SE](https://huggingface.co/amphion/metis/tree/main/metis_se) | Fine-tuned model for speech enhancement. Available in both full-scale and LoRA ($r = 32$) versions. |
61+
| [Metis-L2S](https://huggingface.co/amphion/metis/tree/main/metis_l2s) | Fine-tuned model for lip-to-speech. Available in full-scale version. |
62+
| [Metis-TTS](https://huggingface.co/amphion/MaskGCT/tree/main/t2s_model) | Zero-shot text-to-speech model (as same as the first stage of MaskGCT). |
63+
| [Metis-Omni](https://huggingface.co/amphion/metis/tree/main/metis_omni) | Unified multi-task model supporting zero-shot TTS, VC, TSE, and SE. |
64+
65+
66+
## Usage
67+
68+
To run this model, you need to follow the steps below:
69+
70+
1. Clone the repository and install the environment.
71+
2. Run the Inference script.
72+
73+
### Clone and Environment
74+
75+
#### 1. Clone the repository
76+
77+
```bash
78+
git clone https://github.com/open-mmlab/Amphion.git
79+
cd Amphion
80+
```
81+
#### 2. Install the environment
82+
83+
Before start installing, making sure you are under the `Amphion` directory. If not, use `cd` to enter.
84+
85+
Since we use `phonemizer` to convert text to phoneme, you need to install `espeak-ng` first. More details can be found [here](https://bootphon.github.io/phonemizer/install.html). Choose the correct installation command according to your operating system:
86+
87+
```bash
88+
# For Debian-like distribution (e.g. Ubuntu, Mint, etc.)
89+
sudo apt-get install espeak-ng
90+
# For RedHat-like distribution (e.g. CentOS, Fedora, etc.)
91+
sudo yum install espeak-ng
92+
93+
# For Windows
94+
# Please visit https://github.com/espeak-ng/espeak-ng/releases to download .msi installer
95+
```
96+
97+
**The environment used for Metis is the same as the one used for MaskGCT.**
98+
99+
Now, we are going to install the environment. It is recommended to use conda to configure:
100+
101+
```bash
102+
conda create -n maskgct python=3.10
103+
conda activate maskgct
104+
105+
pip install -r models/tts/maskgct/requirements.txt
106+
```
107+
108+
### Inference
109+
110+
#### 1. Inference Script
111+
112+
```bash
113+
# Metis TSE
114+
python -m models.tts.metis.metis_infer_tse
115+
116+
# Metis SE
117+
python -m models.tts.metis.metis_infer_se
118+
119+
# Metis VC
120+
python -m models.tts.metis.metis_infer_vc
121+
122+
# Metis Lip2Speech
123+
python -m models.tts.metis.metis_infer_l2s
124+
```
125+
126+
You can also use a similar framework for inference with MaskGCT:
127+
128+
```bash
129+
# Metis TTS (MaskGCT)
130+
python -m models.tts.maskgct.maskgct_infer_tts
131+
```
132+
133+
You can also use only one model (Metis-Omni) to infer TTS, VC, TSE, and SE tasks.
134+
135+
```bash
136+
# Metis Omni
137+
python -m models.tts.metis.metis_infer_omni
138+
```
139+
140+
Running this will automatically download the pretrained model from HuggingFace and start the inference process. We provide example audio files for inference. Please see the scripts for more details and parameter configurations.
141+
142+
143+
#### 2. Example Usaage
144+
145+
Take Metis-TSE for example, the inference script first downloads the model checkpoints:
146+
147+
```python
148+
# download base model, lora weights, and adapter weights
149+
base_ckpt_dir = snapshot_download(
150+
"amphion/metis",
151+
repo_type="model",
152+
local_dir="./models/tts/metis/ckpt",
153+
allow_patterns=["metis_base/model.safetensors"],
154+
)
155+
lora_ckpt_dir = snapshot_download(
156+
"amphion/metis",
157+
repo_type="model",
158+
local_dir="./models/tts/metis/ckpt",
159+
allow_patterns=["metis_tse/metis_tse_lora_32.safetensors"],
160+
)
161+
adapter_ckpt_dir = snapshot_download(
162+
"amphion/metis",
163+
repo_type="model",
164+
local_dir="./models/tts/metis/ckpt",
165+
allow_patterns=["metis_tse/metis_tse_lora_32_adapter.safetensors"],
166+
)
167+
```
168+
169+
Then, the script will load the model checkpoints and initialize the fine-tined Metis model:
170+
171+
```python
172+
base_ckpt_path = os.path.join(base_ckpt_dir, "metis_base/model.safetensors")
173+
lora_ckpt_path = os.path.join(
174+
lora_ckpt_dir, "metis_tse/metis_tse_lora_32.safetensors"
175+
)
176+
adapter_ckpt_path = os.path.join(
177+
adapter_ckpt_dir, "metis_tse/metis_tse_lora_32_adapter.safetensors"
178+
)
179+
180+
metis = Metis(
181+
base_ckpt_path=base_ckpt_path,
182+
lora_ckpt_path=lora_ckpt_path,
183+
adapter_ckpt_path=adapter_ckpt_path,
184+
cfg=metis_cfg,
185+
device=device,
186+
model_type="tse",
187+
)
188+
```
189+
190+
Finally, the script will generate the speech and save it to the `models/tts/metis/wav/tse/gen.wav` directory, you can change this in the script.
191+
192+
```python
193+
prompt_speech_path = "./models/tts/metis/wav/tse/prompt.wav"
194+
source_speech_path = "./models/tts/metis/wav/tse/mix.wav"
195+
196+
n_timesteps = 10
197+
cfg = 0.0
198+
199+
gen_speech = metis(
200+
prompt_speech_path=prompt_speech_path,
201+
source_speech_path=source_speech_path,
202+
cfg=cfg,
203+
n_timesteps=n_timesteps,
204+
model_type="tse",
205+
)
206+
207+
sf.write("./models/tts/metis/wav/tse/gen.wav", gen_speech, 24000)
208+
```
209+
210+
## Citations
211+
212+
If you use Metis in your research, please cite the following paper:
213+
214+
```bibtex
215+
@article{wang2025metis,
216+
title={Metis: A Foundation Speech Generation Model with Masked Generative Pre-training},
217+
author={Wang, Yuancheng and Zheng, Jiachen and Zhang, Junan and Zhang, Xueyao and Liao, Huan and Wu, Zhizheng},
218+
journal={arXiv preprint arXiv:2502.03128},
219+
year={2025}
220+
}
221+
@inproceedings{wang2024maskgct,
222+
author={Wang, Yuancheng and Zhan, Haoyue and Liu, Liwei and Zeng, Ruihong and Guo, Haotian and Zheng, Jiachen and Zhang, Qiang and Zhang, Xueyao and Zhang, Shunsi and Wu, Zhizheng},
223+
title={MaskGCT: Zero-Shot Text-to-Speech with Masked Generative Codec Transformer},
224+
booktitle = {{ICLR}},
225+
publisher = {OpenReview.net},
226+
year = {2025}
227+
}
228+
@article{amphion_v0.2,
229+
title = {Overview of the Amphion Toolkit (v0.2)},
230+
author = {Jiaqi Li and Xueyao Zhang and Yuancheng Wang and Haorui He and Chaoren Wang and Li Wang and Huan Liao and Junyi Ao and Zeyu Xie and Yiqiao Huang and Junan Zhang and Zhizheng Wu},
231+
year = {2025},
232+
journal = {arXiv preprint arXiv:2501.15442},
233+
}
234+
@inproceedings{amphion,
235+
author={Zhang, Xueyao and Xue, Liumeng and Gu, Yicheng and Wang, Yuancheng and Li, Jiaqi and He, Haorui and Wang, Chaoren and Song, Ting and Chen, Xi and Fang, Zihao and Chen, Haopeng and Zhang, Junan and Tang, Tze Ying and Zou, Lexiao and Wang, Mingxuan and Han, Jun and Chen, Kai and Li, Haizhou and Wu, Zhizheng},
236+
title={Amphion: An Open-Source Audio, Music and Speech Generation Toolkit},
237+
booktitle={{IEEE} Spoken Language Technology Workshop, {SLT} 2024},
238+
year={2024}
239+
}
240+
```

0 commit comments

Comments
 (0)