3030 "410m" : (24 , 1024 , 16 , 16 , 4096 ), # ~410M params
3131 # Small to medium models
3232 "1b" : (16 , 2048 , 16 , 16 , 5632 ), # ~1B params
33- "3b" : (28 , 2048 , 16 , 2 , 11008 ), # ~3B params
33+ "3b" : (36 , 2048 , 16 , 4 , 11008 ), # ~3B params
3434 # Standard sizes
3535 "7b" : (32 , 4096 , 32 , 32 , 11008 ), # ~7B params
3636 "13b" : (40 , 5120 , 40 , 40 , 13824 ), # ~13B params
@@ -47,7 +47,7 @@ def get_args():
4747 parser .add_argument (
4848 "--model" ,
4949 choices = MODEL_SIZES .keys (),
50- default = "custom " ,
50+ default = "3b " ,
5151 help = "Model size to generate config for (e.g., 7b, 13b)" ,
5252 )
5353 parser .add_argument (
@@ -76,6 +76,10 @@ def get_args():
7676 tokens_group .add_argument ("--mbs" , type = int , default = 3 , help = "Micro batch size" )
7777 tokens_group .add_argument ("--acc" , type = int , default = 1 , help = "Batch accumulation per replica" )
7878
79+ # checkpoints
80+ checkpoints_group = parser .add_argument_group ("checkpoints" )
81+ checkpoints_group .add_argument ("--ckpt-save" , type = int , default = 10 , help = "Checkpoint save interval" )
82+
7983 args = parser .parse_args ()
8084 return args
8185
@@ -108,7 +112,7 @@ def get_model_config(model_size: str) -> Qwen2Config:
108112 is_qwen2_config = True ,
109113 pad_token_id = None ,
110114 _attn_implementation = "flash_attention_2" ,
111- # sliding_window_size=20 ,
115+ _use_doc_masking = True ,
112116 )
113117
114118
@@ -154,7 +158,7 @@ def calculate_parameters(model_config: Qwen2Config) -> str:
154158
155159def create_config (model_config : Qwen2Config , args : argparse .Namespace ) -> Config :
156160 learning_rate = LRSchedulerArgs (
157- learning_rate = 3e-4 , lr_warmup_steps = 2 , lr_warmup_style = "linear" , lr_decay_style = "cosine" , min_decay_lr = 1e-5
161+ learning_rate = 3e-4 , lr_warmup_steps = 2000 , lr_warmup_style = "linear" , lr_decay_style = "cosine" , min_decay_lr = 0
158162 )
159163 parallelism = ParallelismArgs (
160164 dp = args .dp ,
@@ -175,7 +179,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
175179 )
176180 optimizer = OptimizerArgs (
177181 zero_stage = args .zero ,
178- weight_decay = 0.01 ,
182+ weight_decay = 0.1 ,
179183 clip_grad = 1.0 ,
180184 accumulate_grad_in_fp32 = True ,
181185 learning_rate_scheduler = learning_rate ,
@@ -192,7 +196,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
192196
193197 return Config (
194198 general = GeneralArgs (project = "debug" , run = args .run , seed = seed , ignore_sanity_checks = args .no_sanity ),
195- checkpoints = CheckpointsArgs (checkpoints_path = checkpoints_path , checkpoint_interval = 10 ),
199+ checkpoints = CheckpointsArgs (checkpoints_path = checkpoints_path , checkpoint_interval = args . ckpt_save ),
196200 parallelism = parallelism ,
197201 model = ModelArgs (init_method = RandomInit (std = 0.025 ), model_config = model_config ),
198202 # tokenizer=TokenizerArgs("HuggingFaceTB/cosmo2-tokenizer"),
@@ -219,7 +223,11 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config
219223 world_size = args .dp * args .tp * args .pp * args .cp
220224 if world_size <= 8 :
221225 print (
222- f"CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={ world_size } run_train.py --config-file { args .out } "
226+ f"ENABLE_TIMERS=1 DEBUG_CPU=1 STATS_SAMPLING_INTERVAL_IN_SEC=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={ world_size } run_train.py --config-file { args .out } "
223227 )
228+ print ("You can also use environment variables for more debugging:" )
229+ print (" - ENABLE_TIMERS=1: Enable detailed timing information" )
230+ print (" - DEBUG_CPU=1: Log CPU and memory usage statistics" )
231+ print (" - STATS_SAMPLING_INTERVAL_IN_SEC=1: Set sampling interval for metrics collection" )
224232 else :
225233 print ("Checkout slurm_launcher.py to launch a multi-node job" )
0 commit comments