You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I trained some Llama models using a small chunk of DCLM (a pretraining corpus). I tokenized it using the OLMo tokenizer. I took 2 intermediate checkpoints at steps 6849 and 2403 and exported them to HF. When I now load them via HF to run inference I see that the model repeatedly keeps generating the same |||IP_ADDRESS||| token. What could be going wrong?
Update: See my comment below. It seems the export is broken as the model works well when loading via NeMo
Train Script -
from typing import Callable, Optional
import lightning.pytorch as pl
import nemo_run as run
import torch
from lightning.pytorch.callbacks.callback import Callback
from megatron.core.distributed import DistributedDataParallelConfig
import time
from nemo import lightning as nl
from nemo.collections.llm.api import pretrain
from nemo.collections import llm
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.model.llama import Llama32Config1B, LlamaModel
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger, wandb_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from lightning.pytorch.callbacks import Callback
from nemo.utils import logging, timers
from nemo.utils.exp_manager import TimingCallback
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
from nemo.lightning import AutoResume
import wandb
import argparse
from dotenv import load_dotenv
import os
load_dotenv()
# Global Batch Size Related
from nemo.utils.import_utils import safe_import_from
get_current_global_batch_size, HAVE_MCORE_MBATCH_CALCULATOR = safe_import_from(
"megatron.core.num_microbatches_calculator", "get_current_global_batch_size"
)
wandb.login(key=os.environ.get("WANDB_API_KEY"))
SEQUENCE_LENGTH = 2048
PER_GPU_BATCH_SIZE = 16
GRADIENT_ACCUMULATION_STEPS = 4
NUM_GPUS = 8
NUM_NODES = 1
TIE_WEIGHTS = True
GLOBAL_BATCH_SIZE = PER_GPU_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS*NUM_GPUS*NUM_NODES
NAME = "llama32_1b_dclm" + "-SL-" + str(SEQUENCE_LENGTH) + "-PGBS-" + str(PER_GPU_BATCH_SIZE) + "-GAS-" + str(GRADIENT_ACCUMULATION_STEPS) + "-NGPU-" + str(NUM_GPUS) + "-NNODES-" + str(NUM_NODES)
if TIE_WEIGHTS:
NAME += "-TW"
def dclm(
gbs: int = 256,
mbs: int = 4,
seq_length: int = 8192,
) -> run.Config[pl.LightningDataModule]:
return run.Config(
llm.PreTrainingDataModule,
paths=["Data/dclm_megatron/concatenated.jsonl_text_document"],
seq_length=seq_length,
global_batch_size=gbs,
micro_batch_size=mbs,
tokenizer=run.Config(AutoTokenizer, pretrained_model_name="allenai/OLMo-1B-hf", vocab_file="Data/tokenizer/tokenizer_config.json", use_fast=True),
# tokenizer=AutoTokenizer(pretrained_model_name="allenai/OLMo-1B-hf", vocab_file="Data/tokenizer/tokenizer_config.json", use_fast=True),
split="99,8,2",
num_workers=2,
index_mapping_dir="Data/index_mapping",
)
@run.cli.factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a Llama3.2 3B model configuration.
Returns:
run.Config[pl.LightningModule]: Configuration for the Llama3.2 3B model.
Examples:
CLI usage:
$ nemo llm pretrain model=llama32_1b ...
Python API usage:
>>> model_config = model()
>>> print(model_config)
"""
conf = run.Config(Llama32Config1B)
conf.seq_length = SEQUENCE_LENGTH
conf.share_embeddings_and_output_weights = TIE_WEIGHTS
return run.Config(LlamaModel, config=conf)
def trainer(
tensor_parallelism: int = 1,
pipeline_parallelism: int = 1,
pipeline_parallelism_type: Optional[torch.dtype] = None,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 1,
sequence_parallelism: bool = False,
num_nodes: int = 1,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for Llama3.2 3B model.
Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
max_steps (int): Maximum number of training steps.
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.
Returns:
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.
Examples:
CLI usage:
$ nemo llm pretrain trainer=llama32_1b ...
Python API usage:
>>> trainer_config = trainer(num_nodes=1, num_gpus_per_node=1)
>>> print(trainer_config)
Note:
This configuration uses extensive parallelism to handle the large model size efficiently.
"""
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
pipeline_dtype=pipeline_parallelism_type,
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
context_parallel_size=context_parallelism,
sequence_parallel=sequence_parallelism,
gradient_as_bucket_view=True,
ckpt_async_save=True,
ckpt_parallel_load=True,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
),
)
trainer = run.Config(
nl.Trainer,
accelerator="gpu",
accumulate_grad_batches=GRADIENT_ACCUMULATION_STEPS,
callbacks=callbacks,
devices=num_gpus_per_node,
# limit_test_batches=50,
limit_val_batches=32,
log_every_n_steps=10,
max_steps=max_steps,
num_nodes=num_nodes,
plugins=bf16_mixed(),
strategy=strategy,
use_distributed_sampler=False,
val_check_interval=500,
)
return trainer
@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
fn: Callable = pretrain,
) -> run.Partial:
"""
Create a pre-training recipe for Llama3.2 3B model.
This function sets up a complete configuration for pre-training, including
model, trainer, data, logging, optimization, and resumption settings.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory llama32_1b
$ nemo llm pretrain --factory "llama32_1b(num_nodes=1, name='my_3b_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="llama32_1b_pretrain", num_nodes=1)
>>> print(recipe)
Note:
This recipe is optimized for the large 8B model and requires significant computational resources.
"""
recipe = run.Partial(
fn,
model=model(),
trainer=trainer(
num_nodes=num_nodes,
num_gpus_per_node=num_gpus_per_node,
callbacks=[
run.Config(TimingCallback, log_tokens_per_sec=True),
run.Config(ModelCheckpoint, save_last=False, monitor="val_loss", save_top_k=2, every_n_train_steps=500)
],
),
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name), wandb_logger=wandb_logger(project="Nemo_Testing", name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
resume=default_resume(),
)
return recipe
# Executor for running pretraining
def local_executor_torchrun(devices: int = 2) -> run.LocalExecutor:
executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun")
return executor
# This condition is necessary for the script to be compatible with Python's multiprocessing module.
if __name__ == "__main__":
curr_date_time = time.strftime("%Y-%m-%d-%H-%M-%S")
print(NAME + "-" + curr_date_time)
recipe = pretrain_recipe(name=NAME, num_nodes=NUM_NODES, num_gpus_per_node=NUM_GPUS, dir="Checkpoints/" + NAME + "-" + curr_date_time)
recipe.data = dclm(gbs=GLOBAL_BATCH_SIZE, mbs=PER_GPU_BATCH_SIZE, seq_length=SEQUENCE_LENGTH)
executor = local_executor_torchrun(devices=NUM_GPUS)
run.run(recipe, executor=executor)
aflah02
changed the title
Exported Llama Models Trained Using NeMo Generate The Same Token
Exported Llama Models Trained Using NeMo Generate The Same Token Repeatedly
Feb 17, 2025
Hi
I trained some Llama models using a small chunk of DCLM (a pretraining corpus). I tokenized it using the OLMo tokenizer. I took 2 intermediate checkpoints at steps 6849 and 2403 and exported them to HF. When I now load them via HF to run inference I see that the model repeatedly keeps generating the same
|||IP_ADDRESS|||
token. What could be going wrong?Update: See my comment below. It seems the export is broken as the model works well when loading via NeMo
Train Script -
Here are the training output and error logs - https://gist.github.com/aflah02/6a28bec7d907179eb2f75c5eac6bb913 & https://gist.github.com/aflah02/1fdadd4e10115de8c3de416f3a63e653
This is the conversion script that I use -
The text was updated successfully, but these errors were encountered: