Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exported Llama Models Trained Using NeMo Generate The Same Token Repeatedly #12212

Open
aflah02 opened this issue Feb 17, 2025 · 1 comment
Open
Labels
bug Something isn't working

Comments

@aflah02
Copy link

aflah02 commented Feb 17, 2025

Hi

I trained some Llama models using a small chunk of DCLM (a pretraining corpus). I tokenized it using the OLMo tokenizer. I took 2 intermediate checkpoints at steps 6849 and 2403 and exported them to HF. When I now load them via HF to run inference I see that the model repeatedly keeps generating the same |||IP_ADDRESS||| token. What could be going wrong?

Update: See my comment below. It seems the export is broken as the model works well when loading via NeMo

Train Script -

from typing import Callable, Optional

import lightning.pytorch as pl
import nemo_run as run
import torch
from lightning.pytorch.callbacks.callback import Callback
from megatron.core.distributed import DistributedDataParallelConfig
import time
from nemo import lightning as nl
from nemo.collections.llm.api import pretrain
from nemo.collections import llm
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.model.llama import Llama32Config1B, LlamaModel
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger, wandb_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from lightning.pytorch.callbacks import Callback
from nemo.utils import logging, timers
from nemo.utils.exp_manager import TimingCallback
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
from nemo.lightning import AutoResume
import wandb
import argparse
from dotenv import load_dotenv
import os

load_dotenv()

# Global Batch Size Related
from nemo.utils.import_utils import safe_import_from


get_current_global_batch_size, HAVE_MCORE_MBATCH_CALCULATOR = safe_import_from(
    "megatron.core.num_microbatches_calculator", "get_current_global_batch_size"
)

wandb.login(key=os.environ.get("WANDB_API_KEY"))

SEQUENCE_LENGTH = 2048
PER_GPU_BATCH_SIZE = 16
GRADIENT_ACCUMULATION_STEPS = 4

NUM_GPUS = 8
NUM_NODES = 1

TIE_WEIGHTS = True

GLOBAL_BATCH_SIZE = PER_GPU_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS*NUM_GPUS*NUM_NODES

NAME = "llama32_1b_dclm" + "-SL-" + str(SEQUENCE_LENGTH) + "-PGBS-" + str(PER_GPU_BATCH_SIZE) + "-GAS-" + str(GRADIENT_ACCUMULATION_STEPS) + "-NGPU-" + str(NUM_GPUS) + "-NNODES-" + str(NUM_NODES)

if TIE_WEIGHTS:
    NAME += "-TW"


def dclm(
    gbs: int = 256,
    mbs: int = 4,
    seq_length: int = 8192,
) -> run.Config[pl.LightningDataModule]:

    return run.Config(
        llm.PreTrainingDataModule,
        paths=["Data/dclm_megatron/concatenated.jsonl_text_document"],
        seq_length=seq_length,
        global_batch_size=gbs,
        micro_batch_size=mbs,
        tokenizer=run.Config(AutoTokenizer, pretrained_model_name="allenai/OLMo-1B-hf", vocab_file="Data/tokenizer/tokenizer_config.json", use_fast=True),
        # tokenizer=AutoTokenizer(pretrained_model_name="allenai/OLMo-1B-hf", vocab_file="Data/tokenizer/tokenizer_config.json", use_fast=True),
        split="99,8,2",
        num_workers=2,
        index_mapping_dir="Data/index_mapping",
    )


@run.cli.factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
    """
    Factory function to create a Llama3.2 3B model configuration.

    Returns:
        run.Config[pl.LightningModule]: Configuration for the Llama3.2 3B model.

    Examples:
        CLI usage:
            $ nemo llm pretrain model=llama32_1b ...

        Python API usage:
            >>> model_config = model()
            >>> print(model_config)
    """
    conf = run.Config(Llama32Config1B)
    conf.seq_length = SEQUENCE_LENGTH
    conf.share_embeddings_and_output_weights = TIE_WEIGHTS
    return run.Config(LlamaModel, config=conf)


