From f58474e679718effa76ebaf8d203eece3c596020 Mon Sep 17 00:00:00 2001 From: shaRk-033 <96977927+shaRk-033@users.noreply.github.com> Date: Sun, 26 Jan 2025 13:57:47 +0530 Subject: [PATCH] introduce gradient clipping --- train_gpt.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index e7826ed..69eefdc 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -367,11 +367,11 @@ def generate(self, class_labels, max_tokens=1024, temperature=1.0, top_k=None): @click.option( "--global_batch_size", default=32 * 8, help="Global batch size across all GPUs" ) -@click.option("--per_gpu_batch_size", default=32, help="Per GPU batch size") +@click.option("--per_gpu_batch_size", default=16, help="Per GPU batch size") @click.option("--num_iterations", default=6004, help="Number of training iterations") -@click.option("--learning_rate", default=1e-3, help="Learning rate") +@click.option("--learning_rate", default=3e-3, help="Learning rate") @click.option( - "--learning_rate_embed", default=1e-2, help="Learning rate for embeddings" + "--learning_rate_embed", default=3e-3, help="Learning rate for embeddings" ) @click.option("--weight_decay", default=0.1, help="Weight decay") @click.option("--warmup_iters", default=10, help="Warmup iterations") @@ -381,7 +381,10 @@ def generate(self, class_labels, max_tokens=1024, temperature=1.0, top_k=None): @click.option("--n_embed", default=768, help="Embedding dimension") @click.option("--init_ckpt", default=None, help="Path to initial checkpoint") @click.option("--vres", default=False, help="Use vres") -@click.option("--n_layer", default=12, help="Number of layers") +@click.option("--n_layer", default=8, help="Number of layers") +@click.option( + "--max_grad_norm", default=1.0, help="Maximum gradient norm for clipping" +) def train( run_name, train_data, @@ -400,6 +403,7 @@ def train( init_ckpt, vres, n_layer, + max_grad_norm, # Added parameter ): dist.init_process_group(backend="nccl") ddp_rank = int(os.environ["RANK"]) @@ -502,7 +506,7 @@ def train( weight_decay=weight_decay, fused=True, ) - + # enable_cudnn_sdp(True) enable_flash_sdp(True) # enable_mem_efficient_sdp(False) @@ -574,6 +578,7 @@ def get_lr(it): logits, loss = model(batch["input_ids"], batch["class_label"]) loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # add gradient clipping optimizer.step() optimizer.zero_grad() @@ -608,7 +613,6 @@ def get_lr(it): if master_process and step % save_every == 0: checkpoint = { "model": model.module.state_dict(), - "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "step": step, "config": model.module.config,