Skip to content

Commit e46faff

Browse files
xrsrkeNouamaneTazic8efeliebakzzhhjjj
authored andcommitted
MoE without token dropping (#355)
* can only merge to main from dev (#348) * move moe from qwen modeling to src/nn * add groupedmlp * add token permute and unpermute * fix num_tokens_per_expert counting < num_experts * fix init and init scaling factor and run evals in background (#353) * can only merge to main from dev * Fix UnBoundLocalError in `clm_collator.py` (#339) * Update clm_collator.py * can only merge to main from dev (#348) --------- Co-authored-by: Nouamane Tazi <[email protected]> * fix init and init scaling factor and run evals in background (#349) * InitScalingMethod * InitScalingMethod * run evals in background (#352) * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes --------- Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> --------- Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> --------- Co-authored-by: Connector Switch <[email protected]> Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> * inference qwen moe seems to work inference seems good rn * update readme * fix router's weight initialization and wrong hidden size for non-moe mlp in qwen * add source for router weight and router logits in float32 * fixes * . * . * add parametrize grouped mlp in column and row linear * add logging per-param grad norm * fix conversation fail due to buffer on cpu * config_qwen * . * . * fix moe convert config --------- Co-authored-by: Nouamane Tazi <[email protected]> Co-authored-by: Connector Switch <[email protected]> Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]> Co-authored-by: zzhhjjj <[email protected]> Co-authored-by: nouamanetazi <[email protected]>
1 parent 184dd0b commit e46faff

File tree

17 files changed

+865
-233
lines changed

17 files changed

+865
-233
lines changed

examples/config_qwen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def get_model_config(model_size: str) -> Qwen2Config:
108108
is_qwen2_config=True,
109109
pad_token_id=None,
110110
_attn_implementation="flash_attention_2",
111-
sliding_window_size=20,
111+
# sliding_window_size=20,
112112
)
113113

114114

examples/config_qwen.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ general:
3232
consumed_train_samples: null
3333
ignore_sanity_checks: false
3434
project: debug
35-
run: qwen_20250410_014907_16027793
35+
run: qwen_20250423_201000_16423158
3636
seed: 42
3737
step: null
3838
lighteval: null
@@ -45,14 +45,15 @@ model:
4545
ddp_bucket_cap_mb: 25
4646
dtype: bfloat16
4747
init_method:
48+
scaling_method: NUM_LAYERS
4849
std: 0.025
4950
make_vocab_size_divisible_by: 1
5051
model_config:
5152
_attn_implementation: flash_attention_2
52-
_fused_rms_norm: true
53-
_fused_rotary_emb: true
54-
_use_doc_masking: true
55-
_use_qkv_packed: true
53+
_fused_rms_norm: false
54+
_fused_rotary_emb: false
55+
_use_doc_masking: false
56+
_use_qkv_packed: false
5657
attention_bias: false
5758
bos_token_id: 1
5859
eos_token_id: 2
@@ -74,7 +75,7 @@ model:
7475
rope_interleaved: false
7576
rope_scaling: null
7677
rope_theta: 10000.0
77-
sliding_window_size: 20
78+
sliding_window_size: null
7879
tie_word_embeddings: true
7980
use_cache: true
8081
vocab_size: 128256
@@ -104,7 +105,6 @@ parallelism:
104105
context_parallel_size: 1
105106
dp: 2
106107
expert_parallel_size: 1
107-
moe_layer_recompute: false
108108
pp: 1
109109
pp_engine: 1f1b
110110
recompute_layer: false

examples/config_qwen_with_moe.yaml

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
checkpoints:
2+
checkpoint_interval: 1000
3+
checkpoints_path: /fsx/phuc/new_workspace/experiments/qwen2_moe_test
4+
checkpoints_path_is_shared_file_system: false
5+
load_lr_scheduler: true
6+
load_optimizer: true
7+
resume_checkpoint_path: null
8+
save_final_state: true
9+
save_initial_state: false
10+
data_stages:
11+
- data:
12+
dataset:
13+
dataset_folder:
14+
- /fsx/loubna/datasets/llama_tokenized/fineweb-edu/merged
15+
dataset_max_tokens: null
16+
dataset_read_path: null
17+
dataset_weights: null
18+
pad_samples_to_global_batch_size: false
19+
return_positions: true
20+
shuffle_files: false
21+
skip_in_stream: false
22+
token_size_in_bytes: 4
23+
tokenizer_name: meta-llama/Llama-3.2-1B
24+
use_old_brrr_dataloader: false
25+
vocab_size: 128256
26+
num_loading_workers: 1
27+
seed: 42
28+
name: Stable Training Stage
29+
start_training_step: 1
30+
general:
31+
benchmark_csv_path: null
32+
consumed_train_samples: null
33+
ignore_sanity_checks: false
34+
project: qwen_moe
35+
run: qwen_20250410_014907_16027793
36+
seed: 42
37+
step: null
38+
lighteval: null
39+
logging:
40+
iteration_step_info_interval: 1
41+
log_level: info
42+
log_level_replica: info
43+
metrics_logging: null
44+
model:
45+
ddp_bucket_cap_mb: 25
46+
dtype: bfloat16
47+
init_method:
48+
std: 0.025
49+
make_vocab_size_divisible_by: 1
50+
model_config:
51+
_attn_implementation: flash_attention_2
52+
_fused_rms_norm: true
53+
_fused_rotary_emb: true
54+
_use_doc_masking: true
55+
_use_qkv_packed: true
56+
attention_bias: false
57+
bos_token_id: 1
58+
eos_token_id: 2
59+
flex_attention_mask: null
60+
hidden_act: silu
61+
hidden_size: 256
62+
initializer_range: 0.02
63+
intermediate_size: 768
64+
is_qwen2_config: true
65+
max_position_embeddings: 4096
66+
moe_config: null
67+
no_rope_layer: null
68+
num_attention_heads: 4
69+
num_hidden_layers: 12
70+
num_key_value_heads: 4
71+
pad_token_id: null
72+
pretraining_tp: 1
73+
rms_norm_eps: 1.0e-06
74+
rope_interleaved: false
75+
rope_scaling: null
76+
rope_theta: 10000.0
77+
sliding_window_size: 20
78+
tie_word_embeddings: true
79+
use_cache: true
80+
vocab_size: 128256
81+
z_loss_coefficient: 0.0001
82+
z_loss_enabled: false
83+
moe_config:
84+
num_experts: 8
85+
top_k: 1
86+
enable_shared_expert: true
87+
token_dispatcher_type: alltoall
88+
optimizer:
89+
accumulate_grad_in_fp32: true
90+
clip_grad: 1.0
91+
learning_rate_scheduler:
92+
learning_rate: 0.0003
93+
lr_decay_starting_step: null
94+
lr_decay_steps: 31998
95+
lr_decay_style: cosine
96+
lr_warmup_steps: 2
97+
lr_warmup_style: linear
98+
min_decay_lr: 1.0e-05
99+
optimizer_factory:
100+
adam_beta1: 0.9
101+
adam_beta2: 0.95
102+
adam_eps: 1.0e-08
103+
name: adamW
104+
torch_adam_is_fused: true
105+
weight_decay: 0.01
106+
weight_decay_exclude_named_params: []
107+
zero_stage: 0
108+
parallelism:
109+
context_parallel_size: 1
110+
dp: 2
111+
expert_parallel_size: 1
112+
pp: 1
113+
pp_engine: 1f1b
114+
recompute_layer: false
115+
tp: 1
116+
tp_linear_async_communication: true
117+
tp_mode: REDUCE_SCATTER
118+
tp_recompute_allgather: true
119+
profiler: null
120+
s3_upload: null
121+
tokenizer:
122+
tokenizer_max_length: null
123+
tokenizer_name_or_path: meta-llama/Llama-3.2-1B
124+
tokenizer_revision: null
125+
tokens:
126+
batch_accumulation_per_replica: 1
127+
limit_test_batches: 0
128+
limit_val_batches: 0
129+
micro_batch_size: 3
130+
sequence_length: 4096
131+
train_steps: 32000
132+
val_check_interval: -1
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Qwen-MoE Inference
2+
3+
This guide explains how to convert Hugging face Qwen-MoE models to Nanotron format and run inference with them.
4+
5+
## Convert Qwen-MoE to Nanotron Format
6+
7+
Navigate to the `inference/qwen_moe` directory and run:
8+
9+
```bash
10+
torchrun --nproc-per-node 1 examples/inference/qwen_moe/convert.py \
11+
--nanotron-checkpoint-path nanotron_checkpoints/Qwen1.5-MoE-A2.7B \
12+
--pretrained-model-name-or-path Qwen/Qwen1.5-MoE-A2.7B
13+
```
14+
15+
This command will save the converted model weights to the specified path in `nanotron_checkpoints`
16+
17+
## Run Inference
18+
19+
From the root directory of Nanotron, run:
20+
21+
```bash
22+
torchrun --rdzv_endpoint=localhost:29700 --rdzv-backend=c10d --nproc_per_node=1 \
23+
run_generate.py \
24+
--ckpt-path nanotron_checkpoints/Qwen1.5-MoE-A2.7B
25+
```
26+
27+
This command will load the converted model weights and run inference.

0 commit comments

Comments
 (0)