def trainer(
    tensor_parallelism: int = 1,
    pipeline_parallelism: int = 1,
    pipeline_parallelism_type: Optional[torch.dtype] = None,
    virtual_pipeline_parallelism: Optional[int] = None,
    context_parallelism: int = 1,
    sequence_parallelism: bool = False,
    num_nodes: int = 1,
    num_gpus_per_node: int = 8,
    max_steps: int = 1168251,
    callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
    """
    Configure the NeMo Lightning Trainer for Llama3.2 3B model.

    Args:
        tensor_parallelism (int): Degree of tensor model parallelism.
        pipeline_parallelism (int): Degree of pipeline model parallelism.
        pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
        virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
        context_parallelism (int): Degree of context parallelism.
        sequence_parallelism (bool): Whether to use sequence parallelism.
        num_nodes (int): Number of compute nodes to use.
        num_gpus_per_node (int): Number of GPUs per node.
        max_steps (int): Maximum number of training steps.
        callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.

    Returns:
        run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.

    Examples:
        CLI usage:
            $ nemo llm pretrain trainer=llama32_1b ...

        Python API usage:
            >>> trainer_config = trainer(num_nodes=1, num_gpus_per_node=1)
            >>> print(trainer_config)

    Note:
        This configuration uses extensive parallelism to handle the large model size efficiently.
    """
    strategy = run.Config(
        nl.MegatronStrategy,
        tensor_model_parallel_size=tensor_parallelism,
        pipeline_model_parallel_size=pipeline_parallelism,
        pipeline_dtype=pipeline_parallelism_type,
        virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
        context_parallel_size=context_parallelism,
        sequence_parallel=sequence_parallelism,
        gradient_as_bucket_view=True,
        ckpt_async_save=True,
        ckpt_parallel_load=True,
        ddp=run.Config(
            DistributedDataParallelConfig,
            check_for_nan_in_grad=True,
            grad_reduce_in_fp32=True,
            overlap_grad_reduce=True,
            overlap_param_gather=True,
            average_in_collective=True,
        ),
    )

    trainer = run.Config(
        nl.Trainer,
        accelerator="gpu",
        accumulate_grad_batches=GRADIENT_ACCUMULATION_STEPS,
        callbacks=callbacks,
        devices=num_gpus_per_node,
        # limit_test_batches=50,
        limit_val_batches=32,
        log_every_n_steps=10,
        max_steps=max_steps,
        num_nodes=num_nodes,
        plugins=bf16_mixed(),
        strategy=strategy,
        use_distributed_sampler=False,
        val_check_interval=500,
    )

    return trainer


@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
    dir: Optional[str] = None,
    name: str = "default",
    num_nodes: int = 1,
    num_gpus_per_node: int = 8,
    fn: Callable = pretrain,
) -> run.Partial:
    """
    Create a pre-training recipe for Llama3.2 3B model.

    This function sets up a complete configuration for pre-training, including
    model, trainer, data, logging, optimization, and resumption settings.

    Args:
        dir (Optional[str]): Directory for saving logs and checkpoints.
        name (str): Name of the pre-training run.
        num_nodes (int): Number of compute nodes to use.
        num_gpus_per_node (int): Number of GPUs per node.
        fn (Callable): The pre-training function to use.

    Returns:
        run.Partial: Partial configuration for pre-training.

    Examples:
        CLI usage:
            $ nemo llm pretrain --factory llama32_1b
            $ nemo llm pretrain --factory "llama32_1b(num_nodes=1, name='my_3b_pretrain')"

        Python API usage:
            >>> recipe = pretrain_recipe(name="llama32_1b_pretrain", num_nodes=1)
            >>> print(recipe)

    Note:
        This recipe is optimized for the large 8B model and requires significant computational resources.
    """

    recipe = run.Partial(
        fn,
        model=model(),
        trainer=trainer(
            num_nodes=num_nodes,
            num_gpus_per_node=num_gpus_per_node,
            callbacks=[
                run.Config(TimingCallback, log_tokens_per_sec=True), 
                run.Config(ModelCheckpoint, save_last=False, monitor="val_loss", save_top_k=2, every_n_train_steps=500)
            ],
        ),
        log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name), wandb_logger=wandb_logger(project="Nemo_Testing", name=name)),
        optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
        resume=default_resume(),
    )

    

    return recipe

# Executor for running pretraining 
def local_executor_torchrun(devices: int = 2) -> run.LocalExecutor:
    executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun")
    return executor

