Skip to content

Commit fb1d6e9

Browse files
NouamaneTazieliebak
authored andcommitted
Nouamane/lighteval (#356)
* InitScalingMethod * InitScalingMethod * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes * add requeue * add wandb with lighteval and fix eval interval * fix this little space :( * folder_path should always have s3 when using s3 (fix consumed tokens issue) * config qwen * . --------- Co-authored-by: elie <[email protected]> Co-authored-by: “eliebak” <[email protected]>
1 parent e46faff commit fb1d6e9

File tree

6 files changed

+162
-33
lines changed

6 files changed

+162
-33
lines changed

examples/config_qwen.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
"410m": (24, 1024, 16, 16, 4096), # ~410M params
3131
# Small to medium models
3232
"1b": (16, 2048, 16, 16, 5632), # ~1B params
33-
"3b": (28, 2048, 16, 2, 11008), # ~3B params
33+
"3b": (36, 2048, 16, 4, 11008), # ~3B params
3434
# Standard sizes
3535
"7b": (32, 4096, 32, 32, 11008), # ~7B params
3636
"13b": (40, 5120, 40, 40, 13824), # ~13B params
@@ -47,7 +47,7 @@ def get_args():
4747
parser.add_argument(
4848
"--model",
4949
choices=MODEL_SIZES.keys(),
50-
default="custom",
50+
default="3b",
5151
help="Model size to generate config for (e.g., 7b, 13b)",
5252
)
5353
parser.add_argument(
@@ -76,6 +76,10 @@ def get_args():
7676
tokens_group.add_argument("--mbs", type=int, default=3, help="Micro batch size")
7777
tokens_group.add_argument("--acc", type=int, default=1, help="Batch accumulation per replica")
7878

79+
# checkpoints
80+
checkpoints_group = parser.add_argument_group("checkpoints")
81+
checkpoints_group.add_argument("--ckpt-save", type=int, default=10, help="Checkpoint save interval")
82+
7983
args = parser.parse_args()
8084
return args
8185

@@ -108,7 +112,7 @@ def get_model_config(model_size: str) -> Qwen2Config:
108112
is_qwen2_config=True,
109113
pad_token_id=None,
110114
_attn_implementation="flash_attention_2",
111-
# sliding_window_size=20,
115+
_use_doc_masking=True,
112116
)
113117

114118

@@ -154,7 +158,7 @@ def calculate_parameters(model_config: Qwen2Config) -> str:
154158

155159
def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config:
156160
learning_rate = LRSchedulerArgs(
157-
learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5
161+
learning_rate=3e-4, lr_warmup_steps=2000, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=0
158162
)
159163
parallelism = ParallelismArgs(
160164
dp=args.dp,
@@ -175,7 +179,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
175179
)
176180
optimizer = OptimizerArgs(
177181
zero_stage=args.zero,
178-
weight_decay=0.01,
182+
weight_decay=0.1,
179183
clip_grad=1.0,
180184
accumulate_grad_in_fp32=True,
181185
learning_rate_scheduler=learning_rate,
@@ -192,7 +196,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
192196

193197
return Config(
194198
general=GeneralArgs(project="debug", run=args.run, seed=seed, ignore_sanity_checks=args.no_sanity),
195-
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10),
199+
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=args.ckpt_save),
196200
parallelism=parallelism,
197201
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
198202
# tokenizer=TokenizerArgs("HuggingFaceTB/cosmo2-tokenizer"),
@@ -219,7 +223,11 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
219223
world_size = args.dp * args.tp * args.pp * args.cp
220224
if world_size <= 8:
221225
print(
222-
f"CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}"
226+
f"ENABLE_TIMERS=1 DEBUG_CPU=1 STATS_SAMPLING_INTERVAL_IN_SEC=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}"
223227
)
228+
print("You can also use environment variables for more debugging:")
229+
print(" - ENABLE_TIMERS=1: Enable detailed timing information")
230+
print(" - DEBUG_CPU=1: Log CPU and memory usage statistics")
231+
print(" - STATS_SAMPLING_INTERVAL_IN_SEC=1: Set sampling interval for metrics collection")
224232
else:
225233
print("Checkout slurm_launcher.py to launch a multi-node job")

examples/config_qwen.yaml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
checkpoints:
2-
checkpoint_interval: 10
2+
checkpoint_interval: 100000
33
checkpoints_path: checkpoints
44
checkpoints_path_is_shared_file_system: false
55
load_lr_scheduler: true
@@ -30,9 +30,9 @@ data_stages:
3030
general:
3131
benchmark_csv_path: null
3232
consumed_train_samples: null
33-
ignore_sanity_checks: false
33+
ignore_sanity_checks: true
3434
project: debug
35-
run: qwen_20250423_201000_16423158
35+
run: qwen_20250424_120835_16423158
3636
seed: 42
3737
step: null
3838
lighteval: null
@@ -50,24 +50,24 @@ model:
5050
make_vocab_size_divisible_by: 1
5151
model_config:
5252
_attn_implementation: flash_attention_2
53-
_fused_rms_norm: false
54-
_fused_rotary_emb: false
55-
_use_doc_masking: false
56-
_use_qkv_packed: false
53+
_fused_rms_norm: true
54+
_fused_rotary_emb: true
55+
_use_doc_masking: true
56+
_use_qkv_packed: true
5757
attention_bias: false
5858
bos_token_id: 1
5959
eos_token_id: 2
6060
flex_attention_mask: null
6161
hidden_act: silu
62-
hidden_size: 256
62+
hidden_size: 2048
6363
initializer_range: 0.02
64-
intermediate_size: 768
64+
intermediate_size: 11008
6565
is_qwen2_config: true
6666
max_position_embeddings: 4096
6767
moe_config: null
6868
no_rope_layer: null
69-
num_attention_heads: 4
70-
num_hidden_layers: 12
69+
num_attention_heads: 16
70+
num_hidden_layers: 36
7171
num_key_value_heads: 4
7272
pad_token_id: null
7373
pretraining_tp: 1
@@ -108,7 +108,7 @@ parallelism:
108108
pp: 1
109109
pp_engine: 1f1b
110110
recompute_layer: false
111-
tp: 1
111+
tp: 2
112112
tp_linear_async_communication: true
113113
tp_mode: REDUCE_SCATTER
114114
tp_recompute_allgather: true

src/nanotron/config/lighteval_config.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,13 @@ class LightEvalConfig:
109109
logging: Optional[LightEvalLoggingArgs] = None
110110
wandb: Optional[LightEvalWandbLoggerConfig] = None
111111
slurm: Optional[LightEvalSlurm] = None
112-
s3_save_path: Optional[str] = None # should not be dependent of the run_name
113-
output_dir: Optional[str] = None # we should sanity check that it's the same as the one in the eval_config_override
112+
s3_save_path: Optional[str] = None # should not be dependent of the run_name
113+
upload_to_wandb: Optional[bool] = False
114+
wandb_project: Optional[str] = None
115+
wandb_entity: Optional[str] = None
116+
output_dir: Optional[
117+
str
118+
] = None # we should sanity check that it's the same as the one in the eval_config_override
114119
nanotron_path: Optional[str] = "./"
115120
eval_config_override: str = None
116121
eval_config_override: Path = None # Previously hardcoded in run_slurm_one_job
@@ -127,6 +132,12 @@ def __post_init__(self):
127132
if self.slurm is None:
128133
self.slurm = LightEvalSlurm()
129134
self.local_checkpoint_dir = str(Path(self.local_checkpoint_dir).expanduser())
135+
if self.upload_to_wandb:
136+
assert (
137+
self.s3_save_path is not None
138+
), " We should have a s3_save_path if we want to upload to wandb" # todo: add the option to read from local folder i guess
139+
assert self.wandb_project is not None, "wandb_project must be specified if upload_to_wandb is True"
140+
assert self.wandb_entity is not None, "wandb_entity must be specified if upload_to_wandb is True"
130141
if self.eval_interval_file is not None and Path(self.eval_interval_file).exists():
131142
logger.warning(
132143
f"Eval interval file {self.eval_interval_file} exists. `eval_interval` will be replaced by the value in the file upon the next evaluation. You should probably delete this file if that's not what you want."

src/nanotron/data/tokenized_bytes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,12 +369,13 @@ def __init__(
369369
)
370370
from datatrove.utils.dataset import url_to_fs
371371

372-
fs_folder, folder_path = url_to_fs(folder_path)
372+
fs_folder, stripped_folder_path = url_to_fs(folder_path)
373373
matched_files = (
374-
fs_folder.find(folder_path, detail=False, maxdepth=1 if not recursive else None)
374+
fs_folder.find(stripped_folder_path, detail=False, maxdepth=1 if not recursive else None)
375375
if not filename_pattern
376376
else fs_folder.glob(
377-
os.path.join(folder_path, filename_pattern), maxdepth=1 if not recursive else None
377+
os.path.join(stripped_folder_path, filename_pattern),
378+
maxdepth=1 if not recursive else None,
378379
)
379380
)
380381
matched_files = sorted(matched_files)

src/nanotron/eval/one_job_runner.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,18 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]:
6060
logger.warning(
6161
f"Lighteval Runner got {len(uploaded_files)} files. Using {checkpoint_path} as checkpoint path."
6262
)
63-
64-
slurm_job_id, slurm_log = run_slurm_one_job(
65-
config=self.config,
66-
lighteval_config=self.lighteval_config,
67-
model_checkpoint_path=checkpoint_path,
68-
current_step=self.config.general.step,
69-
)
63+
if self.config.general.step % self.lighteval_config.eval_interval == 0:
64+
slurm_job_id, slurm_log = run_slurm_one_job(
65+
config=self.config,
66+
lighteval_config=self.lighteval_config,
67+
model_checkpoint_path=checkpoint_path,
68+
current_step=self.config.general.step,
69+
)
70+
else:
71+
logger.warning(
72+
f"Skipping evaluation at step {self.config.general.step} because it's not a multiple of {self.lighteval_config.eval_interval}"
73+
)
74+
return None, None
7075

7176
return slurm_job_id, slurm_log
7277

@@ -130,7 +135,8 @@ def run_slurm_one_job(
130135
#SBATCH --exclusive
131136
#SBATCH --qos={slurm_config.qos}
132137
#SBATCH --time={slurm_config.time}
133-
#SBATCH --output={eval_logs_path}/%j-{timestamp}.out"""
138+
#SBATCH --output={eval_logs_path}/%j-{timestamp}.out
139+
#SBATCH --requeue"""
134140

135141
if slurm_config.reservation:
136142
slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}"
@@ -250,7 +256,23 @@ def run_slurm_one_job(
250256
--cache-dir {slurm_config.hf_cache}"""
251257
if lighteval_config.output_dir is not None and lighteval_config.s3_save_path is not None:
252258
slurm_script += f"""
253-
s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path}
259+
s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path}/
260+
"""
261+
if lighteval_config.upload_to_wandb:
262+
gbs_tok = (
263+
config.parallelism.dp
264+
* config.tokens.micro_batch_size
265+
* config.tokens.sequence_length
266+
* config.tokens.batch_accumulation_per_replica
267+
)
268+
slurm_script += f"""
269+
python {nanotron_path}/src/nanotron/eval/upload_to_wandb.py \\
270+
--wandb_project {lighteval_config.wandb_project} \\
271+
--wandb_entity {lighteval_config.wandb_entity} \\
272+
--model_name {general_run_name} \\
273+
--results_path {lighteval_config.s3_save_path}/results/results/{general_run_name}/{current_step}/ \\
274+
--train_step {current_step} \\
275+
--consumed_tokens {current_step*gbs_tok}
254276
"""
255277
slurm_script += """
256278
echo "Cleaning up downloaded checkpoints..."
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import json
2+
import s3fs
3+
import wandb
4+
import re
5+
import argparse
6+
from wandb.sdk.lib.runid import generate_id
7+
8+
9+
def push_to_wandb(wandb_project, wandb_entity, model_name, results_path, train_step, consumed_tokens):
10+
s3 = s3fs.S3FileSystem(anon=False)
11+
all_metrics = {
12+
# basic X axis replacements for all metrics
13+
"consumed_tokens": consumed_tokens,
14+
"train_step": train_step,
15+
}
16+
17+
for result_file in sorted(s3.ls(results_path)):
18+
if not result_file.endswith(".json"):
19+
continue
20+
21+
with s3.open(result_file, "r") as f:
22+
results = json.loads(f.read())["results"]
23+
24+
for benchmark, metrics in results.items():
25+
if benchmark == "all":
26+
continue
27+
28+
# extract dataset and config name
29+
match = re.search(r"\|(.*?)(?::(.*?))?\|", benchmark)
30+
if match:
31+
dataset, subtask = match.groups()
32+
33+
for metric_name, metric_value in metrics.items():
34+
if "_stderr" in metric_name:
35+
continue
36+
# wandb-friendly metric name
37+
wandb_metric = f"{dataset}/{subtask}/{metric_name}" if subtask else f"{dataset}/{metric_name}"
38+
all_metrics[wandb_metric] = metric_value
39+
40+
run_id = f"{model_name}-{generate_id()}"
41+
42+
# try to find the run in wandb and resume it
43+
api = wandb.Api()
44+
runs = api.runs(f"{wandb_entity}/{wandb_project}")
45+
for run in runs:
46+
if run.name == model_name:
47+
run_id = run.id
48+
break
49+
50+
wandb.init(
51+
project=wandb_project,
52+
entity=wandb_entity,
53+
name=model_name,
54+
id=run_id,
55+
config={
56+
"model_name": model_name,
57+
},
58+
resume="allow",
59+
)
60+
61+
# log all metrics for this checkpoint
62+
wandb.log(all_metrics)
63+
64+
wandb.finish()
65+
66+
if __name__ == "__main__":
67+
# Setup argument parser
68+
parser = argparse.ArgumentParser(description="Upload evaluation results to Weights & Biases.")
69+
parser.add_argument("--wandb_project", type=str, required=True, help="WandB project name.")
70+
parser.add_argument("--wandb_entity", type=str, required=True, help="WandB entity name.")
71+
parser.add_argument("--model_name", type=str, required=True, help="Name of the model.")
72+
parser.add_argument("--results_path", type=str, required=True, help="S3 path to the results directory.")
73+
parser.add_argument("--train_step", type=int, required=True, help="Training step corresponding to the checkpoint.")
74+
parser.add_argument("--consumed_tokens", type=int, required=True, help="Total consumed tokens up to this checkpoint.")
75+
76+
# Parse arguments
77+
args = parser.parse_args()
78+
79+
# Call the main function with parsed arguments
80+
push_to_wandb(
81+
wandb_project=args.wandb_project,
82+
wandb_entity=args.wandb_entity,
83+
model_name=args.model_name,
84+
results_path=args.results_path,
85+
train_step=args.train_step,
86+
consumed_tokens=args.consumed_tokens
87+
)

0 commit comments

Comments
 (0)