# This condition is necessary for the script to be compatible with Python's multiprocessing module.
if __name__ == "__main__":
    curr_date_time = time.strftime("%Y-%m-%d-%H-%M-%S")
    print(NAME + "-" + curr_date_time)
    recipe = pretrain_recipe(name=NAME, num_nodes=NUM_NODES, num_gpus_per_node=NUM_GPUS, dir="Checkpoints/" + NAME + "-" + curr_date_time)
    recipe.data = dclm(gbs=GLOBAL_BATCH_SIZE, mbs=PER_GPU_BATCH_SIZE, seq_length=SEQUENCE_LENGTH)
    executor = local_executor_torchrun(devices=NUM_GPUS)
    run.run(recipe, executor=executor)

Here are the training output and error logs - https://gist.github.com/aflah02/6a28bec7d907179eb2f75c5eac6bb913 & https://gist.github.com/aflah02/1fdadd4e10115de8c3de416f3a63e653

This is the conversion script that I use -

from pathlib import Path

from nemo.collections.llm import export_ckpt

if __name__ == "__main__":
    export_ckpt(
        path=Path("Checkpoints/llama32_1b_dclm-SL-2048-PGBS-16-GAS-4-NGPU-8-NNODES-1-TW-2025-02-15-14-49-51/llama32_1b_dclm-SL-2048-PGBS-16-GAS-4-NGPU-8-NNODES-1-TW/checkpoints/model_name=0--val_loss=6.88-step=6849-consumed_samples=3507200.0-last"),
        target="hf",
        output_path=Path("Exported/llama32_1b_dclm-SL-2048-PGBS-16-GAS-4-NGPU-8-NNODES-1-TW-step=6849"),
    )
@aflah02 aflah02 added the bug Something isn't working label Feb 17, 2025
@aflah02 aflah02 changed the title Exported Llama Models Trained Using NeMo Generate The Same Token Exported Llama Models Trained Using NeMo Generate The Same Token Repeatedly Feb 17, 2025
@aflah02
Copy link
Author

aflah02 commented Feb 17, 2025

Note if I run inference via NeMo it works so the HF export seems to be broken. My code for infference via NeMo -

import json
import warnings
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from nemo.collections.llm import api
import lightning.pytorch as pl
import nemo_run as run
import torch
from megatron.core import parallel_state
from rich.console import Console
from torch.distributed import all_gather_object
from typing_extensions import Annotated

import nemo.lightning as nl
from nemo.collections.llm.distillation import DistillationGPTModel
from nemo.collections.llm.evaluation.api import EvaluationConfig, EvaluationTarget
from nemo.collections.llm.gpt.model import GPTModel
from nemo.collections.llm.quantization import ExportConfig, QuantizationConfig
from nemo.lightning import (
    AutoResume,
    NeMoLogger,
    OptimizerModule,
    Trainer,
    configure_no_restart_validation_training_loop,
    io,
)
from nemo.lightning.base import NEMO_MODELS_CACHE
from nemo.lightning.pytorch.callbacks import PEFT, JitTransform, ModelTransform
from nemo.utils import logging
from nemo.utils.get_rank import is_global_rank_zero

from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest

strategy = nl.MegatronStrategy(
    tensor_model_parallel_size=1,
    pipeline_model_parallel_size=1,
    context_parallel_size=1,
    sequence_parallel=False,
    setup_optimizers=False,
    store_optimizer_states=False,
)

trainer = nl.Trainer(
    accelerator="gpu",
    devices=1,
    num_nodes=1,
    strategy=strategy,
    plugins=nl.MegatronMixedPrecision(
        precision="bf16-mixed",
        params_dtype=torch.bfloat16,
        pipeline_dtype=torch.bfloat16,
        autocast_enabled=False,
        grad_reduce_in_fp32=False,
    ),
)
prompts = [
    "The speed of light in vaccum is",
]

if __name__ == "__main__":
    results = api.generate(
        path="Checkpoints/llama32_1b_dclm-SL-2048-PGBS-16-GAS-4-NGPU-8-NNODES-1-TW-2025-02-15-14-49-51/llama32_1b_dclm-SL-2048-PGBS-16-GAS-4-NGPU-8-NNODES-1-TW/checkpoints/model_name=0--val_loss=7.10-step=7249-consumed_samples=3712000.0-last",
        prompts=prompts,
        trainer=trainer,
        inference_params=CommonInferenceParams(temperature=0.1, top_k=10, num_tokens_to_generate=20),
        text_only=True,
    )

    print(results)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant