From 1ca42e39a5fb9ed5561a67785d25e68c597a7bd2 Mon Sep 17 00:00:00 2001 From: nouamanetazi Date: Mon, 14 Apr 2025 14:16:59 +0000 Subject: [PATCH 01/10] can only merge to main from dev --- .github/PULL_REQUEST_TEMPLATE.md | 7 ++++++- .github/workflows/pr-rules.yaml | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/pr-rules.yaml diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a776daf2..4a5ebde4 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,3 +1,8 @@ + + # What does this PR do? \ No newline at end of file + diff --git a/.github/workflows/pr-rules.yaml b/.github/workflows/pr-rules.yaml new file mode 100644 index 00000000..b82d61ba --- /dev/null +++ b/.github/workflows/pr-rules.yaml @@ -0,0 +1,15 @@ +name: Check PR Source Branch +on: + pull_request: + branches: + - main + +jobs: + check-branch: + runs-on: ubuntu-latest + steps: + - name: Check PR source branch + if: github.base_ref == 'main' && github.head_ref != 'dev' + run: | + echo "ERROR: PRs to main must come from dev branch" + exit 1 From 0dbf24dc8908ad1b734f70e79e5e6728f9bcfbef Mon Sep 17 00:00:00 2001 From: Connector Switch Date: Mon, 14 Apr 2025 22:20:19 +0800 Subject: [PATCH 02/10] Fix UnBoundLocalError in `clm_collator.py` (#339) * Update clm_collator.py * can only merge to main from dev (#348) --------- Co-authored-by: Nouamane Tazi --- src/nanotron/data/clm_collator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nanotron/data/clm_collator.py b/src/nanotron/data/clm_collator.py index 5c141adf..89fd0083 100644 --- a/src/nanotron/data/clm_collator.py +++ b/src/nanotron/data/clm_collator.py @@ -97,6 +97,7 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni result["label_mask"] = np.ones((batch_size, self.sequence_length), dtype=np.bool_) # Context Parallelism: Each CP rank gets a slice of the label_ids and label_mask + cp_rank, cp_size = dist.get_rank(self.parallel_context.cp_pg), self.parallel_context.context_parallel_size local_slice = slice( cp_rank * self.sequence_length // cp_size, (cp_rank + 1) * self.sequence_length // cp_size ) From 0a04f347d3ff7a6ca6eefa3890488028e7cedac0 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Fri, 18 Apr 2025 16:02:23 +0100 Subject: [PATCH 03/10] fix init and init scaling factor and run evals in background (#349) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * InitScalingMethod * InitScalingMethod * run evals in background (#352) * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes --------- Co-authored-by: elie <97572401+eliebak@users.noreply.github.com> Co-authored-by: “eliebak” --------- Co-authored-by: elie <97572401+eliebak@users.noreply.github.com> Co-authored-by: “eliebak” --- src/nanotron/config/config.py | 14 +- src/nanotron/config/lighteval_config.py | 48 +++- src/nanotron/config/models_config.py | 14 +- src/nanotron/config/utils_config.py | 9 + src/nanotron/eval/README.md | 13 + src/nanotron/eval/__init__.py | 3 + src/nanotron/eval/evaluation_tasks.py | 368 ++++++++++++++++++++++++ src/nanotron/eval/one_job_runner.py | 360 +++++++++++++++++++++++ src/nanotron/logging/base.py | 2 +- src/nanotron/models/qwen.py | 2 +- src/nanotron/s3_checkpoints/s3_mover.py | 2 +- src/nanotron/scaling/parametrization.py | 33 ++- src/nanotron/serialize/main.py | 2 +- src/nanotron/trainer.py | 112 ++++++-- 14 files changed, 933 insertions(+), 49 deletions(-) create mode 100644 src/nanotron/eval/README.md create mode 100644 src/nanotron/eval/__init__.py create mode 100644 src/nanotron/eval/evaluation_tasks.py create mode 100644 src/nanotron/eval/one_job_runner.py diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 4a847209..c16f076c 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -17,6 +17,7 @@ from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit from nanotron.config.parallelism_config import ParallelismArgs from nanotron.config.utils_config import ( + InitScalingMethod, RecomputeGranularity, cast_str_to_pipeline_engine, cast_str_to_torch_dtype, @@ -460,6 +461,13 @@ def __post_init__(self): if self.s3_upload is not None: self.s3_upload.__post_init__() + if self.lighteval is not None: + if self.lighteval.eval_interval is None: + self.lighteval.eval_interval = self.checkpoints.checkpoint_interval + else: + assert ( + self.lighteval.eval_interval % self.checkpoints.checkpoint_interval == 0 + ), f"eval_interval={self.lighteval.eval_interval} must be a multiple of checkpoint_interval={self.checkpoints.checkpoint_interval}" # Some final sanity checks across separate arguments sections: if self.profiler is not None and self.profiler.profiler_export_path is not None: @@ -542,14 +550,15 @@ def global_batch_size(self): def global_batch_size_in_tokens(self): return self.global_batch_size * self.tokens.sequence_length - def save_as_yaml(self, file_path: str): + def save_as_yaml(self, file_path: str, sanity_checks: bool = True): config_dict = serialize(self) file_path = str(file_path) with open(file_path, "w") as f: yaml.dump(config_dict, f) # Sanity test config can be reloaded - _ = get_config_from_file(file_path, config_class=self.__class__) + if sanity_checks: + _ = get_config_from_file(file_path, config_class=self.__class__) def get_yaml(self): config_dict = serialize(self) @@ -620,6 +629,7 @@ def get_config_from_dict( PipelineEngine: cast_str_to_pipeline_engine, TensorParallelLinearMode: lambda x: TensorParallelLinearMode[x.upper()], RecomputeGranularity: lambda x: RecomputeGranularity[x.upper()], + InitScalingMethod: lambda x: InitScalingMethod[x.upper()], SamplerType: lambda x: SamplerType[x.upper()], }, # strict_unions_match=True, diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index b5f12059..363ee988 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -73,6 +73,22 @@ def __post_init__(self): assert self.wandb_project != "", "Please specify a wandb_project" +@dataclass +class LightEvalSlurm: + """Arguments related to SLURM configuration for LightEval""" + + gpus_per_node: int = 8 + partition: str = "hopper-prod" + hf_cache: str = "~/.cache/huggingface" + cpus_per_task: int = 88 + qos: str = "low" + time: str = "24:00:00" + reservation: Optional[str] = "smollm" + + def __post_init__(self): + self.hf_cache = str(Path(self.hf_cache).expanduser()) + + @dataclass class LightEvalConfig: """Arguments related to running LightEval on checkpoints. @@ -81,13 +97,37 @@ class LightEvalConfig: the saved config when running LightEval after training. """ - slurm_template: Optional[str] = None - slurm_script_dir: Optional[str] = None - - checkpoints_path: Optional[str] = None + slurm_script_dir: Optional[Path] = Path("eval_results/launch-config") + logs_path: Optional[Path] = Path("eval_results/logs") + local_checkpoint_dir: Path = Path( + "/scratch" + ) # Base directory for temporary checkpoint storage, will store under {local_checkpoint_dir}/{run_name}/{step} parallelism: Optional[ParallelismArgs] = None batch_size: Optional[int] = None generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None tasks: Optional[LightEvalTasksArgs] = None logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None + slurm: Optional[LightEvalSlurm] = None + s3_save_path: Optional[str] = None # should not be dependent of the run_name + output_dir: Optional[str] = None # we should sanity check that it's the same as the one in the eval_config_override + nanotron_path: Optional[str] = "./" + eval_config_override: str = None + eval_config_override: Path = None # Previously hardcoded in run_slurm_one_job + eval_interval: Optional[ + int + ] = None # Must be multiple of checkpoint_interval. If None, eval will be done after each checkpoint upload to s3 + eval_interval_file: Optional[ + Path + ] = None # If specified, eval_interval will be read from this file upon the next evaluation. + + def __post_init__(self): + if self.parallelism is None: + self.parallelism = ParallelismArgs(dp=1, pp=1, tp=1, tp_linear_async_communication=True) + if self.slurm is None: + self.slurm = LightEvalSlurm() + self.local_checkpoint_dir = str(Path(self.local_checkpoint_dir).expanduser()) + if self.eval_interval_file is not None and Path(self.eval_interval_file).exists(): + logger.warning( + f"Eval interval file {self.eval_interval_file} exists. `eval_interval` will be replaced by the value in the file upon the next evaluation. You should probably delete this file if that's not what you want." + ) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 410634b8..dd575e39 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any, List, Optional, Union +from nanotron.config.utils_config import InitScalingMethod from nanotron.nn.attention import ALL_ATTENTION_FUNCTIONS, AttentionImplementation # The default attention implementation to use @@ -11,6 +12,7 @@ @dataclass class RandomInit: std: float + scaling_method: InitScalingMethod = InitScalingMethod.NUM_LAYERS @dataclass @@ -141,11 +143,13 @@ class Qwen2Config: sliding_window_size: Optional[int] = None z_loss_enabled: bool = False # Z-loss regularization https://www.jmlr.org/papers/volume24/22-1144/22-1144.pdf z_loss_coefficient: float = 0.0001 # Default from the paper (10^-4) - no_rope_layer: Optional[int] = None # Skip rope every no_rope_layer layers (see https://arxiv.org/abs/2501.18795 https://arxiv.org/abs/2305.19466 and Llama4) - _fused_rotary_emb: bool = True - _fused_rms_norm: bool = True - _use_qkv_packed: bool = True - _use_doc_masking: bool = True + no_rope_layer: Optional[ + int + ] = None # Skip rope every no_rope_layer layers (see https://arxiv.org/abs/2501.18795 https://arxiv.org/abs/2305.19466 and Llama4) + _fused_rotary_emb: bool = False + _fused_rms_norm: bool = False + _use_qkv_packed: bool = False + _use_doc_masking: bool = False # MoE configuration moe_config: Optional[MoEConfig] = None diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py index f4c07146..84e8079a 100644 --- a/src/nanotron/config/utils_config.py +++ b/src/nanotron/config/utils_config.py @@ -18,6 +18,13 @@ class RecomputeGranularity(Enum): FULL = auto() +class InitScalingMethod(Enum): + NONE = auto() + NUM_LAYERS = auto() + LAYER_INDEX = auto() + MODEL_SCALE = auto() + + def serialize(data) -> dict: """Recursively serialize a nested dataclass to a dict - do some type conversions along the way""" if data is None: @@ -39,6 +46,8 @@ def serialize(data) -> dict: result[field.name] = value.name elif isinstance(value, RecomputeGranularity): result[field.name] = value.name + elif isinstance(value, InitScalingMethod): + result[field.name] = value.name elif isinstance(value, SamplerType): result[field.name] = value.name elif isinstance(value, torch.dtype): diff --git a/src/nanotron/eval/README.md b/src/nanotron/eval/README.md new file mode 100644 index 00000000..05bfe162 --- /dev/null +++ b/src/nanotron/eval/README.md @@ -0,0 +1,13 @@ +# Nanotron Evaluation + +This directory contains code for evaluating models trained with Nanotron. + +## Installation + +To use the evaluation functionality, you need to install the `lighteval` package: + +```bash +uv pip install lighteval[dev] +``` + +## Usage diff --git a/src/nanotron/eval/__init__.py b/src/nanotron/eval/__init__.py new file mode 100644 index 00000000..d7ea002c --- /dev/null +++ b/src/nanotron/eval/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa: F401 + +from .one_job_runner import LightEvalRunner diff --git a/src/nanotron/eval/evaluation_tasks.py b/src/nanotron/eval/evaluation_tasks.py new file mode 100644 index 00000000..2543df31 --- /dev/null +++ b/src/nanotron/eval/evaluation_tasks.py @@ -0,0 +1,368 @@ +from functools import partial + +from lighteval.metrics.dynamic_metrics import ( + loglikelihood_acc_metric, +) +from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm +from lighteval.tasks.default_prompts import LETTER_INDICES +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.multilingual.adapters import ( + winogrand_adapter, +) +from lighteval.tasks.multilingual.tasks import TASKS_TABLE as ML_TASKS_TABLE +from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation +from lighteval.tasks.templates.continuation import get_continuation_prompt_function +from lighteval.tasks.templates.hellaswag import get_hellaswag_prompt_function +from lighteval.tasks.templates.multichoice import get_mcq_prompt_function +from lighteval.tasks.templates.utils.formulation import ( + CFFormulation, + HybridFormulation, + MCFFormulation, +) +from lighteval.utils.language import Language + +TASKS_TABLE = [] + +TASKS_TABLE.extend(ML_TASKS_TABLE) + +arc_tasks = [ + LightevalTaskConfig( + name=f"arc_{formulation.name.lower()}:{subset.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["choices"]["text"], + "gold_idx": int(line["answerKey"]) - 1 + if line["answerKey"].isdigit() + else LETTER_INDICES.index(line["answerKey"]), + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="allenai/ai2_arc", + hf_subset=f"ARC-{subset}", + hf_revision="210d026faf9955653af8916fad021475a3f00453", + trust_dataset=True, + evaluation_splits=("test",), + few_shots_split="train", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for subset in ["Easy", "Challenge"] + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(arc_tasks) + +hellaswag_tasks = [ + LightevalTaskConfig( + name=f"hellaswag_{formulation.name.lower()}", + suite=["custom"], + prompt_function=get_hellaswag_prompt_function( + language=Language.ENGLISH, + adapter=lambda line: { + "activity_label": line["activity_label"], + "ctx_a": line["ctx_a"], + "ctx_b": line["ctx_b"], + "continuations": line["endings"], + "gold_idx": int(line["label"]), + }, + formulation=formulation, + ), + hf_repo="Rowan/hellaswag", + hf_subset="default", + hf_revision="6002345709e0801764318f06bf06ce1e7d1a1fe3", + evaluation_splits=["validation"], + hf_avail_splits=["validation"], + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + trust_dataset=True, + ) + for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] +] + +TASKS_TABLE.extend(hellaswag_tasks) + +commonsense_qa_tasks = [ + LightevalTaskConfig( + name=f"commonsenseqa_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["choices"]["text"], + "gold_idx": line["choices"]["label"].index(line["answerKey"].strip()), + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="tau/commonsense_qa", + hf_subset="default", + hf_revision="94630fe30dad47192a8546eb75f094926d47e155", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(commonsense_qa_tasks) + +openbook_qa_tasks = [ + LightevalTaskConfig( + name=f"openbookqa_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question_stem"], + "choices": line["choices"]["text"], + "gold_idx": LETTER_INDICES.index(line["answerKey"]), + }, + formulation=formulation, + ), + suite=["custom"], + hf_repo="allenai/openbookqa", + hf_subset="main", + hf_revision="388097ea7776314e93a529163e0fea805b8a6454", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(openbook_qa_tasks) + +winogrande_tasks = [ + LightevalTaskConfig( + name=f"winogrande_{formulation.name.lower()}", + suite=("custom",), + prompt_function=get_continuation_prompt_function( + Language.ENGLISH, partial(winogrand_adapter, Language.ENGLISH), formulation=formulation + ), + hf_repo="allenai/winogrande", + hf_subset="winogrande_xl", + trust_dataset=True, + hf_revision="85ac5b5a3b7a930e22d590176e39460400d19e41", + metric=[ + loglikelihood_acc_metric(normalization=None), + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(winogrande_tasks) + +piqa_tasks = [ + LightevalTaskConfig( + name=f"piqa_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["goal"], + "choices": [line["sol1"], line["sol2"]], + "gold_idx": int(line["label"]), + }, + formulation=formulation, + ), + suite=["custom"], + hf_repo="ybisk/piqa", + hf_revision="2e8ac2dffd59bac8c3c6714948f4c551a0848bb0", + hf_subset="plain_text", + trust_dataset=True, + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(piqa_tasks) + + +MMLU_SUBSETS = [ + "abstract_algebra", + "anatomy", + "astronomy", + "business_ethics", + "clinical_knowledge", + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_medicine", + "college_physics", + "computer_security", + "conceptual_physics", + "econometrics", + "electrical_engineering", + "elementary_mathematics", + "formal_logic", + "global_facts", + "high_school_biology", + "high_school_chemistry", + "high_school_computer_science", + "high_school_european_history", + "high_school_geography", + "high_school_government_and_politics", + "high_school_macroeconomics", + "high_school_mathematics", + "high_school_microeconomics", + "high_school_physics", + "high_school_psychology", + "high_school_statistics", + "high_school_us_history", + "high_school_world_history", + "human_aging", + "human_sexuality", + "international_law", + "jurisprudence", + "logical_fallacies", + "machine_learning", + "management", + "marketing", + "medical_genetics", + "miscellaneous", + "moral_disputes", + "moral_scenarios", + "nutrition", + "philosophy", + "prehistory", + "professional_accounting", + "professional_law", + "professional_medicine", + "professional_psychology", + "public_relations", + "security_studies", + "sociology", + "us_foreign_policy", + "virology", + "world_religions", +] + +mmlu_tasks = [ + LightevalTaskConfig( + name=f"mmlu_{formulation.name.lower()}:{subset}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["choices"], + "gold_idx": int(line["answer"]), + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="cais/mmlu", + hf_subset=subset, + hf_revision="c30699e8356da336a370243923dbaf21066bb9fe", + trust_dataset=True, + evaluation_splits=("test",), + few_shots_split="dev", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for subset in MMLU_SUBSETS + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(mmlu_tasks) + +mmlu_pro_tasks = [ + LightevalTaskConfig( + name=f"mmlu_pro_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["options"], + "gold_idx": line["answer_index"], + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="TIGER-Lab/MMLU-Pro", + hf_subset="default", + hf_revision="3373e0b32277875b8db2aa555a333b78a08477ea", + trust_dataset=True, + evaluation_splits=("test",), + few_shots_split="validation", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(mmlu_pro_tasks) + + +if __name__ == "__main__": + print(t.name for t in TASKS_TABLE) + print(len(TASKS_TABLE)) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py new file mode 100644 index 00000000..43d1a765 --- /dev/null +++ b/src/nanotron/eval/one_job_runner.py @@ -0,0 +1,360 @@ +""" Mostly complete a SLURM template with a link to a single checkpoint on s3 and launch it +""" +import datetime +import math +import os +import subprocess +from typing import List, Optional, Tuple + +from datasets.download.streaming_download_manager import xPath + +from nanotron import logging +from nanotron.config import Config, LightEvalConfig +from nanotron.data.s3_utils import _get_s3_path_components +from nanotron.logging import log_rank +from nanotron.parallel import ParallelContext + +logger = logging.get_logger(__name__) + + +class LightEvalRunner: + def __init__(self, config: Config, parallel_context: Optional[ParallelContext] = None): + self.config = config + assert config.lighteval is not None, "LightEval config is required" + self.lighteval_config = config.lighteval + self.parallel_context = parallel_context + + def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: + """Run light evaluation on uploaded files.""" + if ( + self.config.lighteval.eval_interval is not None + and self.config.general.step % self.config.lighteval.eval_interval != 0 + ): + logger.debug( + f"Skipping evaluation at step {self.config.general.step} because eval_interval is {self.config.lighteval.eval_interval}" + ) + return + config_files = [ + f for f in uploaded_files if "config.py" in f["destination"] or "config.yaml" in f["destination"] + ] + # Sanity check on the config files len (we want only one) + if len(config_files) == 0: + log_rank( + "No config files founds in uploaded checkpoints. Not running evaluation.", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return + if len(config_files) > 1: + log_rank( + f"Found multiple config files in uploaded checkpoints: {config_files}", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return + checkpoint_path = config_files[0]["destination"].replace("config.yaml", "") + logger.warning( + f"Lighteval Runner got {len(uploaded_files)} files. Using {checkpoint_path} as checkpoint path." + ) + + slurm_job_id, slurm_log = run_slurm_one_job( + config=self.config, + lighteval_config=self.lighteval_config, + model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + ) + + return slurm_job_id, slurm_log + + +def normalize_s3_path(path: str) -> str: + """Normalize S3 path using existing s3_utils""" + # Use existing utility to normalize path components + path = xPath(path) + bucket, prefix = _get_s3_path_components(path) + # Reconstruct normalized path + return f"s3://{bucket}/{prefix}".rstrip("/") + + +def run_slurm_one_job( + config: Config, + lighteval_config: LightEvalConfig, + model_checkpoint_path: str, + current_step: int, +): + """Launch a single job on Slurm with the given mapping""" + # Normalize S3 path if needed + if model_checkpoint_path.startswith(("s3:/", "s3://")): + model_checkpoint_path = normalize_s3_path(model_checkpoint_path) + logger.info(f"Normalized S3 path: {model_checkpoint_path}") + + # Use config values instead of hardcoded defaults + slurm_config = lighteval_config.slurm + + # Calculate the number of nodes based on parallelism config + total_gpus_needed = ( + lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp + ) + nodes = math.ceil(total_gpus_needed / slurm_config.gpus_per_node) + + # Get timestamp for log files + timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + general_run_name = config.general.run + run_name = f"{timestamp}-eval_{general_run_name}".replace(" ", "_") + + # Use lighteval config paths if available, otherwise use defaults + eval_launch_script_path = lighteval_config.slurm_script_dir + eval_logs_path = lighteval_config.logs_path + eval_launch_script_path = os.path.join(eval_launch_script_path, general_run_name, f"step-{current_step}") + eval_logs_path = os.path.join(eval_logs_path, general_run_name, f"step-{current_step}") + + # Create directories + os.makedirs(eval_launch_script_path, exist_ok=True) + os.makedirs(eval_logs_path, exist_ok=True) + + # Use configured local path instead of hardcoded /tmp + local_path = os.path.join(lighteval_config.local_checkpoint_dir, run_name, str(current_step)) + nanotron_path = lighteval_config.nanotron_path + # Create the SLURM script content + slurm_script = f"""#!/bin/bash +#SBATCH --job-name=eval_{current_step}_{run_name} +#SBATCH --partition={slurm_config.partition} +#SBATCH --nodes={nodes} +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task={slurm_config.cpus_per_task} +#SBATCH --gpus={slurm_config.gpus_per_node} +#SBATCH --exclusive +#SBATCH --qos={slurm_config.qos} +#SBATCH --time={slurm_config.time} +#SBATCH --output={eval_logs_path}/%j-{timestamp}.out""" + + if slurm_config.reservation: + slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" + + # Rest of the script content + slurm_script += f""" + +set -x + +LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={local_path} + +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token setup +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +# Set environment variables +export CUDA_DEVICE_MAX_CONNECTIONS=1 +# export CUBLAS_WORKSPACE_CONFIG=":4096:8" + +# Set HuggingFace cache locations +export HUGGINGFACE_HUB_CACHE={slurm_config.hf_cache} +export HF_DATASETS_CACHE={slurm_config.hf_cache} +export HF_MODULES_CACHE={slurm_config.hf_cache} +export HF_HOME={slurm_config.hf_cache} + +echo "Running on $COUNT_NODE nodes: $HOSTNAMES" + +# Create checkpoint directory +mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER + +# Handle S3 paths +if [[ "{model_checkpoint_path}" == s3://* ]]; then + echo "Downloading checkpoint from S3: {model_checkpoint_path}" + + # First check if the S3 path exists + if ! s5cmd ls "{model_checkpoint_path}" &>/dev/null; then + echo "Error: S3 path {model_checkpoint_path} does not exist" + exit 1 + fi + + # Try sync command and check its exit status + s5cmd cp \\ + --concurrency=50 \\ + --exclude "optimizer/*" \\ + --exclude "random/*" \\ + --exclude "lr_scheduler/*" \\ + --part-size 100 \\ + "{model_checkpoint_path}/*" "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/" + + if [ $? -ne 0 ]; then + echo "Error: Failed to sync files from S3" + exit 1 + fi + + # Verify that config.yaml was downloaded + if [ ! -f "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml" ]; then + echo "Error: config.yaml not found in downloaded checkpoint" + echo "Contents of $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER:" + ls -la $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + exit 1 + fi +else + echo "Copying checkpoint files from {model_checkpoint_path} to $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER" + rsync -av --progress --inplace --no-whole-file \\ + --exclude 'optimizer/' \\ + --exclude 'random/' \\ + --exclude 'lr_scheduler/' \\ + {model_checkpoint_path} $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + + if [ $? -ne 0 ]; then + echo "Error: Failed to copy files using rsync" + exit 1 + fi + + # Verify that config.yaml was copied + if [ ! -f "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml" ]; then + echo "Error: config.yaml not found in copied checkpoint" + echo "Contents of $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER:" + ls -la $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + exit 1 + fi +fi + +echo "Contents of checkpoint directory:" +ls -la $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + +# Add random sleep to avoid hub request conflicts +# sleep $(( RANDOM % 300 )) + +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun \\ + --nproc_per_node {slurm_config.gpus_per_node} \\ + --nnodes $COUNT_NODE \\ + --node_rank $SLURM_PROCID \\ + --master_addr $MASTER_ADDR \\ + --master_port $MASTER_PORT \\ + {nanotron_path}/run_evals.py \\ + --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \\ + --lighteval-override {lighteval_config.eval_config_override} + --cache-dir {slurm_config.hf_cache}""" + if lighteval_config.output_dir is not None and lighteval_config.s3_save_path is not None: + slurm_script += f""" +s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path} +""" + slurm_script += """ +echo "Cleaning up downloaded checkpoints..." +rm -rf "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER" +echo "Cleanup completed" + +echo "END TIME: $(date)" +""" + + # Write the script to file + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + launch_script_path = os.path.join(eval_launch_script_path, f"launch_script-{current_time}.slurm") + os.makedirs(os.path.dirname(launch_script_path), exist_ok=True) + + with open(launch_script_path, "w") as f: + f.write(slurm_script) + + # Preserve important environment variables + env = { + "PATH": os.environ["PATH"], + "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""), + "HOME": os.path.expanduser("~"), + } + + try: + # Use subprocess.run instead of check_output for better error handling + result = subprocess.run(["sbatch", launch_script_path], env=env, check=True, capture_output=True, text=True) + output = result.stdout + job_ids = output.split()[-1] + + output_log = os.path.join(eval_logs_path, f"{timestamp}-{run_name}-{job_ids}.out") + + logger.warning( + f"""🚀 Slurm job launched successfully: + Job name: {run_name} + Job ID: {job_ids} + Launch script: {launch_script_path} + Log file: {output_log}""" + ) + except subprocess.CalledProcessError as e: + logger.error(f"Error while launching Slurm job: {e}") + logger.error(f"Command output: {e.output}") + logger.error(f"Command stderr: {e.stderr}") + job_ids = None + output_log = None + + return job_ids, output_log + + +if __name__ == "__main__": + + from nanotron.config.config import Config + + # Load existing config from checkpoint + # checkpoint_path = "checkpoints/smollm3-training-test-tps-48nn-seed-6-/10" + # config_path = os.path.join(checkpoint_path, "config.yaml") + checkpoint_path = "s3://smollm3/smollm3-3B-final/3B-final-GQA-noTP-2k-seq/20000/" + config_path = "checkpoints/smollm3-training-test-tps-48nn-seed-6-/10/config.yaml" + try: + # Load the existing config + print(f"\nLoading config from: {config_path}") + config = Config.load_from_yaml(config_path) + + # Print config details + print("\nConfig details:") + print(f"Project: {config.general.project}") + print(f"Run: {config.general.run}") + print(f"Step: {config.general.step}") + + if config.lighteval: + print("\nLightEval config:") + print( + f"Parallelism: dp={config.lighteval.parallelism.dp}, tp={config.lighteval.parallelism.tp}, pp={config.lighteval.parallelism.pp}" + ) + print(f"Batch size: {config.lighteval.batch_size}") + print(f"Slurm template: {config.lighteval.slurm_template}") + print(f"Checkpoints path: {config.lighteval.checkpoints_path}") + if config.lighteval.tasks: + print(f"Tasks: {config.lighteval.tasks.tasks}") + print(f"Custom tasks: {config.lighteval.tasks.custom_tasks}") + print(f"Max samples: {config.lighteval.tasks.max_samples}") + + # Create test files structure + test_files = [ + { + "destination": os.path.join(checkpoint_path, "config.yaml"), + "source": "existing_config", + } + ] + + if config.lighteval is None: + config.lighteval = LightEvalConfig() + + print("\nInitializing LightEvalRunner...") + runner = LightEvalRunner(config=config) + + print("\nTesting LightEvalRunner.eval_single_checkpoint()...") + job_id, log_path = runner.eval_single_checkpoint(test_files) + + except Exception as e: + print(f"\nError during test: {str(e)}") + import traceback + + traceback.print_exc() + + finally: + print("\nTest completed") diff --git a/src/nanotron/logging/base.py b/src/nanotron/logging/base.py index e84554ee..b14b94aa 100644 --- a/src/nanotron/logging/base.py +++ b/src/nanotron/logging/base.py @@ -265,7 +265,7 @@ def warn_once( def human_format(num: float, billions: bool = False, divide_by_1024: bool = False) -> str: if abs(num) < 1: return "{:.3g}".format(num) - SIZES = ["", "K", "M", "G", "T", "P", "E"] + SIZES = ["", "K", "M", "B", "T", "P", "E"] num = float("{:.3g}".format(num)) magnitude = 0 i = 0 diff --git a/src/nanotron/models/qwen.py b/src/nanotron/models/qwen.py index 8115a9bb..eee5cba3 100644 --- a/src/nanotron/models/qwen.py +++ b/src/nanotron/models/qwen.py @@ -896,7 +896,7 @@ def init_model_randomly(self, config: Config): else: raise ValueError(f"Unknown init method {init_method}") - parametrizator = parametrizator_cls(config=config.model) + parametrizator = parametrizator_cls(config=config) log_rank( f"Parametrizing model parameters using {parametrizator.__class__.__name__}", diff --git a/src/nanotron/s3_checkpoints/s3_mover.py b/src/nanotron/s3_checkpoints/s3_mover.py index 483019d5..32842a90 100644 --- a/src/nanotron/s3_checkpoints/s3_mover.py +++ b/src/nanotron/s3_checkpoints/s3_mover.py @@ -225,7 +225,7 @@ def distributed_wait_for_completion(self, group: Optional[ProcessGroup] = None): dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False) dist.barrier() all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list) - time.sleep(1) + time.sleep(1) # TODO @nouamane: make this configurable def is_previous_save_finished(self) -> bool: """Return True if a potential previous checkpoint has been fully uploaded to S3 diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index 187e76e0..8f3062a9 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -3,7 +3,8 @@ from enum import Enum, auto from typing import Dict -from nanotron.config import ModelArgs +from nanotron.config import Config, ModelArgs +from nanotron.config.models_config import InitScalingMethod from nanotron.nn.layer_norm import LlamaRMSNorm, TritonRMSNorm from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, @@ -31,7 +32,7 @@ def parametrize(self, param_name: str, module: nn.Module): class StandardParametrizator(Parametrizator): - def __init__(self, config: ModelArgs): + def __init__(self, config: Config): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { TensorParallelColumnLinear: self._parametrize_column_linear, @@ -41,23 +42,42 @@ def __init__(self, config: ModelArgs): TensorParallelEmbedding: self._parametrize_embedding, } - self.std = config.init_method.std - self.num_layers = config.model_config.num_hidden_layers + self.std = config.model.init_method.std + self.num_layers = config.model.model_config.num_hidden_layers + self.tp = config.parallelism.tp + self.scaling_method = config.model.init_method.scaling_method + self.hidden_size = config.model.model_config.hidden_size def _parametrize_column_linear(self, param_name: str, module: nn.Module): assert param_name in ["weight", "bias"] if "weight" == param_name: + # TODO @nouamane: should we use trunc_normal_ init.normal_(module.weight, mean=0.0, std=self.std) elif "bias" == param_name: module.bias.zero_() + def _compute_scaling_factor(self) -> float: + """Compute initialization scaling based on selected method""" + if self.scaling_method == InitScalingMethod.NONE: + return 1.0 + elif self.scaling_method == InitScalingMethod.NUM_LAYERS: + # Scale based on total network depth + return math.sqrt(2 * self.num_layers) + elif self.scaling_method == InitScalingMethod.LAYER_INDEX: + # Scale based on layer position + raise NotImplementedError("Layer position scaling not yet implemented") + else: + raise ValueError(f"Invalid scaling method: {self.scaling_method}") + def _parametrize_row_linear(self, param_name: str, module: nn.Module): assert param_name in ["weight", "bias"] if "weight" == param_name: - std = self.std / math.sqrt(2 * self.num_layers) - init.normal_(module.weight, mean=0.0, std=std) + scaling = self._compute_scaling_factor() + adjusted_std = self.std / scaling + # TODO @nouamane: should we use trunc_normal_ + init.normal_(module.weight, mean=0.0, std=adjusted_std) elif "bias" == param_name: module.bias.zero_() @@ -65,7 +85,6 @@ def _parametrize_layer_norm(self, param_name: str, module: nn.Module): assert param_name in ["weight", "bias"] if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 module.weight.fill_(1) elif "bias" == param_name: module.bias.zero_() diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index b1445b48..2b5d4558 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -64,7 +64,7 @@ def save( try: if should_save_config: - config.save_as_yaml(root_folder / "config.yaml") + config.save_as_yaml(root_folder / "config.yaml", sanity_checks=sanity_checks) except Exception as e: # TODO @nouamane: catch full disk error log_rank( diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 5110d6eb..00c26943 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -39,6 +39,7 @@ ) from nanotron.constants import MODEL_CONFIG_FILE_NAME from nanotron.data.dataloader import sanity_check_dataloader +from nanotron.eval import LightEvalRunner from nanotron.helpers import ( _vocab_size_with_padding, compute_remain_train_steps_of_a_data_stage_from_ckp, @@ -122,7 +123,7 @@ def get_size(bytes): """Convert bytes to human readable format""" - for unit in ["", "K", "M", "G", "T", "P"]: + for unit in ["", "K", "M", "B", "T", "P"]: if bytes < 1024: return f"{bytes:.2f}{unit}B" bytes /= 1024 @@ -185,7 +186,9 @@ def __init__( ######################################## # Set random states - set_random_seed(self.config.general.seed) + # Set different random seed for each TP rank to ensure diversity (especially at weight init) + tp_rank = dist.get_rank(self.parallel_context.tp_pg) + set_random_seed(self.config.general.seed + tp_rank) # Init model and build on pp ranks self.random_states = init_random_states( @@ -312,6 +315,14 @@ def post_init(self): else: self.s3_mover = None + # Initialize LightEval runner on rank 0 + if dist.get_rank(self.parallel_context.world_pg) == 0: + if self.config.lighteval is not None: + self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) + if self.s3_mover is not None: + # If we have S3 upload enabled, use the eval_single_checkpoint as post-upload callback + self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint + def pre_training(self, *args, **kwargs): if not self.config.general.ignore_sanity_checks: log_rank( @@ -523,8 +534,6 @@ def train( ], **kwargs, ) -> None: - self.pre_training(**kwargs) - if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None: self.save_checkpoint() @@ -543,6 +552,7 @@ def train( self.initial_iter_step = self.metadata.last_train_step + 1 self.last_iter_step = self.config.tokens.train_steps + self.pre_training(**kwargs) prof = get_profiler(config=self.config) # free memory @@ -561,21 +571,23 @@ def train( outputs, loss_avg, z_loss_avg = self.training_step(dataloader=self.current_dataloader) # Update consumption tracking for current batch - self.current_base_dl.dataset.update_consumption_metrics( - start_idx=(self.iteration_step - 1) - * self.global_batch_size, # assumes we start from iteration_step=1 - end_idx=self.iteration_step * self.global_batch_size, - sequence_length=self.sequence_length, - ) + if hasattr(self.current_base_dl, "dataset"): + self.current_base_dl.dataset.update_consumption_metrics( + start_idx=(self.iteration_step - 1) + * self.global_batch_size, # assumes we start from iteration_step=1 + end_idx=self.iteration_step * self.global_batch_size, + sequence_length=self.sequence_length, + ) # Training Logs # Track consumed tokens for all dataset folders in current stage - consumption_stats = self.current_base_dl.dataset.get_consumption_stats() - current_stage = self.metadata.data_stages[self.metadata.last_stage_idx] + if hasattr(self.current_base_dl, "dataset"): + consumption_stats = self.current_base_dl.dataset.get_consumption_stats() + current_stage = self.metadata.data_stages[self.metadata.last_stage_idx] - # Update consumed tokens for all folders in the consumption stats - for folder_path, stats in consumption_stats.items(): - current_stage.consumed_tokens_per_dataset_folder[folder_path] = stats["tokens"] + # Update consumed tokens for all folders in the consumption stats + for folder_path, stats in consumption_stats.items(): + current_stage.consumed_tokens_per_dataset_folder[folder_path] = stats["tokens"] # Original consumption tracking self.metadata.consumed_train_samples += self.global_batch_size @@ -763,7 +775,8 @@ def train_step_logs( # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "consumed_tokens", - self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + self.metadata.consumed_train_samples + * self.config.tokens.sequence_length, # TODO: not true if we change seqlen "human_format", ), # , "12d"), LogItem("time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"), @@ -863,12 +876,13 @@ def get_cpu_logitems(): assert self.current_base_dl is not None, "current_base_dl should be defined" # Log consumption statistics - for dataset_name, stats in self.current_base_dl.dataset.get_consumption_stats().items(): - basic_log_entries.extend( - [ - LogItem(f"dataloader/consumed_tokens/{dataset_name}", stats["tokens"], "human_format"), - ] - ) + if hasattr(self.current_base_dl, "dataset"): + for dataset_name, stats in self.current_base_dl.dataset.get_consumption_stats().items(): + basic_log_entries.extend( + [ + LogItem(f"dataloader/consumed_tokens/{dataset_name}", stats["tokens"], "human_format"), + ] + ) # WandB logging - determine if this rank should log to wandb should_log_to_wandb = wandb is not None and ( @@ -1160,26 +1174,69 @@ def setup_log_writers( return loggerwriter def pre_save_checkpoint(self) -> Path: + # Check if eval_interval should be updated from file + eval_interval_file = self.config.lighteval.eval_interval_file + if eval_interval_file is not None and Path(eval_interval_file).exists(): + try: + with open(eval_interval_file, "r") as f: + new_eval_interval = int(f.read().strip()) + + # Verify that the new interval is a multiple of checkpoint_interval + if new_eval_interval == self.config.lighteval.eval_interval: + pass + elif new_eval_interval % self.config.checkpoints.checkpoint_interval == 0: + log_rank( + f"Updating lighteval.eval_interval from {self.config.lighteval.eval_interval} to {new_eval_interval}", + logger=logger, + level=logging.INFO, + rank=0, + ) + self.config.lighteval.eval_interval = new_eval_interval + else: + log_rank( + f"New eval_interval={new_eval_interval} must be a multiple of checkpoint_interval={self.config.checkpoints.checkpoint_interval}. Keeping current value: {self.config.lighteval.eval_interval}", + logger=logger, + level=logging.WARNING, + rank=0, + ) + except (ValueError, IOError) as e: + log_rank( + f"Error reading eval_interval from file: {e}. Keeping current value: {self.config.lighteval.eval_interval}", + logger=logger, + level=logging.WARNING, + rank=0, + ) + if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) if self.s3_mover.post_upload_callback_outputs is not None: slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs - self.log_object({"job_id": slurm_job_id, "log": slurm_log}, "slurm_eval") + log_rank( + f"launching eval job: job_id={slurm_job_id} log at {slurm_log} slurm_eval", + logger=logger, + level=logging.WARNING, + rank=0, + ) def post_save_checkpoint(self): # Upload to S3 if self.s3_mover is not None: self.s3_mover.start_uploading() - # free memory TODO: do we need this? - # gc.collect() - # torch.cuda.empty_cache() + if dist.get_rank(self.parallel_context.world_pg) == 0: + if self.config.lighteval is not None and self.s3_mover is None: + if ( + self.config.lighteval.eval_interval is None + or self.iteration_step % self.config.lighteval.eval_interval == 0 + ): + checkpoint_path = Path(self.config.checkpoints.checkpoints_path) / f"{self.config.general.step}" + self.lighteval_runner.eval_single_checkpoint(checkpoint_path) def save_checkpoint(self) -> Path: self.pre_save_checkpoint() checkpoints_path = self.config.checkpoints.checkpoints_path - checkpoint_path = checkpoints_path / f"{self.iteration_step}" + checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}" if self.config.checkpoints.checkpoints_path_is_shared_file_system: should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0 else: @@ -1210,6 +1267,7 @@ def save_checkpoint(self) -> Path: root_folder=checkpoint_path, training_metadata=self.metadata, config=self.config, + sanity_checks=not self.config.general.ignore_sanity_checks, ) save_random_states( random_states=self.random_states, parallel_context=self.parallel_context, root_folder=checkpoint_path From d8b27179388bd024f4dadd9454d5aebd8d42efeb Mon Sep 17 00:00:00 2001 From: grewalsk <136873529+grewalsk@users.noreply.github.com> Date: Fri, 18 Apr 2025 11:06:34 -0400 Subject: [PATCH 04/10] =?UTF-8?q?[Feature]=20Implement=20CUDA=20event-base?= =?UTF-8?q?d=20timing=20for=20improved=20GPU=20performa=E2=80=A6=20(#346)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Feature] Implement CUDA event-based timing for improved GPU performance measurement * can only merge to main from dev (#348) * Fix timer decorator logic: Support both CPU and CUDA timers and update docs * Fix timer decorator logic: support both CPU and CUDA; update docs --------- Co-authored-by: Kabir Grewal Co-authored-by: Nouamane Tazi Co-authored-by: Kabir Grewal --- README.md | 3 +- docs/cuda_event_timing.md | 92 ++++++++++++++++++++++++++++++++++ src/nanotron/logging/timers.py | 50 +++++++++++++----- src/nanotron/trainer.py | 10 ++-- test_timer_decorator.py | 63 +++++++++++++++++++++++ 5 files changed, 201 insertions(+), 17 deletions(-) create mode 100644 docs/cuda_event_timing.md create mode 100644 test_timer_decorator.py diff --git a/README.md b/README.md index 719d0720..1b5df079 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 run_train.py --config- The model will be saved in the `checkpoints` directory as specified in the config file. > [!NOTE] -> You can use `examples/config_tiny_llama.py` to generate your own training config +> You can use `examples/config_tiny_llama.py` to generate your own training config For detailed instructions on training your first model, check out our [Your First Training guide](docs/your-first-training.md). For multi-node training with Slurm, see our [Multi-Node Training guide](docs/multi-node-training.md). @@ -173,6 +173,7 @@ We currently support the following features: - [x] Custom module checkpointing for large models - [x] Spectral µTransfer parametrization for scaling up neural networks - [x] Mamba example +- [x] CUDA event-based timing for accurate GPU performance measurement And we have on our roadmap: - [ ] FP8 training diff --git a/docs/cuda_event_timing.md b/docs/cuda_event_timing.md new file mode 100644 index 00000000..9f8f4935 --- /dev/null +++ b/docs/cuda_event_timing.md @@ -0,0 +1,92 @@ +# CUDA Event-Based Timing in Nanotron + +## Overview + +Nanotron now uses CUDA events for timing GPU operations instead of CPU-based timing with `time.time()`. This change provides several benefits: + +1. **More accurate measurement of GPU execution time**: CUDA events are recorded directly on the GPU timeline, providing more precise timing of GPU operations. +2. **Reduced need for explicit CUDA synchronization**: CPU-based timing requires synchronization between CPU and GPU to get accurate measurements, which can introduce overhead and affect performance. +3. **Lower overhead**: CUDA event-based timing has minimal impact on the execution of GPU operations. +4. **Better performance monitoring**: More accurate timing leads to better performance analysis and optimization. + +## Implementation Details + +The implementation uses `torch.cuda.Event` with `enable_timing=True` to create start and end events that are recorded on the GPU timeline. The elapsed time is then calculated using `start_event.elapsed_time(end_event)`, which returns the time in milliseconds. + +### Key Changes + +1. **Default Timer Type**: The default timer type in `nanotron/src/nanotron/logging/timers.py` has been changed from `TimerType.CPU` to `TimerType.CUDA`. + +2. **Iteration Timing**: The iteration timing in `trainer.py` now uses CUDA events instead of `time.time()`. + +3. **Synchronization Control**: By default, CUDA event-based timers do not force synchronization unless explicitly requested with `cuda_sync=True`. + +## Usage + +### Basic Usage + +```python +# Create and use a CUDA timer (default) +with nanotron_timer("my_operation"): + # Your GPU operation here + ... + +# Explicitly specify CUDA timing +with nanotron_timer("my_operation", timer_type="cuda"): + # Your GPU operation here + ... + +# For CPU-only operations, you can still use CPU-based timing +with nanotron_timer("cpu_operation", timer_type="cpu"): + # Your CPU operation here + ... + +# As a decorator with default CUDA timing +@nanotron_timer +def my_function(): + # Your GPU operation here + ... + +# As a decorator with custom name +@nanotron_timer("custom_name") +def my_function(): + # Your GPU operation here + ... + +# As a decorator with CPU timing +@nanotron_timer(timer_type=TimerType.CPU) +def my_cpu_function(): + # Your CPU operation here + ... +``` + +### Advanced Usage + +```python +# Start and end a timer manually +timer = nanotron_timer("my_operation") +timer.start() +# Your operation here +timer.end() + +# Get the elapsed time in seconds +elapsed_time = timer.elapsed + +# Get the total time across all calls +total_time = timer.total_time + +# Get the average time per call +avg_time = timer.average_time +``` + +## Considerations + +1. **Synchronization**: By default, CUDA event-based timers do not force synchronization to avoid overhead. If you need more accurate timing at the cost of performance, you can set `cuda_sync=True`. + +2. **Units**: CUDA events measure time in milliseconds, but the timer API converts this to seconds for consistency with the previous CPU-based timing. + +3. **Fallback**: If CUDA is not available, the timer will automatically fall back to CPU-based timing. + +## Performance Impact + +Using CUDA events for timing instead of CPU-based timing with synchronization can significantly reduce overhead, especially in distributed training scenarios with thousands of GPUs. diff --git a/src/nanotron/logging/timers.py b/src/nanotron/logging/timers.py index 1129b9c6..76f0aa0d 100644 --- a/src/nanotron/logging/timers.py +++ b/src/nanotron/logging/timers.py @@ -19,10 +19,17 @@ class TimerType(Enum): @dataclass class TimerRecord: - """Records timing information for a single timer.""" + """ + Records timing information for a single timer. + + By default, uses CUDA events for timing GPU operations, which provides more accurate + measurements of GPU execution time without forcing CPU-GPU synchronization. + + For CPU-only operations, you can use CPU-based timing by specifying timer_type=TimerType.CPU. + """ name: str - timer_type: TimerType = TimerType.CPU + timer_type: TimerType = TimerType.CUDA start_time: float = 0.0 end_time: float = 0.0 running: bool = False @@ -175,7 +182,17 @@ def average_time(self) -> float: class Timers: - """A collection of timers for tracking execution time in Nanotron.""" + """ + A collection of timers for tracking execution time in Nanotron. + + By default, timers use CUDA events for timing GPU operations, which provides several benefits: + 1. More accurate measurement of GPU execution time + 2. Reduced need for explicit CUDA synchronization + 3. Lower overhead compared to CPU-based timing with synchronization + 4. Better performance monitoring for distributed training + + For CPU-only operations, you can still use CPU-based timing by specifying timer_type=TimerType.CPU. + """ _instance = None _enabled = os.environ.get("ENABLE_TIMERS", "0") == "1" # Add global enable/disable flag @@ -202,20 +219,23 @@ def is_enabled(cls) -> bool: return cls._enabled def __call__( - self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU, cuda_sync: bool = True + self, name: str, timer_type: Union[TimerType, str] = TimerType.CUDA, cuda_sync: bool = False ) -> TimerRecord: """Get or create a timer with the given name. Can be used as a decorator, context manager, or directly: - - @nanotron_timer("name") # As decorator + - @nanotron_timer # As decorator with default CUDA timing + - @nanotron_timer("my_function") # As decorator with custom name + - @nanotron_timer(timer_type=TimerType.CPU) # As decorator with CPU timing - with nanotron_timer("name"): ... # As context manager - nanotron_timer("name").start(); ...; nanotron_timer("name").end() # Direct use Args: name: Name of the timer - timer_type: Type of timer, either TimerType.CPU or TimerType.CUDA - (or 'cpu'/'cuda' strings) - cuda_sync: Whether to perform torch.cuda.synchronize() for more accurate CUDA timing + timer_type: Type of timer, either TimerType.CUDA (default) or TimerType.CPU + (or 'cuda'/'cpu' strings) + cuda_sync: Whether to perform torch.cuda.synchronize() for more accurate CUDA timing. + Default is False to avoid unnecessary synchronization overhead. """ if not self._enabled: # Return a dummy timer that does nothing when timing is disabled @@ -224,11 +244,11 @@ def __call__( if isinstance(timer_type, str): timer_type = TimerType(timer_type) - if callable(name) and timer_type == TimerType.CPU: - # Being used as a decorator with default settings + if callable(name): + # Being used as a decorator with specified or default settings func = name timer_name = func.__name__ - return self._create_timer_decorator(timer_name, TimerType.CPU, cuda_sync)(func) + return self._create_timer_decorator(timer_name, timer_type, cuda_sync)(func) if name not in self._timers: self._timers[name] = TimerRecord(name=name, timer_type=timer_type, cuda_sync=cuda_sync) @@ -245,8 +265,12 @@ def __call__( # If we get here, we're being called as @nanotron_timer("name", timer_type) return self._create_timer_decorator(name, timer_type, cuda_sync) - def _create_timer_decorator(self, name, timer_type, cuda_sync=False): - """Create a decorator that times the execution of a function.""" + def _create_timer_decorator(self, name, timer_type=TimerType.CUDA, cuda_sync=False): + """Create a decorator that times the execution of a function. + + This method supports both CUDA and CPU timer types, allowing for flexible timing + of functions based on the specific needs (GPU operations vs CPU operations). + """ def decorator(func): import functools diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 00c26943..d4289ddd 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -278,6 +278,7 @@ def __init__( self.limit_val_batches = self.config.tokens.limit_val_batches self.current_dataloader: Optional[DataLoader] = None # used for the current training stage self.current_base_dl: Optional[DataLoader] = None # used for the current training stage + self.iteration_timer = None # Will be initialized during training log_libraries_versions(logger=logger) log_rank("Config:", logger=logger, level=logging.INFO, rank=0, is_separator=True) @@ -564,7 +565,9 @@ def train( logger.info(f"Profiler on for step {self.iteration_step}") prof.step() - self.iteration_start_time = time.time() + # Use CUDA event-based timing for more accurate GPU-side elapsed time measurement + self.iteration_timer = nanotron_timer("iteration_time", "cuda", cuda_sync=False) + self.iteration_timer.start() self._update_dataloader_based_on_training_stages(dataloader_or_dls) # Training step @@ -750,8 +753,9 @@ def train_step_logs( ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() - torch.cuda.synchronize() - elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 + # End the iteration timer and get elapsed time in milliseconds + self.iteration_timer.end() + elapsed_time_per_iteration_ms = self.iteration_timer.elapsed * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) ) # tokens_per_sec is calculated using sequence_length diff --git a/test_timer_decorator.py b/test_timer_decorator.py new file mode 100644 index 00000000..b900f33d --- /dev/null +++ b/test_timer_decorator.py @@ -0,0 +1,63 @@ +"""Test script for the timer decorator with both CPU and CUDA timer types.""" + +from nanotron.logging.timers import nanotron_timer, TimerType +import time +import torch + +# Enable timers for testing +nanotron_timer.enable() + +# Test with default CUDA timing +@nanotron_timer +def test_default_decorator(): + """Test function with default CUDA timing.""" + # Simulate some work + time.sleep(0.1) + if torch.cuda.is_available(): + x = torch.randn(1000, 1000, device="cuda") + y = torch.matmul(x, x) + torch.cuda.synchronize() + return "Done" + +# Test with explicit CUDA timing +@nanotron_timer(timer_type=TimerType.CUDA) +def test_cuda_decorator(): + """Test function with explicit CUDA timing.""" + # Simulate some work + time.sleep(0.1) + if torch.cuda.is_available(): + x = torch.randn(1000, 1000, device="cuda") + y = torch.matmul(x, x) + torch.cuda.synchronize() + return "Done" + +# Test with CPU timing +@nanotron_timer(timer_type=TimerType.CPU) +def test_cpu_decorator(): + """Test function with CPU timing.""" + # Simulate some CPU work + time.sleep(0.2) + return "Done" + +# Test with custom name +@nanotron_timer("custom_name") +def test_custom_name_decorator(): + """Test function with custom name.""" + # Simulate some work + time.sleep(0.1) + return "Done" + +if __name__ == "__main__": + print("Testing timer decorators...") + + # Run the test functions + test_default_decorator() + test_cuda_decorator() + test_cpu_decorator() + test_custom_name_decorator() + + # Log all timers + print("\nTimer results:") + nanotron_timer.log_all(rank=None) # Log on all ranks + + print("\nTest completed successfully!") From 9095a9d59e024f41ed16fc777b2c73fcd50e56c7 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Fri, 18 Apr 2025 16:44:28 +0100 Subject: [PATCH 05/10] amend previous pr (#354) --- src/nanotron/logging/timers.py | 53 ++++++++++++++++++++-------------- src/nanotron/trainer.py | 3 +- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/src/nanotron/logging/timers.py b/src/nanotron/logging/timers.py index 76f0aa0d..e3603f11 100644 --- a/src/nanotron/logging/timers.py +++ b/src/nanotron/logging/timers.py @@ -35,6 +35,7 @@ class TimerRecord: running: bool = False call_count: int = 0 cuda_sync: bool = False # Option to add CUDA synchronization for more accurate timings + enabled: bool = True # Allow individual timer to be enabled/disabled # For CPU timers we still track total_time _cpu_total_time: float = 0.0 @@ -55,7 +56,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def start(self) -> "TimerRecord": """Start the timer.""" - if self.name == "dummy": # disabled + if self.name == "dummy" or not self.enabled: # disabled return self if self.running: @@ -82,7 +83,7 @@ def start(self) -> "TimerRecord": def end(self) -> None: """End the timer, but don't compute elapsed time yet.""" - if self.name == "dummy": # disabled + if self.name == "dummy" or not self.enabled: # disabled return if not self.running: @@ -219,7 +220,11 @@ def is_enabled(cls) -> bool: return cls._enabled def __call__( - self, name: str, timer_type: Union[TimerType, str] = TimerType.CUDA, cuda_sync: bool = False + self, + name: str, + timer_type: Union[TimerType, str] = TimerType.CUDA, + cuda_sync: bool = False, + enabled: bool = bool(int(os.environ.get("ENABLE_TIMERS", "0"))), ) -> TimerRecord: """Get or create a timer with the given name. @@ -236,11 +241,11 @@ def __call__( (or 'cuda'/'cpu' strings) cuda_sync: Whether to perform torch.cuda.synchronize() for more accurate CUDA timing. Default is False to avoid unnecessary synchronization overhead. - """ - if not self._enabled: - # Return a dummy timer that does nothing when timing is disabled - return TimerRecord(name="dummy", timer_type=TimerType.CPU) + enabled: Override default enabled setting from environment variable + Raises: + ValueError: If a timer with the same name already exists with different settings + """ if isinstance(timer_type, str): timer_type = TimerType(timer_type) @@ -248,13 +253,23 @@ def __call__( # Being used as a decorator with specified or default settings func = name timer_name = func.__name__ - return self._create_timer_decorator(timer_name, timer_type, cuda_sync)(func) + return self._create_timer_decorator(timer_name, timer_type, cuda_sync, enabled)(func) - if name not in self._timers: - self._timers[name] = TimerRecord(name=name, timer_type=timer_type, cuda_sync=cuda_sync) - else: - # Update the cuda_sync option if the timer already exists - self._timers[name].cuda_sync = cuda_sync + if name in self._timers: + existing_timer = self._timers[name] + if ( + existing_timer.timer_type != timer_type + or existing_timer.cuda_sync != cuda_sync + or existing_timer.enabled != enabled + ): + raise ValueError( + f"Timer '{name}' already exists with different settings.\n" + f"Existing: type={existing_timer.timer_type}, cuda_sync={existing_timer.cuda_sync}, enabled={existing_timer.enabled}\n" + f"New: type={timer_type}, cuda_sync={cuda_sync}, enabled={enabled}" + ) + return existing_timer + + self._timers[name] = TimerRecord(name=name, timer_type=timer_type, cuda_sync=cuda_sync, enabled=enabled) # Check if we're being called as a decorator if not callable(name): @@ -263,21 +278,17 @@ def __call__( return timer_record # If we get here, we're being called as @nanotron_timer("name", timer_type) - return self._create_timer_decorator(name, timer_type, cuda_sync) + return self._create_timer_decorator(name, timer_type, cuda_sync, enabled) - def _create_timer_decorator(self, name, timer_type=TimerType.CUDA, cuda_sync=False): - """Create a decorator that times the execution of a function. - - This method supports both CUDA and CPU timer types, allowing for flexible timing - of functions based on the specific needs (GPU operations vs CPU operations). - """ + def _create_timer_decorator(self, name, timer_type=TimerType.CUDA, cuda_sync=False, enabled=None): + """Create a decorator that times the execution of a function.""" def decorator(func): import functools @functools.wraps(func) def wrapper(*args, **kwargs): - with self(name, timer_type, cuda_sync): + with self(name, timer_type, cuda_sync, enabled): return func(*args, **kwargs) return wrapper diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index d4289ddd..a64bb5db 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -4,7 +4,6 @@ import os import shutil import tempfile -import time from dataclasses import asdict from pathlib import Path from pprint import pformat @@ -566,7 +565,7 @@ def train( prof.step() # Use CUDA event-based timing for more accurate GPU-side elapsed time measurement - self.iteration_timer = nanotron_timer("iteration_time", "cuda", cuda_sync=False) + self.iteration_timer = nanotron_timer("iteration_time", "cuda", cuda_sync=False, enabled=True) self.iteration_timer.start() self._update_dataloader_based_on_training_stages(dataloader_or_dls) From fc003ac1244dd08b421db18e82ba9b71e04da989 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?X=CE=BBRI-U5?= Date: Wed, 23 Apr 2025 22:44:22 +0200 Subject: [PATCH 06/10] MoE without token dropping (#355) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * can only merge to main from dev (#348) * move moe from qwen modeling to src/nn * add groupedmlp * add token permute and unpermute * fix num_tokens_per_expert counting < num_experts * fix init and init scaling factor and run evals in background (#353) * can only merge to main from dev * Fix UnBoundLocalError in `clm_collator.py` (#339) * Update clm_collator.py * can only merge to main from dev (#348) --------- Co-authored-by: Nouamane Tazi * fix init and init scaling factor and run evals in background (#349) * InitScalingMethod * InitScalingMethod * run evals in background (#352) * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes --------- Co-authored-by: elie <97572401+eliebak@users.noreply.github.com> Co-authored-by: “eliebak” --------- Co-authored-by: elie <97572401+eliebak@users.noreply.github.com> Co-authored-by: “eliebak” --------- Co-authored-by: Connector Switch Co-authored-by: elie <97572401+eliebak@users.noreply.github.com> Co-authored-by: “eliebak” * inference qwen moe seems to work inference seems good rn * update readme * fix router's weight initialization and wrong hidden size for non-moe mlp in qwen * add source for router weight and router logits in float32 * fixes * . * . * add parametrize grouped mlp in column and row linear * add logging per-param grad norm * fix conversation fail due to buffer on cpu * config_qwen * . * . * fix moe convert config --------- Co-authored-by: Nouamane Tazi Co-authored-by: Connector Switch Co-authored-by: elie <97572401+eliebak@users.noreply.github.com> Co-authored-by: “eliebak” Co-authored-by: zzhhjjj Co-authored-by: nouamanetazi --- examples/config_qwen.py | 2 +- examples/config_qwen.yaml | 14 +- examples/config_qwen_with_moe.yaml | 132 +++++++++ examples/inference/qwen_moe/README.md | 27 ++ examples/inference/qwen_moe/convert.py | 329 +++++++++++++++++++++ pyproject.toml | 1 + src/nanotron/config/models_config.py | 9 +- src/nanotron/config/parallelism_config.py | 2 - src/nanotron/logging/base.py | 3 + src/nanotron/models/base.py | 57 ++-- src/nanotron/models/qwen.py | 192 +----------- src/nanotron/nn/moe.py | 212 +++++++++++++ src/nanotron/optim/gradient_accumulator.py | 3 +- src/nanotron/scaling/parametrization.py | 26 ++ src/nanotron/trainer.py | 2 +- tests/helpers/qwen_helper.py | 50 ++-- tests/test_moe.py | 37 +++ 17 files changed, 865 insertions(+), 233 deletions(-) create mode 100644 examples/config_qwen_with_moe.yaml create mode 100644 examples/inference/qwen_moe/README.md create mode 100644 examples/inference/qwen_moe/convert.py create mode 100644 src/nanotron/nn/moe.py create mode 100644 tests/test_moe.py diff --git a/examples/config_qwen.py b/examples/config_qwen.py index 639ed2d6..8ca8487b 100644 --- a/examples/config_qwen.py +++ b/examples/config_qwen.py @@ -108,7 +108,7 @@ def get_model_config(model_size: str) -> Qwen2Config: is_qwen2_config=True, pad_token_id=None, _attn_implementation="flash_attention_2", - sliding_window_size=20, + # sliding_window_size=20, ) diff --git a/examples/config_qwen.yaml b/examples/config_qwen.yaml index 5fc8e48e..a2ce9bd1 100644 --- a/examples/config_qwen.yaml +++ b/examples/config_qwen.yaml @@ -32,7 +32,7 @@ general: consumed_train_samples: null ignore_sanity_checks: false project: debug - run: qwen_20250410_014907_16027793 + run: qwen_20250423_201000_16423158 seed: 42 step: null lighteval: null @@ -45,14 +45,15 @@ model: ddp_bucket_cap_mb: 25 dtype: bfloat16 init_method: + scaling_method: NUM_LAYERS std: 0.025 make_vocab_size_divisible_by: 1 model_config: _attn_implementation: flash_attention_2 - _fused_rms_norm: true - _fused_rotary_emb: true - _use_doc_masking: true - _use_qkv_packed: true + _fused_rms_norm: false + _fused_rotary_emb: false + _use_doc_masking: false + _use_qkv_packed: false attention_bias: false bos_token_id: 1 eos_token_id: 2 @@ -74,7 +75,7 @@ model: rope_interleaved: false rope_scaling: null rope_theta: 10000.0 - sliding_window_size: 20 + sliding_window_size: null tie_word_embeddings: true use_cache: true vocab_size: 128256 @@ -104,7 +105,6 @@ parallelism: context_parallel_size: 1 dp: 2 expert_parallel_size: 1 - moe_layer_recompute: false pp: 1 pp_engine: 1f1b recompute_layer: false diff --git a/examples/config_qwen_with_moe.yaml b/examples/config_qwen_with_moe.yaml new file mode 100644 index 00000000..5e51307f --- /dev/null +++ b/examples/config_qwen_with_moe.yaml @@ -0,0 +1,132 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: /fsx/phuc/new_workspace/experiments/qwen2_moe_test + checkpoints_path_is_shared_file_system: false + load_lr_scheduler: true + load_optimizer: true + resume_checkpoint_path: null + save_final_state: true + save_initial_state: false +data_stages: +- data: + dataset: + dataset_folder: + - /fsx/loubna/datasets/llama_tokenized/fineweb-edu/merged + dataset_max_tokens: null + dataset_read_path: null + dataset_weights: null + pad_samples_to_global_batch_size: false + return_positions: true + shuffle_files: false + skip_in_stream: false + token_size_in_bytes: 4 + tokenizer_name: meta-llama/Llama-3.2-1B + use_old_brrr_dataloader: false + vocab_size: 128256 + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: false + project: qwen_moe + run: qwen_20250410_014907_16027793 + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +metrics_logging: null +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + _attn_implementation: flash_attention_2 + _fused_rms_norm: true + _fused_rotary_emb: true + _use_doc_masking: true + _use_qkv_packed: true + attention_bias: false + bos_token_id: 1 + eos_token_id: 2 + flex_attention_mask: null + hidden_act: silu + hidden_size: 256 + initializer_range: 0.02 + intermediate_size: 768 + is_qwen2_config: true + max_position_embeddings: 4096 + moe_config: null + no_rope_layer: null + num_attention_heads: 4 + num_hidden_layers: 12 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-06 + rope_interleaved: false + rope_scaling: null + rope_theta: 10000.0 + sliding_window_size: 20 + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 + z_loss_coefficient: 0.0001 + z_loss_enabled: false + moe_config: + num_experts: 8 + top_k: 1 + enable_shared_expert: true + token_dispatcher_type: alltoall +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 31998 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + weight_decay_exclude_named_params: [] + zero_stage: 0 +parallelism: + context_parallel_size: 1 + dp: 2 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + recompute_layer: false + tp: 1 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER + tp_recompute_allgather: true +profiler: null +s3_upload: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Llama-3.2-1B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 3 + sequence_length: 4096 + train_steps: 32000 + val_check_interval: -1 diff --git a/examples/inference/qwen_moe/README.md b/examples/inference/qwen_moe/README.md new file mode 100644 index 00000000..594b6ce5 --- /dev/null +++ b/examples/inference/qwen_moe/README.md @@ -0,0 +1,27 @@ +# Qwen-MoE Inference + +This guide explains how to convert Hugging face Qwen-MoE models to Nanotron format and run inference with them. + +## Convert Qwen-MoE to Nanotron Format + +Navigate to the `inference/qwen_moe` directory and run: + +```bash +torchrun --nproc-per-node 1 examples/inference/qwen_moe/convert.py \ + --nanotron-checkpoint-path nanotron_checkpoints/Qwen1.5-MoE-A2.7B \ + --pretrained-model-name-or-path Qwen/Qwen1.5-MoE-A2.7B +``` + +This command will save the converted model weights to the specified path in `nanotron_checkpoints` + +## Run Inference + +From the root directory of Nanotron, run: + +```bash +torchrun --rdzv_endpoint=localhost:29700 --rdzv-backend=c10d --nproc_per_node=1 \ + run_generate.py \ + --ckpt-path nanotron_checkpoints/Qwen1.5-MoE-A2.7B +``` + +This command will load the converted model weights and run inference. diff --git a/examples/inference/qwen_moe/convert.py b/examples/inference/qwen_moe/convert.py new file mode 100644 index 00000000..419495da --- /dev/null +++ b/examples/inference/qwen_moe/convert.py @@ -0,0 +1,329 @@ +""" +torchrun --nproc-per-node 1 convert.py --nanotron-checkpoint-path nanotron_checkpoints/Qwen1.5-MoE-A2.7B --pretrained-model-name-or-path Qwen/Qwen1.5-MoE-A2.7B +""" +import argparse +import json +from dataclasses import asdict +from pathlib import Path + +import torch +import yaml +from nanotron import logging +from nanotron.config import Config, GeneralArgs, LoggingArgs, ModelArgs, ParallelismArgs, TokenizerArgs +from nanotron.config.models_config import ExistingCheckpointInit, MoEConfig, Qwen2Config +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.models.qwen import Qwen2ForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize import TrainingMetadata, save_meta, save_weights +from nanotron.serialize.metadata import DataStageMetadata +from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2MoeConfig + +logger = logging.get_logger(__name__) + +# NOTE: We need to initialize the model on gpu, because RotaryEmbedding +# requires its buffer to be on gpu +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Qwen-MoE HF model + log_rank( + f"Loading pretrained qwen moe Model: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + hf_model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + ).to(DEVICE) + hf_config: Qwen2MoeConfig = hf_model.config + + # Set Nanotron Qwen2Config + nanotron_config = Qwen2Config( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_qwen2_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + attention_bias=True, # qwen-moe uses attention bias + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + rope_interleaved=False, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size, + moe_config=MoEConfig( + top_k=hf_config.num_experts_per_tok, + num_experts=hf_config.num_experts, + moe_intermediate_size=hf_config.moe_intermediate_size, + shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size, + router_aux_loss_coef=hf_config.router_aux_loss_coef, + enable_shared_expert=True, + ), + ) + + # Init Nanotron Qwen-MoE model + log_rank("Init empty Nanotron Qwen Moe Model", logger=logger, level=logging.INFO, rank=0) + nanotron_model = build_model( + model_builder=lambda: Qwen2ForTraining( + config=nanotron_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context, parallel_config=parallel_config) + sanity_check(root_module=nanotron_model) + + # Copy params from HF to Nanotron + log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + + with torch.no_grad(): + # token embeddings + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_( + hf_model.model.embed_tokens.weight + ) + + # Decoder layers + for i in tqdm( + range(nanotron_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_( + hf_model.model.layers[i].input_layernorm.weight + ) + + # Self attn + ## QKV + tmp_qkv_proj = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.weight, + hf_model.model.layers[i].self_attn.k_proj.weight, + hf_model.model.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + ## QKV bias + tmp_qkv_bias = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.bias, + hf_model.model.layers[i].self_attn.k_proj.bias, + hf_model.model.layers[i].self_attn.v_proj.bias, + ], + dim=0, + ) + assert tmp_qkv_bias.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.bias.shape + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.bias.copy_(tmp_qkv_bias) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_( + hf_model.model.layers[i].self_attn.o_proj.weight + ) + + # MLP + ## Router + assert ( + hf_model.model.layers[i].mlp.gate.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.router.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.router.weight.copy_(hf_model.model.layers[i].mlp.gate.weight) + + ## shared expert: Gate Up Proj + tmp_shared_expert = torch.cat( + [ + hf_model.model.layers[i].mlp.shared_expert.gate_proj.weight, + hf_model.model.layers[i].mlp.shared_expert.up_proj.weight, + ], + dim=0, + ) + assert ( + tmp_shared_expert.shape + == nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.gate_up_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.gate_up_proj.weight.copy_(tmp_shared_expert) + + ## shared expert: Down Proj + assert ( + hf_model.model.layers[i].mlp.shared_expert.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.down_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.down_proj.weight.copy_( + hf_model.model.layers[i].mlp.shared_expert.down_proj.weight + ) + + ## shared expert: Gate + assert ( + hf_model.model.layers[i].mlp.shared_expert_gate.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.shared_expert_gate.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.shared_expert_gate.weight.copy_( + hf_model.model.layers[i].mlp.shared_expert_gate.weight + ) + + ## experts: + # concatenate all gate_up_proj and down_proj weights for experts into merged_gate_up_proj and merged_down_proj + tmp_merged_gate_up_proj = torch.zeros( + nanotron_config.moe_config.num_experts, + nanotron_config.hidden_size, + 2 * nanotron_config.moe_config.moe_intermediate_size, + ) + tmp_merged_down_proj = torch.zeros( + nanotron_config.moe_config.num_experts, + nanotron_config.moe_config.moe_intermediate_size, + nanotron_config.hidden_size, + ) + + for j in range(nanotron_config.moe_config.num_experts): + ## Gate Up Proj + tmp_merged_gate_up_proj[j, :, : nanotron_config.moe_config.moe_intermediate_size] = ( + hf_model.model.layers[i].mlp.experts[j].gate_proj.weight.T + ) + tmp_merged_gate_up_proj[j, :, nanotron_config.moe_config.moe_intermediate_size :] = ( + hf_model.model.layers[i].mlp.experts[j].up_proj.weight.T + ) + + ## Down Proj + tmp_merged_down_proj[j] = hf_model.model.layers[i].mlp.experts[j].down_proj.weight.T + + # copy to merged_gate_up_proj and merged_down_proj + nanotron_model.model.decoder[i].pp_block.mlp.experts.merged_gate_up_proj.copy_(tmp_merged_gate_up_proj) + nanotron_model.model.decoder[i].pp_block.mlp.experts.merged_down_proj.copy_(tmp_merged_down_proj) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + hf_model.model.layers[i].post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) + + # LM_Head + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) + + log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + # Store weights + nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path) + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path) + + # Store metadata + log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0) + training_metadata = TrainingMetadata( + last_train_step=0, + consumed_train_samples=0, + data_stages=[DataStageMetadata(name="Empty", consumed_train_samples=0, start_training_step=0)], + ) + save_meta( + root_folder=nanotron_checkpoint_path, parallel_context=parallel_context, training_metadata=training_metadata + ) + # Store Tokenizer into Nanotron Checkpoint folder + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) + tokenizer.save_pretrained(nanotron_checkpoint_path) + + # Store Config and Model Config files + with open(nanotron_checkpoint_path / "config.yaml", "w") as f: + config = Config( + general=GeneralArgs(project="Nanotron", run="Qwen2-MoE"), + parallelism=parallel_config, + model=ModelArgs( + init_method=ExistingCheckpointInit(nanotron_checkpoint_path), + model_config=nanotron_config, + ), + tokenizer=TokenizerArgs(tokenizer_name_or_path=args.pretrained_model_name_or_path), + ) + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) + yaml.dump(config.as_dict(), f) + + with open(nanotron_checkpoint_path / "model_config.json", "w") as f: + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) + json.dump(asdict(nanotron_config), f) + + log_rank( + f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/pyproject.toml b/pyproject.toml index 390c32c4..26171e9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "dacite", "tqdm", "datasets", + "torchtyping" ] [tool.setuptools.packages.find] diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index dd575e39..999d1337 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -38,6 +38,9 @@ class MoEConfig: num_experts: int = 8 # Total number of experts top_k: int = 2 # Number of experts to route each token to + moe_intermediate_size: int = 1408 # Intermediate size of the MoE layer + shared_expert_intermediate_size: int = 5632 # Intermediate size of the shared expert + router_aux_loss_coef: float = 0.01 # Coefficient for the router auxiliary loss layers: List[int] = field( default_factory=lambda: [-1] ) # Indices of layers that use MoE. -1 means all layers. Default is all layers @@ -146,9 +149,9 @@ class Qwen2Config: no_rope_layer: Optional[ int ] = None # Skip rope every no_rope_layer layers (see https://arxiv.org/abs/2501.18795 https://arxiv.org/abs/2305.19466 and Llama4) - _fused_rotary_emb: bool = False - _fused_rms_norm: bool = False - _use_qkv_packed: bool = False + _fused_rotary_emb: bool = True + _fused_rms_norm: bool = True + _use_qkv_packed: bool = True _use_doc_masking: bool = False # MoE configuration diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 48aa941e..40d95119 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -33,8 +33,6 @@ class ParallelismArgs: tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None recompute_layer: bool = False - moe_layer_recompute: bool = False - tp_recompute_allgather: bool = True expert_parallel_size: int = 1 diff --git a/src/nanotron/logging/base.py b/src/nanotron/logging/base.py index b14b94aa..5cde1bb1 100644 --- a/src/nanotron/logging/base.py +++ b/src/nanotron/logging/base.py @@ -229,6 +229,7 @@ def log_rank( rank: Optional[int] = None, category: Optional[str] = None, is_separator: bool = False, + main_rank_only: bool = False, **kwargs, ): """Log only if the current process is the rank specified.""" @@ -246,6 +247,8 @@ def log_rank( kwargs["extra"] = kwargs.get("extra", {}) kwargs["extra"]["separator"] = True + if main_rank_only: + rank = 0 # rank is None means everyone logs if rank is None or dist.get_rank(group) == rank: if is_separator: diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index af26c6da..6bb4d647 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -1,3 +1,4 @@ +import threading from abc import ABCMeta, abstractmethod from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Optional, Tuple @@ -238,7 +239,34 @@ def build_model( return model -# TODO @thomasw21: Should this option override user defined options? Maybe not ... right now it does. +@contextmanager +def ignore_init_on_device_and_dtype(): + """ + A context manager that temporarily disables dtype enforcement from init_on_device_and_dtype. + + Example: + ```python + with init_on_device_and_dtype(device=torch.device("cuda"), dtype=torch.float32): + with ignore_init_on_device_and_dtype(): + # This parameter will keep its specified dtype (float32) + self.weight = nn.Parameter(torch.randn(..., dtype=torch.float32)) + ``` + """ + # Create a thread-local storage for the ignore flag + if not hasattr(ignore_init_on_device_and_dtype, "_ignore_flag"): + ignore_init_on_device_and_dtype._ignore_flag = threading.local() + + # Set the ignore flag + old_value = getattr(ignore_init_on_device_and_dtype._ignore_flag, "value", False) + ignore_init_on_device_and_dtype._ignore_flag.value = True + + try: + yield + finally: + # Restore the previous value + ignore_init_on_device_and_dtype._ignore_flag.value = old_value + + @contextmanager def init_on_device_and_dtype( device: torch.device = torch.device("cpu"), @@ -250,35 +278,30 @@ def init_on_device_and_dtype( device (`torch.device` defaults to `cpu`): Device to initialize all parameters on. dtype (`torch.dtype` defaults to `torch.float`): - Dtype to initialize all parameters on. - include_buffers (`bool`, defaults to `False`): - Whether or not to also default all buffers constructors given previous arguments. - Example: - ```python - import torch.nn as nn - from accelerate import init_on_device - with init_on_device_and_dtype(device=torch.device("cuda")): - tst = nn.Liner(100, 100) # on `cuda` device - ``` + Dtype to initialize all parameters on. If specified, will override any dtype + set in parameter initialization with a warning, unless within an ignore_init_on_device_and_dtype context. """ old_register_parameter = nn.Module.register_parameter old_register_buffer = nn.Module.register_buffer + def should_ignore_init_on_device_and_dtype(): + if not hasattr(ignore_init_on_device_and_dtype, "_ignore_flag"): + return False + return getattr(ignore_init_on_device_and_dtype._ignore_flag, "value", False) + def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: - if isinstance(param, DTypeInvariantTensor): - # if param is DTypeInvariantTensor we should avoid updating it - param.data = param.data.to(device) + if should_ignore_init_on_device_and_dtype(): + pass else: param.data = param.data.to(device, dtype) def register_empty_buffer(module, name, buffer, persistent=True): old_register_buffer(module, name, buffer, persistent=persistent) if buffer is not None: - if isinstance(buffer, DTypeInvariantTensor): - # if buffer is DTypeInvariantTensor we should avoid updating it - buffer.data = buffer.data.to(device) + if should_ignore_init_on_device_and_dtype(): + pass else: module._buffers[name] = module._buffers[name].to(device, dtype) diff --git a/src/nanotron/models/qwen.py b/src/nanotron/models/qwen.py index eee5cba3..8aa6eb46 100644 --- a/src/nanotron/models/qwen.py +++ b/src/nanotron/models/qwen.py @@ -3,7 +3,6 @@ import torch from flash_attn.modules.mha import flash_attn_varlen_kvpacked_func from torch import nn -from torch.nn import functional as F from torch.utils.checkpoint import CheckpointFunction from nanotron import distributed as dist @@ -303,6 +302,7 @@ def __init__( config: Qwen2Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + intermediate_size: int, ) -> None: super().__init__() @@ -312,14 +312,14 @@ def __init__( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - # Define gate_up_proj as a merged layer for gate and up projections gate_up_contiguous_chunks = ( - config.intermediate_size, # shape of gate_linear - config.intermediate_size, # shape of up_linear + intermediate_size, # shape of gate_linear + intermediate_size, # shape of up_linear ) + self.gate_up_proj = TensorParallelColumnLinear( config.hidden_size, - 2 * config.intermediate_size, + 2 * intermediate_size, pg=tp_pg, mode=tp_mode, bias=False, # Qwen2 doesn't use bias for gate_up_proj @@ -330,7 +330,7 @@ def __init__( # Define down projection self.down_proj = TensorParallelRowLinear( - config.intermediate_size, + intermediate_size, config.hidden_size, pg=tp_pg, mode=tp_mode, @@ -355,183 +355,6 @@ def forward(self, hidden_states): return {"hidden_states": hidden_states} -class Qwen2MoELayer(nn.Module): - """Mixture of experts Layer for Qwen2 models.""" - - def __init__( - self, - config: Qwen2Config, - parallel_config: Optional[ParallelismArgs], - tp_pg: dist.ProcessGroup, - layer_idx: int = 0, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - - # MoE specific configurations - self.num_experts = config.moe_config.num_experts # Total number of experts - self.num_experts_per_token = config.moe_config.top_k # Number of experts used per token (top-k) - self.expert_parallel_size = getattr(parallel_config, "expert_parallel_size", 1) - self.num_local_experts = self.num_experts // self.expert_parallel_size # Experts per device - - # Get TP mode configuration - tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - tp_linear_async_communication = ( - parallel_config.tp_linear_async_communication if parallel_config is not None else False - ) - - # Router for selecting experts - self.router = TensorParallelColumnLinear( - self.hidden_size, - self.num_experts, - pg=tp_pg, - mode=tp_mode, - bias=False, - async_communication=tp_linear_async_communication, - ) - - # Enable shared experts if configured - self.enable_shared_expert = getattr(config.moe_config, "enable_shared_expert", False) - if self.enable_shared_expert: - self.shared_expert = Qwen2MLP( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - ) - self.shared_expert_gate = TensorParallelColumnLinear( - self.hidden_size, - 1, - pg=tp_pg, - mode=tp_mode, - bias=False, - async_communication=tp_linear_async_communication, - ) - - # Create the expert MLPs - self.experts = nn.ModuleList( - [ - Qwen2MLP( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - ) - for _ in range(self.num_local_experts) - ] - ) - - # Whether to recompute MoE layer during backward pass for memory efficiency - self.recompute_layer = getattr(parallel_config, "recompute_layer", False) - - # Token dispatcher type - determines communication pattern - self.token_dispatcher_type = getattr(config.moe_config, "token_dispatcher_type", "alltoall") - # For more sophisticated implementations, we would add token dispatcher logic here - - def _compute_router_probabilities(self, hidden_states): - """Compute routing probabilities for each token to each expert.""" - router_logits = self.router(hidden_states) # [batch_size*seq_length, num_experts] - - # Get the top-k experts per token - routing_weights, routing_indices = torch.topk(router_logits, k=self.num_experts_per_token, dim=-1) - - # Apply softmax on the top-k values - routing_weights = F.softmax(routing_weights, dim=-1) - - return routing_weights, routing_indices - - def _dispatch_tokens(self, hidden_states, routing_weights, routing_indices): - """ - Dispatches tokens to their selected experts. - In a full implementation, this would handle the actual token routing logic - including communication between devices. - """ - # Simplified implementation - in a complete version this would handle - # all-to-all or all-gather communications for distributed experts - - hidden_states.shape[0] - dispatched_inputs = [] - expert_counts = [] - - # For each expert, gather the tokens assigned to it - for expert_idx in range(self.num_local_experts): - # Find tokens that have this expert in their top-k - expert_mask = (routing_indices == expert_idx).any(dim=-1) - tokens_for_expert = hidden_states[expert_mask] - - # Get the routing weights for this expert - expert_positions = (routing_indices == expert_idx).nonzero(as_tuple=True) - token_positions, k_positions = expert_positions - expert_weights = routing_weights[token_positions, k_positions].unsqueeze(-1) - - # Scale inputs by routing weights - scaled_inputs = tokens_for_expert * expert_weights - - dispatched_inputs.append(scaled_inputs) - expert_counts.append(len(tokens_for_expert)) - - return dispatched_inputs, expert_counts - - def _combine_expert_outputs(self, expert_outputs, routing_indices, original_shape): - """ - Combines outputs from different experts back to the original tensor layout. - """ - # Initialize output tensor with zeros - combined_output = torch.zeros(original_shape, device=expert_outputs[0].device) - - for expert_idx, expert_output in enumerate(expert_outputs): - if expert_output.shape[0] == 0: # Skip if no tokens were routed to this expert - continue - - # Find positions where this expert was in the top-k - expert_mask = (routing_indices == expert_idx).any(dim=-1) - combined_output[expert_mask] += expert_output - - return combined_output - - def _core_forward(self, hidden_states): - """Core forward logic for MoE layer.""" - # Get router probabilities - routing_weights, routing_indices = self._compute_router_probabilities(hidden_states) - - # Dispatch tokens to experts - dispatched_inputs, expert_counts = self._dispatch_tokens(hidden_states, routing_weights, routing_indices) - - # Process tokens with their assigned experts - expert_outputs = [] - for expert_idx, (inputs, count) in enumerate(zip(dispatched_inputs, expert_counts)): - if count == 0: # Skip computation if no tokens assigned - expert_outputs.append(torch.tensor([], device=hidden_states.device)) - continue - - # Forward through the expert - output = self.experts[expert_idx](hidden_states=inputs)["hidden_states"] - expert_outputs.append(output) - - # Combine expert outputs - output = self._combine_expert_outputs(expert_outputs, routing_indices, hidden_states.shape) - - # Add shared expert contribution if enabled - if self.enable_shared_expert: - shared_expert_output = self.shared_expert(hidden_states=hidden_states)["hidden_states"] - shared_gate = torch.sigmoid(self.shared_expert_gate(hidden_states)) - output = output + shared_gate * shared_expert_output - - return output - - def _checkpointed_forward(self, hidden_states): - """Apply gradient checkpointing to save memory during training.""" - return CheckpointFunction.apply(self._core_forward, True, hidden_states) - - def forward(self, hidden_states): - """Forward pass for the MoE layer.""" - if self.recompute_layer and self.training: - hidden_states = self._checkpointed_forward(hidden_states) - else: - hidden_states = self._core_forward(hidden_states) - - return {"hidden_states": hidden_states} - - class Qwen2DecoderLayer(nn.Module): def __init__( self, @@ -559,6 +382,8 @@ def __init__( # Use MoE layer if this layer is in the MoE layers list if config.moe_config and layer_idx in config.moe_config.layers: + from nanotron.nn.moe import Qwen2MoELayer + self.mlp = Qwen2MoELayer( config=config, parallel_config=parallel_config, @@ -570,6 +395,7 @@ def __init__( config=config, parallel_config=parallel_config, tp_pg=tp_pg, + intermediate_size=config.intermediate_size, ) self.recompute_layer = parallel_config.recompute_layer diff --git a/src/nanotron/nn/moe.py b/src/nanotron/nn/moe.py new file mode 100644 index 00000000..40a296d0 --- /dev/null +++ b/src/nanotron/nn/moe.py @@ -0,0 +1,212 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import CheckpointFunction + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import ParallelismArgs +from nanotron.config.models_config import Qwen2Config +from nanotron.models.base import ignore_init_on_device_and_dtype +from nanotron.nn.activations import ACT2FN + +logger = logging.get_logger(__name__) + + +try: + import grouped_gemm.ops as ops +except ImportError: + raise RuntimeError( + "Grouped GEMM is not available. Please run `pip install --no-build-isolation git+https://github.com/fanshiqing/grouped_gemm@main` (takes less than 5 minutes)" + ) + + +class Router(nn.Module): + def __init__( + self, config: Qwen2Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int + ): + super().__init__() + self.config = config + self.parallel_config = parallel_config + self.tp_pg = tp_pg + self.layer_idx = layer_idx + + self.num_experts = config.moe_config.num_experts + self.num_experts_per_token = config.moe_config.top_k + + # float32 routing weights + # NOTE: qwen keep the routing weights in float32 + # https://github.com/huggingface/transformers/blob/27a25bee4fcb865e8799ba026f1ea4455f2cca98/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L608 + with ignore_init_on_device_and_dtype(): + self.weight = nn.Parameter( + torch.randn(self.num_experts, config.hidden_size, dtype=torch.float32, device="cuda") + ) + assert self.weight.dtype == torch.float32 + + def gating(self, x: torch.Tensor) -> torch.Tensor: + """Compute logits for all experts (no softmax).""" + # NOTE: qwen keep the routing logits in float32 + # https://github.com/huggingface/transformers/blob/27a25bee4fcb865e8799ba026f1ea4455f2cca98/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L613 + return F.linear(x.to(torch.float32), self.weight, bias=None) + + def routing(self, logits: torch.Tensor): + """Top-k softmax-normalized routing weights and indices.""" + routing_weights = F.softmax(logits, dim=-1, dtype=torch.float32) + routing_weights, routing_indices = torch.topk(routing_weights, k=self.num_experts_per_token, dim=-1) + routing_indices = routing_indices.to(torch.int32) # NOTE: ops.permute requires indices to be int32 + return routing_weights, routing_indices + + def forward(self, x: torch.Tensor): + logits = self.gating(x) + return self.routing(logits) + + +class GroupedMLP(nn.Module): + def __init__(self, config: Qwen2Config, parallel_config: Optional[ParallelismArgs]): + super().__init__() + + num_local_experts = config.moe_config.num_experts // parallel_config.expert_parallel_size + self.merged_gate_up_proj = nn.Parameter( + torch.randn(num_local_experts, config.hidden_size, 2 * config.moe_config.moe_intermediate_size) + ) + self.merged_down_proj = nn.Parameter( + torch.randn(num_local_experts, config.moe_config.moe_intermediate_size, config.hidden_size) + ) + self.act = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ): + """ + assume hidden_states is permuted + + grouped_gemm's notes: + ops.gemm expect the inputs to have the following criteria: + + expect a, b are in bfloat16 + + expect num_tokens_per_expert is a on cpu + """ + # NOTE: ops.gemm requires "batch_sizes" (aka: num_tokens_per_expert here) to be on cpu + num_tokens_per_expert = num_tokens_per_expert.to("cpu") + merged_states = ops.gmm(hidden_states, self.merged_gate_up_proj, num_tokens_per_expert, trans_b=False) + gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1) + hidden_states = self.act(gate_states) * up_states + hidden_states = ops.gmm(hidden_states, self.merged_down_proj, num_tokens_per_expert, trans_b=False) + + return {"hidden_states": hidden_states} + + +class Qwen2MoELayer(nn.Module): + """Mixture of experts Layer for Qwen2 models.""" + + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int = 0, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # MoE specific configurations + self.num_experts = config.moe_config.num_experts # Total number of experts + self.num_local_experts = ( + config.moe_config.num_experts // parallel_config.expert_parallel_size + ) # Experts per device + self.num_experts_per_token = config.moe_config.top_k # Number of experts used per token (top-k) + self.expert_parallel_size = parallel_config.expert_parallel_size + self.num_local_experts = self.num_experts // self.expert_parallel_size # Experts per device + + # Get TP mode configuration + + # Router for selecting experts + self.router = Router(config, parallel_config, tp_pg, layer_idx) + + # Enable shared experts if configured + self.enable_shared_expert = config.moe_config.enable_shared_expert + if self.enable_shared_expert: + from nanotron.models.qwen import Qwen2MLP + + self.shared_expert = Qwen2MLP( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + intermediate_size=config.moe_config.shared_expert_intermediate_size, + ) + # TODO: duplicte the shared expert gate + self.shared_expert_gate = nn.Linear( + self.hidden_size, + 1, + bias=False, + ) # TODO: ensure shared_expert_gate is tied across TP + + # Create the expert MLPs + self.experts = GroupedMLP(config, parallel_config) + # Whether to recompute MoE layer during backward pass for memory efficiency + self.recompute_layer = parallel_config.recompute_layer + + def _dispatch_tokens( + self, + hidden_states: torch.Tensor, + routing_indices: torch.Tensor, + ): + """ + Dispatches tokens to their selected experts. + In a full implementation, this would handle the actual token routing logic + including communication between devices. + """ + # NOTE: start from expert 0 to expert n + num_tokens_per_expert = torch.bincount( + routing_indices.flatten(), minlength=self.num_local_experts + ) # [num_local_experts] + dispatched_inputs, inverse_permute_mapping = ops.permute(hidden_states, routing_indices) + return dispatched_inputs, inverse_permute_mapping, num_tokens_per_expert + + def _combine_expert_outputs(self, expert_outputs, inverse_mapping, routing_weights): + """ + Combines outputs from different experts back to the original tensor layout. + """ + hidden_states = ops.unpermute(expert_outputs, inverse_mapping, routing_weights) + return hidden_states + + def _core_forward(self, hidden_states): + """Core forward logic for MoE layer.""" + # Get top-k routing weights and indices + routing_weights, routing_indices = self.router(hidden_states) # [num_tokens, num_experts_per_token] + + # Dispatch tokens to experts + dispatched_inputs, inverse_permute_mapping, num_tokens_per_expert = self._dispatch_tokens( + hidden_states, routing_indices + ) + + expert_outputs = self.experts(dispatched_inputs, num_tokens_per_expert) + + output = self._combine_expert_outputs( + expert_outputs["hidden_states"], inverse_permute_mapping, routing_weights + ) + + # Add shared expert contribution if enabled + if self.enable_shared_expert: + shared_expert_output = self.shared_expert(hidden_states=hidden_states)["hidden_states"] + shared_gate = torch.sigmoid(self.shared_expert_gate(hidden_states)) + output = output + shared_gate * shared_expert_output + + return output + + def _checkpointed_forward(self, hidden_states): + """Apply gradient checkpointing to save memory during training.""" + return CheckpointFunction.apply(self._core_forward, True, hidden_states) + + def forward(self, hidden_states): + """Forward pass for the MoE layer.""" + if self.recompute_layer and self.training: + hidden_states = self._checkpointed_forward(hidden_states) + else: + hidden_states = self._core_forward(hidden_states) + + return {"hidden_states": hidden_states} diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 8107b46e..088551b0 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -182,7 +182,8 @@ def build_grad_buffers( if not param.requires_grad: continue - assert param.dtype != torch.float, f"Expected {name} not to be float" + # MoE router weights are initialized in float32 + assert param.dtype != torch.float or "router.weight" in name, f"Expected {name} not to be float" assert param.is_contiguous(), f"Expected {name} to be contiguous" next_offset = offset + param.numel() * element_size diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index 8f3062a9..8324eccf 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -6,6 +6,7 @@ from nanotron.config import Config, ModelArgs from nanotron.config.models_config import InitScalingMethod from nanotron.nn.layer_norm import LlamaRMSNorm, TritonRMSNorm +from nanotron.nn.moe import GroupedMLP, Router from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -36,10 +37,15 @@ def __init__(self, config: Config): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { TensorParallelColumnLinear: self._parametrize_column_linear, + # TODO: double check if correct initialization for grouped MLP TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, LlamaRMSNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, + # NOTE: MoE's specific initialization + GroupedMLP: self._parametrize_grouped_mlp, + Router: self._parametrize_router, + nn.Linear: self._parametrize_column_linear, } self.std = config.model.init_method.std @@ -57,6 +63,26 @@ def _parametrize_column_linear(self, param_name: str, module: nn.Module): elif "bias" == param_name: module.bias.zero_() + def _parametrize_grouped_mlp(self, param_name: str, module: nn.Module): + for n, p in module.named_parameters(): + if n == "merged_gate_up_proj": + # NOTE: the same as parametrization of column linear + init.normal_(p, mean=0.0, std=self.std) + elif n == "merged_down_proj": + # NOTE: the same as parametrization of row linear + scaling = self._compute_scaling_factor() + adjusted_std = self.std / scaling + # TODO @nouamane: should we use trunc_normal_ + init.normal_(p, mean=0.0, std=adjusted_std) + else: + raise ValueError(f"Unknown parameter {n}") + + def _parametrize_router(self, param_name: str, module: nn.Module): + if "weight" == param_name: + init.normal_(module.weight, mean=0.0, std=self.std) + elif "bias" == param_name: + module.bias.zero_() + def _compute_scaling_factor(self) -> float: """Compute initialization scaling based on selected method""" if self.scaling_method == InitScalingMethod.NONE: diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index a64bb5db..c94e97de 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -1178,7 +1178,7 @@ def setup_log_writers( def pre_save_checkpoint(self) -> Path: # Check if eval_interval should be updated from file - eval_interval_file = self.config.lighteval.eval_interval_file + eval_interval_file = self.config.lighteval.eval_interval_file if self.config.lighteval is not None else None if eval_interval_file is not None and Path(eval_interval_file).exists(): try: with open(eval_interval_file, "r") as f: diff --git a/tests/helpers/qwen_helper.py b/tests/helpers/qwen_helper.py index b333e2a7..7528b9d5 100644 --- a/tests/helpers/qwen_helper.py +++ b/tests/helpers/qwen_helper.py @@ -17,30 +17,44 @@ TokensArgs, ) from nanotron.config.config import PretrainDatasetsArgs +from nanotron.config.models_config import MoEConfig from nanotron.models import build_model from nanotron.models.qwen import Qwen2Config, Qwen2ForTraining from nanotron.parallel.context import ParallelContext from nanotron.trainer import mark_tied_parameters -TINY_QWEN_CONFIG = Qwen2Config( +QWEN_MOE_CONFIG = MoEConfig( + num_experts=8, + top_k=1, + enable_shared_expert=True, + token_dispatcher_type="alltoall", +) + +QWEN_MODEL_CONFIG = { + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 128, + "initializer_range": 0.02, + "intermediate_size": 128 * 4, + "max_position_embeddings": 128, + "num_attention_heads": 4, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "pad_token_id": None, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 4096, + "_attn_implementation": "flash_attention_2", +} + +TINY_QWEN_CONFIG = Qwen2Config(**QWEN_MODEL_CONFIG) +TINY_MOE_QWEN_CONFIG = Qwen2Config( **{ - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 128, - "initializer_range": 0.02, - "intermediate_size": 128 * 4, - "max_position_embeddings": 128, - "num_attention_heads": 4, - "num_hidden_layers": 4, - "num_key_value_heads": 2, - "pad_token_id": None, - "rms_norm_eps": 1e-06, - "rope_theta": 10000.0, - "tie_word_embeddings": False, - "use_cache": True, - "vocab_size": 4096, - "_attn_implementation": "flash_attention_2", + **QWEN_MODEL_CONFIG, + "moe_config": QWEN_MOE_CONFIG, } ) diff --git a/tests/test_moe.py b/tests/test_moe.py new file mode 100644 index 00000000..635757e1 --- /dev/null +++ b/tests/test_moe.py @@ -0,0 +1,37 @@ +import torch +from helpers.qwen_helper import TINY_MOE_QWEN_CONFIG +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.models.base import init_on_device_and_dtype +from nanotron.nn.moe import GroupedMLP + + +def test_grouped_mlp(): + parallel_config = ParallelismArgs( + dp=1, + pp=1, + tp=1, + expert_parallel_size=1, + pp_engine="1f1b", + tp_mode="REDUCE_SCATTER", + tp_linear_async_communication=True, + ) + num_tokens_per_experts = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + NUM_TOKENS = num_tokens_per_experts.sum() + NUM_EXPERTS = TINY_MOE_QWEN_CONFIG.moe_config.num_experts + HIDDEN_SIZE = TINY_MOE_QWEN_CONFIG.hidden_size + permuted_hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device="cuda") + + assert len(num_tokens_per_experts) == NUM_EXPERTS + + with init_on_device_and_dtype(device=torch.device("cuda"), dtype=torch.bfloat16): + grouped_mlp = GroupedMLP(config=TINY_MOE_QWEN_CONFIG, parallel_config=parallel_config) + + output = grouped_mlp(permuted_hidden_states, num_tokens_per_experts) + + assert output["hidden_states"].shape == (NUM_TOKENS, HIDDEN_SIZE) + assert output["hidden_states"].dtype == torch.bfloat16 + assert output["hidden_states"].device.type == "cuda" + + +if __name__ == "__main__": + test_grouped_mlp() From ba2ba8495fb729cc215c845fdd63f5a2ed459d93 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 24 Apr 2025 13:14:43 +0100 Subject: [PATCH 07/10] Nouamane/lighteval (#356) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * InitScalingMethod * InitScalingMethod * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes * add requeue * add wandb with lighteval and fix eval interval * fix this little space :( * folder_path should always have s3 when using s3 (fix consumed tokens issue) * config qwen * . --------- Co-authored-by: elie <97572401+eliebak@users.noreply.github.com> Co-authored-by: “eliebak” --- examples/config_qwen.py | 22 +++++-- examples/config_qwen.yaml | 24 +++---- src/nanotron/config/lighteval_config.py | 15 ++++- src/nanotron/data/tokenized_bytes.py | 7 +- src/nanotron/eval/one_job_runner.py | 40 +++++++++--- src/nanotron/eval/upload_to_wandb.py | 87 +++++++++++++++++++++++++ 6 files changed, 162 insertions(+), 33 deletions(-) create mode 100644 src/nanotron/eval/upload_to_wandb.py diff --git a/examples/config_qwen.py b/examples/config_qwen.py index 8ca8487b..a5d901b2 100644 --- a/examples/config_qwen.py +++ b/examples/config_qwen.py @@ -30,7 +30,7 @@ "410m": (24, 1024, 16, 16, 4096), # ~410M params # Small to medium models "1b": (16, 2048, 16, 16, 5632), # ~1B params - "3b": (28, 2048, 16, 2, 11008), # ~3B params + "3b": (36, 2048, 16, 4, 11008), # ~3B params # Standard sizes "7b": (32, 4096, 32, 32, 11008), # ~7B params "13b": (40, 5120, 40, 40, 13824), # ~13B params @@ -47,7 +47,7 @@ def get_args(): parser.add_argument( "--model", choices=MODEL_SIZES.keys(), - default="custom", + default="3b", help="Model size to generate config for (e.g., 7b, 13b)", ) parser.add_argument( @@ -76,6 +76,10 @@ def get_args(): tokens_group.add_argument("--mbs", type=int, default=3, help="Micro batch size") tokens_group.add_argument("--acc", type=int, default=1, help="Batch accumulation per replica") + # checkpoints + checkpoints_group = parser.add_argument_group("checkpoints") + checkpoints_group.add_argument("--ckpt-save", type=int, default=10, help="Checkpoint save interval") + args = parser.parse_args() return args @@ -108,7 +112,7 @@ def get_model_config(model_size: str) -> Qwen2Config: is_qwen2_config=True, pad_token_id=None, _attn_implementation="flash_attention_2", - # sliding_window_size=20, + _use_doc_masking=True, ) @@ -154,7 +158,7 @@ def calculate_parameters(model_config: Qwen2Config) -> str: def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config: learning_rate = LRSchedulerArgs( - learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 + learning_rate=3e-4, lr_warmup_steps=2000, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=0 ) parallelism = ParallelismArgs( dp=args.dp, @@ -175,7 +179,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config ) optimizer = OptimizerArgs( zero_stage=args.zero, - weight_decay=0.01, + weight_decay=0.1, clip_grad=1.0, accumulate_grad_in_fp32=True, learning_rate_scheduler=learning_rate, @@ -192,7 +196,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config return Config( general=GeneralArgs(project="debug", run=args.run, seed=seed, ignore_sanity_checks=args.no_sanity), - checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=args.ckpt_save), parallelism=parallelism, model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), # tokenizer=TokenizerArgs("HuggingFaceTB/cosmo2-tokenizer"), @@ -219,7 +223,11 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config world_size = args.dp * args.tp * args.pp * args.cp if world_size <= 8: print( - f"CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}" + f"ENABLE_TIMERS=1 DEBUG_CPU=1 STATS_SAMPLING_INTERVAL_IN_SEC=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}" ) + print("You can also use environment variables for more debugging:") + print(" - ENABLE_TIMERS=1: Enable detailed timing information") + print(" - DEBUG_CPU=1: Log CPU and memory usage statistics") + print(" - STATS_SAMPLING_INTERVAL_IN_SEC=1: Set sampling interval for metrics collection") else: print("Checkout slurm_launcher.py to launch a multi-node job") diff --git a/examples/config_qwen.yaml b/examples/config_qwen.yaml index a2ce9bd1..cf6f40fa 100644 --- a/examples/config_qwen.yaml +++ b/examples/config_qwen.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 10 + checkpoint_interval: 100000 checkpoints_path: checkpoints checkpoints_path_is_shared_file_system: false load_lr_scheduler: true @@ -30,9 +30,9 @@ data_stages: general: benchmark_csv_path: null consumed_train_samples: null - ignore_sanity_checks: false + ignore_sanity_checks: true project: debug - run: qwen_20250423_201000_16423158 + run: qwen_20250424_120835_16423158 seed: 42 step: null lighteval: null @@ -50,24 +50,24 @@ model: make_vocab_size_divisible_by: 1 model_config: _attn_implementation: flash_attention_2 - _fused_rms_norm: false - _fused_rotary_emb: false - _use_doc_masking: false - _use_qkv_packed: false + _fused_rms_norm: true + _fused_rotary_emb: true + _use_doc_masking: true + _use_qkv_packed: true attention_bias: false bos_token_id: 1 eos_token_id: 2 flex_attention_mask: null hidden_act: silu - hidden_size: 256 + hidden_size: 2048 initializer_range: 0.02 - intermediate_size: 768 + intermediate_size: 11008 is_qwen2_config: true max_position_embeddings: 4096 moe_config: null no_rope_layer: null - num_attention_heads: 4 - num_hidden_layers: 12 + num_attention_heads: 16 + num_hidden_layers: 36 num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 @@ -108,7 +108,7 @@ parallelism: pp: 1 pp_engine: 1f1b recompute_layer: false - tp: 1 + tp: 2 tp_linear_async_communication: true tp_mode: REDUCE_SCATTER tp_recompute_allgather: true diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 363ee988..0806acff 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -109,8 +109,13 @@ class LightEvalConfig: logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None slurm: Optional[LightEvalSlurm] = None - s3_save_path: Optional[str] = None # should not be dependent of the run_name - output_dir: Optional[str] = None # we should sanity check that it's the same as the one in the eval_config_override + s3_save_path: Optional[str] = None # should not be dependent of the run_name + upload_to_wandb: Optional[bool] = False + wandb_project: Optional[str] = None + wandb_entity: Optional[str] = None + output_dir: Optional[ + str + ] = None # we should sanity check that it's the same as the one in the eval_config_override nanotron_path: Optional[str] = "./" eval_config_override: str = None eval_config_override: Path = None # Previously hardcoded in run_slurm_one_job @@ -127,6 +132,12 @@ def __post_init__(self): if self.slurm is None: self.slurm = LightEvalSlurm() self.local_checkpoint_dir = str(Path(self.local_checkpoint_dir).expanduser()) + if self.upload_to_wandb: + assert ( + self.s3_save_path is not None + ), " We should have a s3_save_path if we want to upload to wandb" # todo: add the option to read from local folder i guess + assert self.wandb_project is not None, "wandb_project must be specified if upload_to_wandb is True" + assert self.wandb_entity is not None, "wandb_entity must be specified if upload_to_wandb is True" if self.eval_interval_file is not None and Path(self.eval_interval_file).exists(): logger.warning( f"Eval interval file {self.eval_interval_file} exists. `eval_interval` will be replaced by the value in the file upon the next evaluation. You should probably delete this file if that's not what you want." diff --git a/src/nanotron/data/tokenized_bytes.py b/src/nanotron/data/tokenized_bytes.py index 4f9063eb..880983bb 100644 --- a/src/nanotron/data/tokenized_bytes.py +++ b/src/nanotron/data/tokenized_bytes.py @@ -369,12 +369,13 @@ def __init__( ) from datatrove.utils.dataset import url_to_fs - fs_folder, folder_path = url_to_fs(folder_path) + fs_folder, stripped_folder_path = url_to_fs(folder_path) matched_files = ( - fs_folder.find(folder_path, detail=False, maxdepth=1 if not recursive else None) + fs_folder.find(stripped_folder_path, detail=False, maxdepth=1 if not recursive else None) if not filename_pattern else fs_folder.glob( - os.path.join(folder_path, filename_pattern), maxdepth=1 if not recursive else None + os.path.join(stripped_folder_path, filename_pattern), + maxdepth=1 if not recursive else None, ) ) matched_files = sorted(matched_files) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 43d1a765..6567ec94 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -60,13 +60,18 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: logger.warning( f"Lighteval Runner got {len(uploaded_files)} files. Using {checkpoint_path} as checkpoint path." ) - - slurm_job_id, slurm_log = run_slurm_one_job( - config=self.config, - lighteval_config=self.lighteval_config, - model_checkpoint_path=checkpoint_path, - current_step=self.config.general.step, - ) + if self.config.general.step % self.lighteval_config.eval_interval == 0: + slurm_job_id, slurm_log = run_slurm_one_job( + config=self.config, + lighteval_config=self.lighteval_config, + model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + ) + else: + logger.warning( + f"Skipping evaluation at step {self.config.general.step} because it's not a multiple of {self.lighteval_config.eval_interval}" + ) + return None, None return slurm_job_id, slurm_log @@ -130,7 +135,8 @@ def run_slurm_one_job( #SBATCH --exclusive #SBATCH --qos={slurm_config.qos} #SBATCH --time={slurm_config.time} -#SBATCH --output={eval_logs_path}/%j-{timestamp}.out""" +#SBATCH --output={eval_logs_path}/%j-{timestamp}.out +#SBATCH --requeue""" if slurm_config.reservation: slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" @@ -250,7 +256,23 @@ def run_slurm_one_job( --cache-dir {slurm_config.hf_cache}""" if lighteval_config.output_dir is not None and lighteval_config.s3_save_path is not None: slurm_script += f""" -s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path} +s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path}/ +""" + if lighteval_config.upload_to_wandb: + gbs_tok = ( + config.parallelism.dp + * config.tokens.micro_batch_size + * config.tokens.sequence_length + * config.tokens.batch_accumulation_per_replica + ) + slurm_script += f""" +python {nanotron_path}/src/nanotron/eval/upload_to_wandb.py \\ + --wandb_project {lighteval_config.wandb_project} \\ + --wandb_entity {lighteval_config.wandb_entity} \\ + --model_name {general_run_name} \\ + --results_path {lighteval_config.s3_save_path}/results/results/{general_run_name}/{current_step}/ \\ + --train_step {current_step} \\ + --consumed_tokens {current_step*gbs_tok} """ slurm_script += """ echo "Cleaning up downloaded checkpoints..." diff --git a/src/nanotron/eval/upload_to_wandb.py b/src/nanotron/eval/upload_to_wandb.py new file mode 100644 index 00000000..aa8c12d4 --- /dev/null +++ b/src/nanotron/eval/upload_to_wandb.py @@ -0,0 +1,87 @@ +import json +import s3fs +import wandb +import re +import argparse +from wandb.sdk.lib.runid import generate_id + + +def push_to_wandb(wandb_project, wandb_entity, model_name, results_path, train_step, consumed_tokens): + s3 = s3fs.S3FileSystem(anon=False) + all_metrics = { + # basic X axis replacements for all metrics + "consumed_tokens": consumed_tokens, + "train_step": train_step, + } + + for result_file in sorted(s3.ls(results_path)): + if not result_file.endswith(".json"): + continue + + with s3.open(result_file, "r") as f: + results = json.loads(f.read())["results"] + + for benchmark, metrics in results.items(): + if benchmark == "all": + continue + + # extract dataset and config name + match = re.search(r"\|(.*?)(?::(.*?))?\|", benchmark) + if match: + dataset, subtask = match.groups() + + for metric_name, metric_value in metrics.items(): + if "_stderr" in metric_name: + continue + # wandb-friendly metric name + wandb_metric = f"{dataset}/{subtask}/{metric_name}" if subtask else f"{dataset}/{metric_name}" + all_metrics[wandb_metric] = metric_value + + run_id = f"{model_name}-{generate_id()}" + + # try to find the run in wandb and resume it + api = wandb.Api() + runs = api.runs(f"{wandb_entity}/{wandb_project}") + for run in runs: + if run.name == model_name: + run_id = run.id + break + + wandb.init( + project=wandb_project, + entity=wandb_entity, + name=model_name, + id=run_id, + config={ + "model_name": model_name, + }, + resume="allow", + ) + + # log all metrics for this checkpoint + wandb.log(all_metrics) + + wandb.finish() + +if __name__ == "__main__": + # Setup argument parser + parser = argparse.ArgumentParser(description="Upload evaluation results to Weights & Biases.") + parser.add_argument("--wandb_project", type=str, required=True, help="WandB project name.") + parser.add_argument("--wandb_entity", type=str, required=True, help="WandB entity name.") + parser.add_argument("--model_name", type=str, required=True, help="Name of the model.") + parser.add_argument("--results_path", type=str, required=True, help="S3 path to the results directory.") + parser.add_argument("--train_step", type=int, required=True, help="Training step corresponding to the checkpoint.") + parser.add_argument("--consumed_tokens", type=int, required=True, help="Total consumed tokens up to this checkpoint.") + + # Parse arguments + args = parser.parse_args() + + # Call the main function with parsed arguments + push_to_wandb( + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, + model_name=args.model_name, + results_path=args.results_path, + train_step=args.train_step, + consumed_tokens=args.consumed_tokens + ) From f66c1a0fb3d96c9f3e20fe6d3f43ed9501d4fb72 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 23 Jun 2025 18:01:31 +0200 Subject: [PATCH 08/10] =?UTF-8?q?SmoLM3=20training=20=F0=9F=9A=80=20(#375)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * InitScalingMethod * InitScalingMethod * eval * try adding lightevalrunner to trainer * amend * amend * amend * amend * amend * amend * . * amend * amend * . * qos to low * add nanotron_path * some fix: logs, and config * cp instead of sync * eval_interval * serialize sanity checks * add output dir and s3_save path in the config * fix s3 only if define * fixes * add requeue * add wandb with lighteval and fix eval interval * fix this little space :( * folder_path should always have s3 when using s3 (fix consumed tokens issue) * fix resuming with new data mixture * offsets must be in samples not tokens * sanity check local files when dataset_read_path * better error for new stage * rmsnorm * sliding window * causal SWA * Revert "rmsnorm" This reverts commit 17dad0a2f765af4f831260108002fab5e8b7c924. * rope_seq_len_interpolation_factor * logmixin for intermediate tensors + CP + consumed_token shenanigans when resuming training (#365) * logmixin * context parallelism (llama3 ring attn) + consumed_token shenanigans (#366) * training works * llama3 ring attn * llama3 ring attn * llama3 ring attn * fix position_ids (make them global) * rope_seq_len_interpolation_factor assert * . * . * fix rope and cp_pg * fixed consumed_tokens log --------- Co-authored-by: elie <97572401+eliebak@users.noreply.github.com> Co-authored-by: “eliebak” --- run_train.py | 125 ++- src/nanotron/config/config.py | 14 +- src/nanotron/config/models_config.py | 10 + src/nanotron/constants.py | 2 +- src/nanotron/data/clm_collator.py | 4 +- .../data/nemo_dataset/blendable_dataset.py | 30 +- src/nanotron/data/tokenized_bytes.py | 52 +- src/nanotron/helpers.py | 12 +- src/nanotron/logging/__init__.py | 3 + src/nanotron/logging/logmixin.py | 83 ++ src/nanotron/models/base.py | 4 +- src/nanotron/models/qwen.py | 346 ++++++-- src/nanotron/nn/attention.py | 3 +- src/nanotron/nn/llama3_ring_attention.py | 810 ++++++++++++++++++ src/nanotron/nn/rotary.py | 79 ++ src/nanotron/optim/gradient_accumulator.py | 13 +- src/nanotron/parallel/context.py | 9 + src/nanotron/sanity_checks.py | 18 +- src/nanotron/serialize/main.py | 8 +- src/nanotron/serialize/metadata.py | 49 +- src/nanotron/serialize/optimizer.py | 4 +- src/nanotron/serialize/weights.py | 5 +- src/nanotron/trainer.py | 98 ++- 23 files changed, 1572 insertions(+), 209 deletions(-) create mode 100644 src/nanotron/logging/logmixin.py create mode 100644 src/nanotron/nn/llama3_ring_attention.py diff --git a/run_train.py b/run_train.py index d00ef211..35bc62da 100644 --- a/run_train.py +++ b/run_train.py @@ -41,7 +41,8 @@ from nanotron.trainer import DistributedTrainer from nanotron.utils import main_rank_first from torch.utils.data import DataLoader - +from nanotron.trainer import DataStageMetadata +from collections import defaultdict try: from huggingface_hub import __version__ as hf_hub_version from transformers import AutoTokenizer @@ -60,8 +61,9 @@ def get_dataloader_from_data_stage( trainer: DistributedTrainer, data: DataArgs, - consumed_train_samples: int, + consumed_train_samples_stage: int, consumed_tokens_per_dataset_folder: Dict[str, int], + last_stages_consumed_tokens_per_dataset_folder: Dict[str, int], num_remaining_train_steps: int, sanity_check_dataloader_interval: Optional[int] = None, ): @@ -69,10 +71,11 @@ def get_dataloader_from_data_stage( Returns a dataloader for a given data stage. data: The data configuration for the current stage. - consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero). + consumed_train_samples_stage: The number of samples consumed by the model in the this stage (each stage starts from zero). + consumed_tokens_per_dataset_folder: The number of tokens consumed by the model in previous stages to avoid reseeing them, because the sampler has restarted for this stage. num_remaining_train_steps: The number of remaining training steps for this stage. """ - assert consumed_train_samples >= 0, "consumed_train_samples should be greater than 0" + assert consumed_train_samples_stage >= 0, "consumed_train_samples_stage should be greater than 0" assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0" # First, we need to know which ranks to feed the dataloader to @@ -164,7 +167,7 @@ def get_dataloader_from_data_stage( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, micro_batch_size=trainer.micro_batch_size, - consumed_train_samples=consumed_train_samples, + consumed_train_samples_stage=consumed_train_samples_stage, dataloader_num_workers=data.num_loading_workers, seed_worker=data.seed, dataloader_drop_last=True, @@ -198,7 +201,8 @@ def get_dataloader_from_data_stage( level=logging.INFO, rank=0, ) - tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" start_time = time.time() @@ -207,11 +211,14 @@ def get_dataloader_from_data_stage( global_batch_size=trainer.global_batch_size, sequence_length=trainer.sequence_length, train_steps=trainer.config.tokens.train_steps, + current_iteration=trainer.iteration_step, parallel_context=trainer.parallel_context, shuffle=data.dataset.shuffle_files, eos_token_id=tokenizer.eos_token_id, seed=data.seed, + consumed_samples=consumed_train_samples_stage, consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, + last_stages_consumed_tokens_per_dataset_folder=last_stages_consumed_tokens_per_dataset_folder, ) dataloader = get_tb_dataloader( dataset=train_dataset, @@ -220,8 +227,8 @@ def get_dataloader_from_data_stage( global_batch_size=trainer.global_batch_size, num_workers=data.num_loading_workers, cfg=data.dataset, - consumed_samples=consumed_train_samples, - num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + consumed_samples=consumed_train_samples_stage, + num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, # TODO: this overshoots what's needed by the current stage, but it doesnt matter? parallel_context=trainer.parallel_context, input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, @@ -315,46 +322,72 @@ def get_dataloader( full_log_message = f"There are {len(trainer.config.data_stages)} training stages \n{stages_info}" log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) - for stage_idx, stage in enumerate(trainer.config.data_stages): - # NOTE: we only create the dataloader for the first stage, - # then we lazy initialize the dataloader for the other stages - stage = cast(DatasetStageArgs, stage) - ( - consumed_train_samples, - consumed_tokens_per_dataset_folder, - ) = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata) - - num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( - stage, trainer.config, trainer.metadata - ) - log_rank( - f"Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples" - f"Consumed tokens per dataset folder: {pformat(consumed_tokens_per_dataset_folder)}", - logger=logger, - level=logging.INFO, - rank=0, - ) + current_stage = None + # WARNING: we assume we train on last stage + stage_idx = len(trainer.config.data_stages) - 1 + stage_args = trainer.config.data_stages[stage_idx] + if trainer.iteration_step+1 == stage_args.start_training_step: + log_rank(f"Starting new stage {stage_args.name}", logger=logger, level=logging.INFO, rank=0) + # we start a new stage + if stage_idx >= len(trainer.metadata.data_stages): + trainer.metadata.data_stages.append(DataStageMetadata( + name=stage_args.name, + start_training_step=stage_args.start_training_step, + consumed_train_samples=0, + consumed_tokens_per_dataset_folder={}, + sequence_length=trainer.sequence_length, + )) + elif len(trainer.metadata.data_stages) < len(trainer.config.data_stages): + raise ValueError(f"If you're trying to start a new stage, you need to set `start_training_step` to the step after the last stage's: {trainer.iteration_step+1}") + current_stage = trainer.metadata.data_stages[stage_idx] + cur_stage_consumed_train_samples = current_stage.consumed_train_samples + consumed_tokens_per_dataset_folder = current_stage.consumed_tokens_per_dataset_folder + stage_args_data = trainer.config.data_stages[stage_idx].data + + num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( + current_stage, trainer.config, trainer.metadata + ) # TODO: check this + log_rank( + f"Current stage: {current_stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {cur_stage_consumed_train_samples} samples" + f"Consumed tokens per dataset folder: {pformat(consumed_tokens_per_dataset_folder)}", + logger=logger, + level=logging.INFO, + rank=0, + ) - dataloader = ( - get_dataloader_from_data_stage( - trainer, - stage.data, - consumed_train_samples=consumed_train_samples, - consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, - num_remaining_train_steps=num_remaining_train_steps, - sanity_check_dataloader_interval=sanity_check_dataloader_interval, + # warn that if seqlen of stage - 1 has changed, consumed_train_samples=0 so we'll assume we're reading from new folder (so that we can resume training) + if current_stage.sequence_length != trainer.metadata.data_stages[-1].sequence_length: + raise NotImplementedError("We don't support changing sequence length between stages yet") + if current_stage.consumed_train_samples == 0: + log_rank( + f"Warning: The sequence length of the last stage has changed from {trainer.metadata.data_stages[-1].sequence_length} to {current_stage.sequence_length}. We'll assume we're reading from the beginning of the dataset folders.", + logger=logger, + level=logging.WARNING, + rank=0, ) - if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage( - trainer, - stage.data, - consumed_train_samples=consumed_train_samples, - consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, - num_remaining_train_steps=num_remaining_train_steps, - sanity_check_dataloader_interval=sanity_check_dataloader_interval, - ) - ) - dataloaders[stage.name] = dataloader + else: + # we're resuming training, so that's fine + pass + cur_stage_consumed_train_samples = current_stage.consumed_train_samples + + else: + # Prepare last_stages_consumed_tokens_per_dataset_folder which will be used to offset BlendableDataset to avoid reseeing consumed tokens even when sampler has restarted for this stage + last_stages_consumed_tokens_per_dataset_folder = {} + for stage in trainer.metadata.data_stages[:-1]: + for folder_path, consumed_tokens in stage.consumed_tokens_per_dataset_folder.items(): + last_stages_consumed_tokens_per_dataset_folder[folder_path] = last_stages_consumed_tokens_per_dataset_folder.get(folder_path, 0) + consumed_tokens + + + + dataloaders[current_stage.name] = get_dataloader_from_data_stage( + trainer, + stage_args_data, + consumed_train_samples_stage=cur_stage_consumed_train_samples, + consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, + last_stages_consumed_tokens_per_dataset_folder=last_stages_consumed_tokens_per_dataset_folder, + num_remaining_train_steps=num_remaining_train_steps, + sanity_check_dataloader_interval=sanity_check_dataloader_interval, + ) return dataloaders diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index c16f076c..6cb9e137 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -27,6 +27,7 @@ from nanotron.logging import get_logger, human_format from nanotron.parallel.pipeline_parallel.engine import PipelineEngine from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.config.models_config import Qwen2Config logger = get_logger(__name__) @@ -226,6 +227,7 @@ class DatasetStageArgs: name: str start_training_step: int data: DataArgs + sequence_length: Optional[int] = None # if None, we use the sequence length from the config def __post_init__(self): if self.start_training_step < 0: @@ -272,7 +274,7 @@ class GeneralArgs: run: Optional[str] = None seed: Optional[int] = None step: Optional[int] = None - consumed_train_samples: Optional[int] = None + consumed_train_samples: Optional[int] = None # TODO: remove this benchmark_csv_path: Optional[Path] = None ignore_sanity_checks: bool = True @@ -300,7 +302,6 @@ class ProfilerArgs: with_stack: bool = True export_chrome_trace: bool = False - @dataclass class ModelArgs: """Arguments related to model architecture""" @@ -317,6 +318,9 @@ def __post_init__(self): if isinstance(self.dtype, str): self.dtype = cast_str_to_torch_dtype(self.dtype) + if isinstance(self.model_config, dict): + self.model_config = Qwen2Config(**self.model_config) + self.model_config._is_using_mup = isinstance(self.init_method, SpectralMupInit) # if self.model_config.max_position_embeddings is None: @@ -542,6 +546,12 @@ def __post_init__(self): self.model.model_config.num_attention_heads % self.model.model_config.num_key_value_heads == 0 ), f"num_attention_heads ({self.model.model_config.num_attention_heads}) must be divisible by num_key_value_heads ({self.model.model_config.num_key_value_heads})" + # data_stages + if self.data_stages is not None: + for stage in self.data_stages: + if stage.sequence_length is None: + stage.sequence_length = self.tokens.sequence_length + @property def global_batch_size(self): return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 999d1337..a74630d9 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -137,6 +137,7 @@ class Qwen2Config: rope_scaling: Optional[dict] = None rope_theta: float = 10000.0 rope_interleaved: bool = False + rope_seq_len_interpolation_factor: Optional[float] = None # if not None, discrete positions will be interpolated by this factor via the trick in https://arxiv.org/abs/2306.15595 tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 @@ -154,6 +155,9 @@ class Qwen2Config: _use_qkv_packed: bool = True _use_doc_masking: bool = False + log_attn_probs: bool = True # Whether to log the attention probabilities + ring_attn_heads_k_stride: Optional[int] = None # Stride of the heads in the key tensor for llama3 ring attention + # MoE configuration moe_config: Optional[MoEConfig] = None @@ -181,6 +185,7 @@ def __post_init__(self): assert self._attn_implementation in [ "flex_attention", "flash_attention_2", + "llama3_ring_attention", ], "Sliding window is only supported for Flex Attention and Flash Attention 2" if self.flex_attention_mask is not None: assert ( @@ -196,6 +201,11 @@ def __post_init__(self): self.num_hidden_layers % self.no_rope_layer == 0 ), "no_rope_layer must be a multiple of num_hidden_layers" + if self._attn_implementation == "llama3_ring_attention": + assert self.ring_attn_heads_k_stride is not None, "ring_attn_heads_k_stride must be specified for llama3 ring attention" + else: + assert self.ring_attn_heads_k_stride is None, f"ring_attn_heads_k_stride must be None for non-llama3 ring attention, got attn_implementation={self._attn_implementation}" + @property def is_using_mup(self) -> bool: return self._is_using_mup diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 580bd99d..b31142de 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -2,7 +2,7 @@ from packaging.version import Version, parse -CHECKPOINT_VERSION = Version("1.4") +CHECKPOINT_VERSION = Version("1.5") PY_VERSION = parse(platform.python_version()) diff --git a/src/nanotron/data/clm_collator.py b/src/nanotron/data/clm_collator.py index 89fd0083..8886e312 100644 --- a/src/nanotron/data/clm_collator.py +++ b/src/nanotron/data/clm_collator.py @@ -149,6 +149,7 @@ class DataCollatorForCLMWithPositionIds: output_pp_rank: int parallel_context: ParallelContext use_doc_masking: bool = True + cp_return_global_position_ids: bool = True def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Process the case when current rank doesn't require data @@ -217,7 +218,8 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni cp_rank * self.sequence_length // cp_size, (cp_rank + 1) * self.sequence_length // cp_size ) result["input_ids"] = result["input_ids"][:, local_slice] # (b, s/cp_size) - result["positions"] = result["positions"][:, local_slice] # (b, s/cp_size) + if not self.cp_return_global_position_ids: + result["positions"] = result["positions"][:, local_slice] # (b, s/cp_size) result["position_ids"] = result.pop("positions") # Process labels diff --git a/src/nanotron/data/nemo_dataset/blendable_dataset.py b/src/nanotron/data/nemo_dataset/blendable_dataset.py index e4e99916..96c49fe0 100644 --- a/src/nanotron/data/nemo_dataset/blendable_dataset.py +++ b/src/nanotron/data/nemo_dataset/blendable_dataset.py @@ -25,6 +25,7 @@ from nanotron.logging import log_rank from nanotron.parallel import ParallelContext from nanotron.utils import main_rank_first +from pprint import pformat if TYPE_CHECKING: from . import GPTDataset, SubsetSplitLog @@ -48,6 +49,7 @@ def __init__( parallel_context: ParallelContext, seed: int, consumed_tokens_per_dataset_folder: Optional[Dict[str, int]] = None, + offsets_in_samples: Optional[Dict[str, int]] = None, ): self.datasets = datasets num_datasets = len(datasets) @@ -114,16 +116,26 @@ def __init__( # self.last_item_idx = np.full(self.history_size, -1, dtype=np.int64) # Initialize consumption tracking - self.consumed_tokens = {idx: 0 for idx in range(len(datasets))} + self.consumed_tokens = {idx: 0 for idx in range(len(datasets))} # current stage's consumed_tokens_per_dataset_folder if consumed_tokens_per_dataset_folder is not None: # find idx of dataset that matches the folder path for idx, dataset in enumerate(datasets): for folder_path, consumed_tokens in consumed_tokens_per_dataset_folder.items(): if dataset.folder_path == folder_path: self.consumed_tokens[idx] = consumed_tokens - break + log_rank(f"[BlendableDataset] Setting consumed_tokens for dataset {idx} ({dataset.folder_path}) to {consumed_tokens}", logger=logger, level=logging.INFO, rank=0) + self.sequence_length = None # Will be set when first batch is processed + # Setup offsets for already consumed tokens from previous stages + self.offsets_in_samples = {idx: 0 for idx in range(len(datasets))} # last stage's consumed_tokens_per_dataset_folder + if offsets_in_samples is not None: + for idx, dataset in enumerate(datasets): + for folder_path, offset in offsets_in_samples.items(): + if dataset.folder_path == folder_path: + self.offsets_in_samples[idx] = offset + log_rank(f"[BlendableDataset] Applying offset {offset} samples to dataset {idx} ({dataset.folder_path})", logger=logger, level=logging.INFO, rank=0) + def __len__(self): return self.size @@ -131,16 +143,7 @@ def __getitem__(self, idx): dataset_idx = self.dataset_index[idx] sample_idx = self.dataset_sample_index[idx] - # Shift history arrays and add new values at the end - # self.last_item_idx = np.roll(self.last_item_idx, -1) - # self.last_dataset_idx = np.roll(self.last_dataset_idx, -1) - # self.last_dataset_sample_idx = np.roll(self.last_dataset_sample_idx, -1) - - # self.last_item_idx[-1] = idx - # self.last_dataset_idx[-1] = dataset_idx - # self.last_dataset_sample_idx[-1] = sample_idx - - return self.datasets[dataset_idx][sample_idx] + return self.datasets[dataset_idx][sample_idx + self.offsets_in_samples[dataset_idx]] # TODO: is it okay to not respect dataset_sample_index? Since it's sequential it's okay for now # @property # def last_file_idx(self): @@ -181,6 +184,9 @@ def get_consumption_stats(self): """ stats = {} for dataset_idx, dataset in enumerate(self.datasets): + assert ( + "s3" in dataset.folder_path + ), "Only S3 paths are supported for consumption stats" # TODO: remove this stats[dataset.folder_path] = {"tokens": self.consumed_tokens[dataset_idx]} return stats diff --git a/src/nanotron/data/tokenized_bytes.py b/src/nanotron/data/tokenized_bytes.py index 880983bb..3b8a5437 100644 --- a/src/nanotron/data/tokenized_bytes.py +++ b/src/nanotron/data/tokenized_bytes.py @@ -506,7 +506,7 @@ def build_dataset( seq_len=seq_length, recursive=False, token_size=token_size, - max_tokens=max_tokens, + max_tokens=max_tokens, # TODO: remove shuffle=shuffle, return_positions=return_positions, # if set to True, the position ids are directly read from datatrove eos_token_id=eos_token_id, @@ -522,11 +522,14 @@ def get_tb_datasets( sequence_length: int, global_batch_size: int, train_steps: int, + current_iteration: int, parallel_context: ParallelContext, eos_token_id: Optional[int] = None, shuffle: bool = False, seed: int = 6, + consumed_samples: int = 0, consumed_tokens_per_dataset_folder: Optional[Dict[str, int]] = None, + last_stages_consumed_tokens_per_dataset_folder: Optional[Dict[str, int]] = None, ) -> Tuple[DataLoader, TrainDataLog]: """Build TokenizedBytes datasets @@ -542,6 +545,7 @@ def get_tb_datasets( if dataset_max_tokens is None: dataset_max_tokens = [None] * len(config.dataset_folder) train_num_samples = train_steps * global_batch_size + last_stages_consumed_samples_per_dataset_folder = {k: v // sequence_length for k, v in last_stages_consumed_tokens_per_dataset_folder.items()} datasets = [ build_dataset( @@ -561,6 +565,49 @@ def get_tb_datasets( for i, (dataset_folder, max_tokens) in enumerate(zip(config.dataset_folder, dataset_max_tokens)) ] + # in case of dataset_read_path check we have enough files locally for the training + if config.dataset_read_path: + + weights = config.dataset_weights + if not weights: + weights = [1] * len(datasets) + + # Normalize weights + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # check we have enough files locally for the training + for i, dataset in enumerate(datasets): + # warmup datasets + estimate_current_sample = int(consumed_samples * weights[i]) + last_stages_consumed_samples_per_dataset_folder.get(dataset.folder_path, 0) + _ = dataset[estimate_current_sample] + # print which file we're currently reading from + log_rank(f"Dataset {i} ({dataset.folder_path}) is reading from file {dataset.current_file_path}", logger=logger, level=logging.INFO, rank=0) + # estimate number of tokens needed for this dataset + needed_num_samples_dataset = int((train_steps - current_iteration) * global_batch_size * weights[i]) + needed_num_tokens_dataset = needed_num_samples_dataset * sequence_length + needed_size_tokens_dataset = human_format(needed_num_tokens_dataset * config.token_size_in_bytes) + log_rank(f"Dataset {i} ({dataset.folder_path}) needs {needed_num_tokens_dataset} tokens (size: {needed_size_tokens_dataset}) for current stage", logger=logger, level=logging.INFO, rank=0) + + # NOTE: let's assume that s3 folder keep the same old files when resuming + # check that sum of lens of files in dataset is greater than needed_num_samples_dataset (use dataset.lens) + total_num_samples_dataset = int(train_steps * global_batch_size * weights[i]) + log_rank(f"Dataset {i} ({dataset.folder_path}) on s3 has {len(dataset) * sequence_length} tokens (size: {human_format(len(dataset) * sequence_length * config.token_size_in_bytes)}) and needs {total_num_samples_dataset * sequence_length} tokens (size: {human_format(total_num_samples_dataset * sequence_length * config.token_size_in_bytes)}) for all stages", logger=logger, level=logging.INFO, rank=0) + assert total_num_samples_dataset <= len(dataset), f"Not enough files on s3 for dataset {i} ({dataset.folder_path})" + # check that local files exist for the needed_num_samples_dataset + estimate_end_sample = estimate_current_sample + needed_num_samples_dataset + for file_idx, file in enumerate(dataset.files): + # intersection [start_sample, end_sample] with [dataset.lens[file_idx], dataset.lens[file_idx+1]] + a, b, c, d = estimate_current_sample, estimate_end_sample, dataset.lens[file_idx], dataset.lens[file_idx+1] + if max(a, c) < min(b, d): # ranges overlap + assert os.path.exists(file.file_path), f"Dataset {i} ({dataset.folder_path}) will need file {file.file_path} but it does not exist" + log_rank(f"Dataset {i} ({dataset.folder_path}) will need file {file.file_path} from sample {max(a, c)} to {min(b, d)} (offset: {last_stages_consumed_samples_per_dataset_folder.get(dataset.folder_path, 0)})", logger=logger, level=logging.INFO, rank=0) + else: + log_rank(f"Dataset {i} ({dataset.folder_path}) will not need file {file.file_path} to train from sample {estimate_current_sample} to {estimate_end_sample} (offset: {last_stages_consumed_samples_per_dataset_folder.get(dataset.folder_path, 0)})", logger=logger, level=logging.INFO, rank=0) + + if len(datasets) == 1 and False: outputs_dataset = datasets[0] else: @@ -583,6 +630,7 @@ def get_tb_datasets( parallel_context=parallel_context, seed=seed, consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, + offsets_in_samples=last_stages_consumed_samples_per_dataset_folder, ) log_rank("Streamable datasets ready.", logger=logger, level=logging.INFO, rank=0) @@ -626,7 +674,7 @@ def get_tb_dataloader( dataset = EmptyInfiniteDataset(length=len(dataset)) log_rank( - f"Building dataloader with consumed samples: {consumed_samples}", logger=logger, level=logging.INFO, rank=0 + f"Building dataloader with consumed samples for current datastage: {consumed_samples}", logger=logger, level=logging.INFO, rank=0 ) # Megatron sampler # batch_sampler = MegatronPretrainingRandomSampler( diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index e83355e1..c8df3d21 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -8,7 +8,7 @@ from datetime import datetime from functools import partial from math import ceil -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -312,10 +312,9 @@ def merge_named_param_groups( return named_param_groups - def init_optimizer_and_grad_accumulator( parametrization_method: ParametrizationMethod, - model: nn.Module, + model: Union[nn.Module, DistributedDataParallel], optimizer_args: OptimizerArgs, parallel_context: ParallelContext, ) -> Tuple[BaseOptimizer, GradientAccumulator]: @@ -446,13 +445,13 @@ def grad_optimizer_builder(named_param_groups): assert isinstance(grad_accumulator, FP32GradientAccumulator) model.register_comm_hook( state=FP32GradBucketManager( - dp_pg=parallel_context.dp_pg, + dp_cp_pg=parallel_context.dp_cp_pg, accumulator=grad_accumulator, param_id_to_name={ id(param): param.get_tied_info().get_full_name_from_module_id_to_prefix( module_id_to_prefix=module_id_to_prefix ) - if param.is_tied + if param.is_tied # a tied param exists only once physically in memory else name for name, param in unwrapped_model.named_parameters() }, @@ -818,7 +817,7 @@ def is_resume_from_training(): return 0 else: last_train_steps = metadata.last_train_step if is_resume_from_training() else stage.start_training_step - return total_train_steps - last_train_steps + return total_train_steps - last_train_steps + 1 def get_consumed_train_samples_of_a_data_stage_from_ckp( @@ -826,6 +825,7 @@ def get_consumed_train_samples_of_a_data_stage_from_ckp( ) -> Optional[int]: start_training_step = stage.start_training_step + # find the stage in the metadata using the start_training_step actual_stage = next( (s for s in metadata.data_stages if s.start_training_step == start_training_step), None, diff --git a/src/nanotron/logging/__init__.py b/src/nanotron/logging/__init__.py index 5529e9f6..c36e8fd1 100644 --- a/src/nanotron/logging/__init__.py +++ b/src/nanotron/logging/__init__.py @@ -29,6 +29,7 @@ # Export timer functionality from nanotron.logging.timers import TimerRecord, Timers, nanotron_timer +from nanotron.logging.logmixin import LogMixin, LoggingCollectorMixin __all__ = [ "CRITICAL", @@ -56,4 +57,6 @@ "TimerRecord", "Timers", "nanotron_timer", + "LogMixin", + "LoggingCollectorMixin", ] diff --git a/src/nanotron/logging/logmixin.py b/src/nanotron/logging/logmixin.py new file mode 100644 index 00000000..13c218e9 --- /dev/null +++ b/src/nanotron/logging/logmixin.py @@ -0,0 +1,83 @@ +import torch +from torch import nn +from typing import Dict, List, Any + +class LogMixin: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._tbi_logs: List[Dict[str, torch.Tensor]] = [] + + def tbi_logger(self, log_data: Dict[str, torch.Tensor]): + """ + Logs a dictionary of named tensors. + Tensors are stored by reference (i.e., on their original device, e.g., CUDA). + If you need a snapshot at the exact moment of logging and the tensor might be + modified in-place later, consider storing {key: tensor.clone() for key, tensor in log_data.items()}. + """ + # Example check (optional): + # for tensor in log_data.values(): + # if not tensor.is_cuda: + # # Or raise an error, or move to CUDA, depending on desired behavior + # print(f"Warning: Tensor {list(log_data.keys())[list(log_data.values()).index(tensor)]} is not on CUDA.") + self._tbi_logs.append(log_data) + + def _get_internal_logs(self) -> List[Dict[str, torch.Tensor]]: + """ + Retrieves the logs stored by this module instance. + """ + return self._tbi_logs + + def _clear_internal_logs(self): + """ + Clears the logs stored by this module instance. + Important for managing memory, should be called after logs are processed. + """ + self._tbi_logs = [] + + +class LoggingCollectorMixin: + """ + A mixin class for nn.Module-based models to collect logs from submodules + that use LogMixin. + The class this is mixed into must be an nn.Module or its subclass. + """ + # No __init__ is strictly necessary here if the mixin itself doesn't have + # its own state to initialize. The methods operate on `self` which will be + # an instance of the class it's mixed into (e.g., Qwen2ForTraining). + # If an __init__ were added, it should also call super().__init__(*args, **kwargs). + + def get_tbi_logs(self, non_blocking: bool = False) -> Dict[str, List[Dict[str, torch.Tensor]]]: + """ + Collects all TBI logs from modules that use LogMixin. + Returns a dictionary where keys are fully qualified module names and + values are lists of log entries (each entry being a dictionary of tensors). + Tensors remain on their original CUDA devices. + Assumes `self` is an nn.Module instance with `named_modules()` method. + """ + all_logs: Dict[str, List[Dict[str, torch.Tensor]]] = {} + # `self` refers to the instance of the class LoggingCollectorMixin is mixed into. + # This class is expected to be an nn.Module or subclass. + for name, module in self.named_modules(): + if isinstance(module, LogMixin): + module_logs = module._get_internal_logs() + if module_logs: # Only add if there are logs for this module + for entry in module_logs: + for k, v in entry.items(): + all_logs[name + "/" + k] = v.detach().to(device="cpu", non_blocking=non_blocking) + return all_logs + + def clear_all_tbi_logs(self): + """ + Clears TBI logs from all modules that use LogMixin. + This should be called after processing the logs (e.g., after a forward/backward pass) + to free up memory. + Assumes `self` is an nn.Module instance with `modules()` method. + """ + # `self` refers to the instance of the class LoggingCollectorMixin is mixed into. + for module in self.modules(): + if isinstance(module, LogMixin): + try: + module._clear_internal_logs() + except AttributeError: + # Similar to get_tbi_logs, handle cases where mixin might not be fully initialized. + pass diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index 6bb4d647..777210ff 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -13,7 +13,7 @@ from nanotron.logging import log_rank from nanotron.parallel.context import ParallelContext from nanotron.parallel.pipeline_parallel.block import PipelineBlock - +from nanotron.logging import LoggingCollectorMixin if TYPE_CHECKING: from nanotron.config import NanotronConfigs from nanotron.parallel.parameters import NanotronParameter @@ -21,7 +21,7 @@ logger = logging.get_logger(__name__) -class NanotronModel(nn.Module, metaclass=ABCMeta): +class NanotronModel(nn.Module, LoggingCollectorMixin, metaclass=ABCMeta): """Abstract class for Nanotron models We make the following assumptions: - When building PP blocks, we assume that the modules order are in the same order as the forward pass.""" diff --git a/src/nanotron/models/qwen.py b/src/nanotron/models/qwen.py index 8aa6eb46..a18ff704 100644 --- a/src/nanotron/models/qwen.py +++ b/src/nanotron/models/qwen.py @@ -29,7 +29,8 @@ ) from nanotron.random import RandomStates from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator - +from nanotron.logging import LogMixin +from nanotron.nn.llama3_ring_attention import llama3_flash_attn_varlen_kvpacked_func, llama3_flash_attn_prepare_cu_seqlens logger = logging.get_logger(__name__) @@ -131,8 +132,7 @@ def forward( -1, self.local_num_heads * self.head_dim ) # [b*s, num_heads, head_dim] -> [b*s, num_heads*head_dim] - -class Qwen2Attention(nn.Module): +class Qwen2Attention(LogMixin, nn.Module): def __init__( self, config: Qwen2Config, @@ -146,6 +146,8 @@ def __init__( self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.tp_pg_size = tp_pg.size() + self.cp_pg_size = cp_pg.size() + self.cp_pg = cp_pg # Head configuration self.num_heads = config.num_attention_heads @@ -192,12 +194,12 @@ def __init__( async_communication=tp_linear_async_communication, ) if config._use_qkv_packed: - from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding - + from nanotron.nn.rotary import FlashRotaryEmbedding self.rotary_emb = FlashRotaryEmbedding( dim=self.head_dim, base=config.rope_theta, interleaved=config.rope_interleaved, + seq_len_interpolation_factor=config.rope_seq_len_interpolation_factor, ) else: self.rotary_emb = RotaryEmbedding( @@ -205,27 +207,29 @@ def __init__( max_seq_len=config.max_position_embeddings, base=config.rope_theta, interleaved=config.rope_interleaved, - seq_len_scaling_factor=None, + seq_len_scaling_factor=config.rope_seq_len_scaling_factor, fused=config._fused_rotary_emb, ) self.attention = CoreAttention(config, tp_pg, cp_pg, layer_idx) self.simple_causal_mask = True self._use_qkv_packed = config._use_qkv_packed - - # TODO: support doc masking / SWA / SFT / inference + self.sliding_window_size = config.sliding_window_size + self.log_attn_probs = config.log_attn_probs + self.heads_k_stride = config.ring_attn_heads_k_stride + # TODO: support SFT def forward( self, hidden_states: torch.Tensor, # [batch_size*seq_length, hidden_size] position_ids: torch.Tensor, # [batch_size, seq_length] where -1 is padding - cu_seqlens: Optional[torch.Tensor] = None, # Added cu_seqlens argument + cu_seqlens: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None, # Added cu_seqlens argument ): # [0, 1, 2, 3, 4, 0, 1, 2, -1, -1, -1] # 2 documents with 5 and 3 tokens then padding # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # 1 document with 11 tokens # [0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1] # 1 document with 10 tokens then padding # Replace -1 with 0 in position_ids to mark every padding token as a separate sequence. Ideally we want to get rid of padding tokens from qkv # position_ids = position_ids.masked_fill(position_ids == -1, 0) - seq_length = position_ids.shape[1] + seq_length = position_ids.shape[1] // self.cp_pg_size # in CP, position_ids are global # Keep original position_ids shape for return, flatten for internal use position_ids = position_ids.view(-1) # [batch_size*seq_length] @@ -233,28 +237,28 @@ def forward( if self._use_qkv_packed: attn_output = self._forward_packed(qkv, seq_length, position_ids, cu_seqlens) - else: - q, k, v = qkv.split( - [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1 - ) # [batch_size*seq_length, q_size], [batch_size*seq_length, kv_size] - q = q.view(-1, self.local_num_heads, self.head_dim) # [b*s, num_heads, head_dim] - k = k.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] - v = v.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] - if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: - rotary_pos_emb = self.rotary_emb( - position_ids=position_ids if not self.simple_causal_mask else None, seq_length=seq_length - ) # [b*s, dim] or [seq_length, dim] - q = self.rotary_emb.apply_rotary_pos_emb( - q, rotary_pos_emb, seq_length=seq_length - ) # [b*s, num_heads, head_dim] - k = self.rotary_emb.apply_rotary_pos_emb( - k, rotary_pos_emb, seq_length=seq_length - ) # [b*s, num_kv_heads, head_dim] - else: - log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) - attn_output = self.attention( - q, k, v, position_ids=position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens - ) + # else: + # q, k, v = qkv.split( + # [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1 + # ) # [batch_size*seq_length, q_size], [batch_size*seq_length, kv_size] + # q = q.view(-1, self.local_num_heads, self.head_dim) # [b*s, num_heads, head_dim] + # k = k.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] + # v = v.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] + # if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: + # rotary_pos_emb = self.rotary_emb( + # position_ids=position_ids if not self.simple_causal_mask else None, seq_length=seq_length + # ) # [b*s, dim] or [seq_length, dim] + # q = self.rotary_emb.apply_rotary_pos_emb( + # q, rotary_pos_emb, seq_length=seq_length + # ) # [b*s, num_heads, head_dim] + # k = self.rotary_emb.apply_rotary_pos_emb( + # k, rotary_pos_emb, seq_length=seq_length + # ) # [b*s, num_kv_heads, head_dim] + # else: + # log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) + # attn_output = self.attention( + # q, k, v, position_ids=position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens + # ) output = self.o_proj(attn_output) # Return original position_ids shape return {"hidden_states": output, "position_ids": position_ids.view(-1, seq_length)} @@ -266,32 +270,62 @@ def _forward_packed(self, qkv, seq_length, position_ids, cu_seqlens): q = q.view(-1, seq_length, self.local_num_heads, self.head_dim) kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim) if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: + seqlen_offset = dist.get_rank(self.cp_pg) * seq_length q, kv = self.rotary_emb( - q, kv, seqlen_offset=0, max_seqlen=None - ) # TODO: should we use position_ids here? flash_attn doesn't + q, kv, seqlen_offset=seqlen_offset, max_seqlen=seq_length*self.cp_pg_size + ) else: log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) + self.sliding_window_size = None # WARNING: we skip sliding window for no-rope + q = q.view(-1, self.local_num_heads, self.head_dim) kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim) max_seqlen = seq_length # TODO: should this be max position_ids? - assert cu_seqlens.dtype == torch.int32 - assert max_seqlen is not None - assert isinstance(max_seqlen, int) - attn_output = flash_attn_varlen_kvpacked_func( - q, - kv, - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - 0.0, - softmax_scale=None, - causal=True, # TODO: double check - alibi_slopes=None, - window_size=(-1, -1), # TODO: fix - deterministic=False, - ) # Not contiguous, similar to flash_attn + + if self.config._attn_implementation == "llama3_ring_attention": + attn_output = llama3_flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q=cu_seqlens["cu_seqlens_q"], + cu_seqlens_k=cu_seqlens["cu_seqlens_k"], + max_seqlen_q=cu_seqlens["max_seqlen_q"], + max_seqlen_k=cu_seqlens["max_seqlen_k"], + heads_k_stride=self.heads_k_stride, + local_k_slice=cu_seqlens["local_k_slice"], + dropout_p=0.0, + softmax_scale=None, + causal=True, + alibi_slopes=None, + window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1), + deterministic=False, + return_attn_probs=self.log_attn_probs, + group=self.cp_pg, + ) # Not contiguous, similar to flash_attn + else: + assert cu_seqlens.dtype == torch.int32 + assert max_seqlen is not None + assert isinstance(max_seqlen, int) + attn_output = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + 0.0, + softmax_scale=None, + causal=True, + alibi_slopes=None, + window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1), + deterministic=False, + return_attn_probs=self.log_attn_probs, + ) # Not contiguous, similar to flash_attn + + if self.log_attn_probs: + attn_output, attn_probs, _ = attn_output + # log attn_probs + self.tbi_logger({"attn_probs": attn_probs}) # flash_attn use rearrange instead of reshape https://github.com/Dao-AILab/flash-attention/blob/1a58058a6da83bd7baaf4c512e8a1abe0240bb77/flash_attn/modules/mha.py#L730 return attn_output.reshape(-1, self.local_num_heads * self.head_dim) # [b*s, num_heads*head_dim] @@ -355,6 +389,183 @@ def forward(self, hidden_states): return {"hidden_states": hidden_states} +class Qwen2MoELayer(nn.Module): + """Mixture of experts Layer for Qwen2 models.""" + + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int = 0, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # MoE specific configurations + self.num_experts = config.moe_config.num_experts # Total number of experts + self.num_experts_per_token = config.moe_config.top_k # Number of experts used per token (top-k) + self.expert_parallel_size = getattr(parallel_config, "expert_parallel_size", 1) + self.num_local_experts = self.num_experts // self.expert_parallel_size # Experts per device + + # Get TP mode configuration + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + # Router for selecting experts + self.router = TensorParallelColumnLinear( + self.hidden_size, + self.num_experts, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + + # Enable shared experts if configured + self.enable_shared_expert = getattr(config.moe_config, "enable_shared_expert", False) + if self.enable_shared_expert: + self.shared_expert = Qwen2MLP( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + ) + self.shared_expert_gate = TensorParallelColumnLinear( + self.hidden_size, + 1, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + + # Create the expert MLPs + self.experts = nn.ModuleList( + [ + Qwen2MLP( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + ) + for _ in range(self.num_local_experts) + ] + ) + + # Whether to recompute MoE layer during backward pass for memory efficiency + self.recompute_layer = parallel_config.recompute_layer + + # Token dispatcher type - determines communication pattern + self.token_dispatcher_type = getattr(config.moe_config, "token_dispatcher_type", "alltoall") + # For more sophisticated implementations, we would add token dispatcher logic here + + def _compute_router_probabilities(self, hidden_states): + """Compute routing probabilities for each token to each expert.""" + router_logits = self.router(hidden_states) # [batch_size*seq_length, num_experts] + + # Get the top-k experts per token + routing_weights, routing_indices = torch.topk(router_logits, k=self.num_experts_per_token, dim=-1) + + # Apply softmax on the top-k values + routing_weights = F.softmax(routing_weights, dim=-1) + + return routing_weights, routing_indices + + def _dispatch_tokens(self, hidden_states, routing_weights, routing_indices): + """ + Dispatches tokens to their selected experts. + In a full implementation, this would handle the actual token routing logic + including communication between devices. + """ + # Simplified implementation - in a complete version this would handle + # all-to-all or all-gather communications for distributed experts + + hidden_states.shape[0] + dispatched_inputs = [] + expert_counts = [] + + # For each expert, gather the tokens assigned to it + for expert_idx in range(self.num_local_experts): + # Find tokens that have this expert in their top-k + expert_mask = (routing_indices == expert_idx).any(dim=-1) + tokens_for_expert = hidden_states[expert_mask] + + # Get the routing weights for this expert + expert_positions = (routing_indices == expert_idx).nonzero(as_tuple=True) + token_positions, k_positions = expert_positions + expert_weights = routing_weights[token_positions, k_positions].unsqueeze(-1) + + # Scale inputs by routing weights + scaled_inputs = tokens_for_expert * expert_weights + + dispatched_inputs.append(scaled_inputs) + expert_counts.append(len(tokens_for_expert)) + + return dispatched_inputs, expert_counts + + def _combine_expert_outputs(self, expert_outputs, routing_indices, original_shape): + """ + Combines outputs from different experts back to the original tensor layout. + """ + # Initialize output tensor with zeros + combined_output = torch.zeros(original_shape, device=expert_outputs[0].device) + + for expert_idx, expert_output in enumerate(expert_outputs): + if expert_output.shape[0] == 0: # Skip if no tokens were routed to this expert + continue + + # Find positions where this expert was in the top-k + expert_mask = (routing_indices == expert_idx).any(dim=-1) + combined_output[expert_mask] += expert_output + + return combined_output + + def _core_forward(self, hidden_states): + """Core forward logic for MoE layer.""" + # Get router probabilities + routing_weights, routing_indices = self._compute_router_probabilities(hidden_states) + + # Dispatch tokens to experts + dispatched_inputs, expert_counts = self._dispatch_tokens(hidden_states, routing_weights, routing_indices) + + # Process tokens with their assigned experts + expert_outputs = [] + for expert_idx, (inputs, count) in enumerate(zip(dispatched_inputs, expert_counts)): + if count == 0: # Skip computation if no tokens assigned + expert_outputs.append(torch.tensor([], device=hidden_states.device)) + continue + + # Forward through the expert + output = self.experts[expert_idx](hidden_states=inputs)["hidden_states"] + expert_outputs.append(output) + + # Combine expert outputs + output = self._combine_expert_outputs(expert_outputs, routing_indices, hidden_states.shape) + + # Add shared expert contribution if enabled + if self.enable_shared_expert: + shared_expert_output = self.shared_expert(hidden_states=hidden_states)["hidden_states"] + shared_gate = torch.sigmoid(self.shared_expert_gate(hidden_states)) + output = output + shared_gate * shared_expert_output + + return output + + def _checkpointed_forward(self, hidden_states): + """Apply gradient checkpointing to save memory during training.""" + return CheckpointFunction.apply(self._core_forward, True, hidden_states) + + def forward(self, hidden_states): + """Forward pass for the MoE layer.""" + if self.recompute_layer and self.training: + hidden_states = self._checkpointed_forward(hidden_states) + else: + hidden_states = self._core_forward(hidden_states) + + return {"hidden_states": hidden_states} + + class Qwen2DecoderLayer(nn.Module): def __init__( self, @@ -551,13 +762,38 @@ def forward( ): output = self.token_position_embeddings(input_ids=input_ids, position_ids=position_ids) # Compute cu_seqlens + cu_seqlens: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None if position_ids.numel() > 0: start_indices = torch.where(position_ids.view(-1) == 0)[0] cu_seqlens = torch.cat( [start_indices, torch.tensor([position_ids.numel()], dtype=torch.int32, device=start_indices.device)] ).to(torch.int32) - else: - cu_seqlens = None + + # llama3 ring attention + if self.config._attn_implementation == "llama3_ring_attention": + local_sequence_length = input_ids.shape[1] + sequence_length = position_ids.shape[1] + assert sequence_length == local_sequence_length * self.parallel_context.cp_pg.size(), f"sequence_length={sequence_length} must be equal to local_sequence_length={local_sequence_length} * cp_pg.size()={self.parallel_context.cp_pg.size()}" + assert sequence_length % (2 * self.parallel_context.cp_pg.size()) == 0, f"Sequence length {sequence_length} must be divisible by {2 * self.parallel_context.cp_pg.size()} when using llama3 ring attention" + ( + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + local_k_slice, + ) = llama3_flash_attn_prepare_cu_seqlens( + cu_seqlens, # global cu_seqlens + causal=True, + rank=self.parallel_context.cp_pg.rank(), + world_size=self.parallel_context.cp_pg.size(), + ) + cu_seqlens = { + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "local_k_slice": local_k_slice, + } decoder_states = { "hidden_states": output["input_embeds"], @@ -656,8 +892,8 @@ def forward( z_loss = masked_mean(z_loss.detach(), label_mask, dtype=torch.float) return {"loss": loss, "z_loss": z_loss} - -class Qwen2ForTraining(NanotronModel): +from nanotron.logging import LoggingCollectorMixin +class Qwen2ForTraining(NanotronModel, LoggingCollectorMixin): def __init__( self, config: Qwen2Config, diff --git a/src/nanotron/nn/attention.py b/src/nanotron/nn/attention.py index 235fdcd4..350418b9 100644 --- a/src/nanotron/nn/attention.py +++ b/src/nanotron/nn/attention.py @@ -5,7 +5,7 @@ from packaging import version from nanotron.nn.ring_attention import ring_flash_attn_varlen_func - +from nanotron.nn.llama3_ring_attention import llama3_flash_attn_varlen_qkvpacked_func # Replace direct import with a function for lazy loading def get_ring_flash_attn_cuda(): @@ -217,6 +217,7 @@ def sdpa_attention_forward( "sdpa": sdpa_attention_forward, "ring_flash_triton": lambda *args, **kwargs: get_ring_flash_attn_cuda()(*args, **kwargs), "ring": ring_flash_attn_varlen_func, + "llama3_ring_attention": llama3_flash_attn_varlen_qkvpacked_func, } AttentionImplementation = Literal[tuple(ALL_ATTENTION_FUNCTIONS.keys())] diff --git a/src/nanotron/nn/llama3_ring_attention.py b/src/nanotron/nn/llama3_ring_attention.py new file mode 100644 index 00000000..0e0963ef --- /dev/null +++ b/src/nanotron/nn/llama3_ring_attention.py @@ -0,0 +1,810 @@ +"""Ring attention implementation using flash attention adapted from https://github.com/zhuzilin/ring-flash-attention/""" + +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_forward, + _flash_attn_varlen_backward, +) + +def llama3_flash_attn_prepare_cu_seqlens( + cu_seqlens: torch.Tensor, causal: bool, rank: int, world_size: int +): + """ + Args: + cu_seqlens: torch.Tensor, the cu_seqlens of all the sequences across the ring process group. + + Returns: + cu_seqlens_q: torch.Tensor, the cu_seqlens of the q slice for this rank. + cu_seqlens_k: torch.Tensor, the cu_seqlens of the k slice that the local q need. Note + that this may be longer than `total_seq_len // world_size`. + local_k_slice: slice, the slice of the k that the local q need. Note + that this may be longer than `total_seq_len // world_size`. + """ + total_length = cu_seqlens[-1] + assert total_length % world_size == 0 + length_per_rank = total_length // world_size + left = torch.searchsorted(cu_seqlens, rank * length_per_rank) + right = torch.searchsorted(cu_seqlens, (rank + 1) * length_per_rank) + length_per_rank = length_per_rank.item() + + # after this, cu_seqlens[left:right + 1] contains all the sequence for this rank + if cu_seqlens[left] != rank * length_per_rank: + left -= 1 + left = left.item() + right = right.item() + + # q is always the same. just calculate the cu_seqlens for the local slice + cu_seqlens_q = cu_seqlens[left : right + 1].clone() + cu_seqlens_q -= rank * length_per_rank + cu_seqlens_q[0] = 0 + cu_seqlens_q[-1] = length_per_rank + + cu_seqlens_k = cu_seqlens[left : right + 1].clone() + if causal: + # when causal, we hope + # - the last k seq is of the same length as the last q seq + slice_right = (rank + 1) * length_per_rank + cu_seqlens_k[-1] = slice_right + else: + # when not causal, we hope + # - the last k is full seq + slice_right = cu_seqlens[right].item() + + slice_left = cu_seqlens[left].item() + cu_seqlens_k -= slice_left + + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + local_k_slice = slice(slice_left, slice_right) + return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, local_k_slice + + +def llama3_flash_attn_varlen_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + out_list = [] + lse_list = [] + + nheads = q.shape[1] + total_k, nheads_k, head_dim = k.shape + assert nheads_k % heads_k_stride == 0, f"nheads_k={nheads_k} must be divisible by heads_k_stride={heads_k_stride}" + + world_size = dist.get_world_size(process_group) + kv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + kv_buffer_copy = torch.empty_like(kv_buffer) + + k_0 = k[:, :heads_k_stride].contiguous() + v_0 = v[:, :heads_k_stride].contiguous() + comm = Comm(process_group) + + comm.all_gather(kv_buffer_copy[0], k_0) + comm.all_gather(kv_buffer_copy[1], v_0) + + for i in range(0, nheads_k, heads_k_stride): + comm.wait() + kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer + + if i < nheads_k - heads_k_stride: + # all_gather the next kv slice + kv_slice_left = i + heads_k_stride + kv_slice_right = kv_slice_left + heads_k_stride + send_k = k[:, kv_slice_left:kv_slice_right].contiguous() + send_v = v[:, kv_slice_left:kv_slice_right].contiguous() + comm.all_gather(kv_buffer_copy[0], send_k) + comm.all_gather(kv_buffer_copy[1], send_v) + + q_i = q[:, i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] + k_i = kv_buffer[0][local_k_slice] + v_i = kv_buffer[1][local_k_slice] + + params = get_default_args(_flash_attn_varlen_forward).copy() + params.update( + { + "q": q_i, + "k": k_i, + "v": v_i, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_varlen_forward(**params) + if len(outputs) == 8: + out, _, _, _, _, lse, _, _ = outputs + else: + assert len(outputs) == 4 + out, lse, _, _ = outputs + out_list.append(out) + lse_list.append(lse) + + out = torch.cat(out_list, dim=1) + lse = torch.cat(lse_list, dim=-2) + return out, lse + + +def llama3_flash_attn_varlen_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + nheads = q.shape[1] + total_k, nheads_k, head_dim = k.shape + assert nheads_k % heads_k_stride == 0 + + world_size = dist.get_world_size(process_group) + kv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + kv_buffer_copy = torch.empty_like(kv_buffer) + + dkv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + if heads_k_stride != nheads_k: + kv_contiguous_buffer = torch.empty( + (2, total_k, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + comm = Comm(process_group) + + k_0 = k[:, :heads_k_stride].contiguous() + v_0 = v[:, :heads_k_stride].contiguous() + comm.all_gather(kv_buffer_copy[0], k_0) + comm.all_gather(kv_buffer_copy[1], v_0) + + for i in range(0, nheads_k, heads_k_stride): + dkv_buffer.zero_() + + q_slice = slice( + i * nheads // nheads_k, (i + heads_k_stride) * nheads // nheads_k + ) + q_i = q[:, q_slice] + dout_i = dout[:, q_slice] + out_i = out[:, q_slice] + dq_i = dq[:, q_slice] + if softmax_lse.dim() == 3: + lse_i = softmax_lse[:, q_slice].contiguous() + else: + lse_i = softmax_lse[q_slice] + + comm.wait() + kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer + + if i < nheads_k - heads_k_stride: + # all_gather the next kv slice + kv_slice_left = i + heads_k_stride + kv_slice_right = kv_slice_left + heads_k_stride + send_k = k[:, kv_slice_left:kv_slice_right].contiguous() + send_v = v[:, kv_slice_left:kv_slice_right].contiguous() + comm.all_gather(kv_buffer_copy[0], send_k) + comm.all_gather(kv_buffer_copy[1], send_v) + + k_i = kv_buffer[0][local_k_slice] + v_i = kv_buffer[1][local_k_slice] + dk_i = dkv_buffer[0][local_k_slice] + dv_i = dkv_buffer[1][local_k_slice] + + params = get_default_args(_flash_attn_varlen_backward).copy() + params.update( + { + "dout": dout_i, + "q": q_i, + "k": k_i, + "v": v_i, + "out": out_i, + "softmax_lse": lse_i, + "dq": dq_i, + "dk": dk_i, + "dv": dv_i, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_varlen_backward(**params) + + if heads_k_stride != nheads_k: + # reduce_scatter needs contiguous buffer + dk_i = kv_contiguous_buffer[0] + dv_i = kv_contiguous_buffer[1] + else: + dk_i = dk + dv_i = dv + + dist.reduce_scatter_tensor(dk_i, dkv_buffer[0], group=process_group) + dist.reduce_scatter_tensor(dv_i, dkv_buffer[1], group=process_group) + + if heads_k_stride != nheads_k: + dk[:, i : i + heads_k_stride] = dk_i + dv[:, i : i + heads_k_stride] = dv_i + + return dq, dk, dv + + +class Llama3FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = llama3_flash_attn_varlen_forward( + group, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.heads_k_stride = heads_k_stride + ctx.local_k_slice = local_k_slice + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + dq, dk, dv = llama3_flash_attn_varlen_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.heads_k_stride, + ctx.local_k_slice, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return (dq, dk, dv) + (None,) * 15 + + +def llama3_flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def llama3_flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def llama3_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + + + +## triton_utils.py +import torch +import triton +import triton.language as tl + + +@triton.jit +def flatten_kernel( + # pointers to matrices + OUT, + LSE, + CU_SEQLENS, + # strides + stride_out_nheads, + stride_out_seqlen, + stride_lse_batch, + stride_lse_nheads, + stride_lse_seqlen, + # meta-parameters + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads + OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + LSE = LSE + rm[:, None] * stride_lse_seqlen + x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) + + OUT = OUT + rm[:, None] * stride_out_seqlen + tl.store(OUT, x, mask=rm[:, None] < seqlen) + + +def flatten_varlen_lse(lse, cu_seqlens): + """ + Arguments: + lse: (batch_size, nheads, max_seqlen) + cu_seqlens: (batch_size + 1,) + Return: + flatten_lse: (nheads, total_seqlen) + """ + total_seqlen = cu_seqlens[-1] + batch_size, nheads, max_seqlen = lse.shape + output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) + + def grid(META): + return triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads + + BLOCK_M = 4 + + with torch.cuda.device(lse.device.index): + flatten_kernel[grid]( + output, + lse, + cu_seqlens, + # strides + output.stride(0), + output.stride(1), + lse.stride(0), + lse.stride(1), + lse.stride(2), + BLOCK_M, + ) + return output + + +@triton.jit +def unflatten_kernel( + # pointers to matrices + OUT, + LSE, + CU_SEQLENS, + # strides + stride_out_batch, + stride_out_nheads, + stride_out_seqlen, + stride_lse_seqlen, + stride_lse_nheads, + # meta-parameters + BLOCK_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen + OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + + LSE = LSE + rm[:, None] * stride_lse_seqlen + x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) + + OUT = OUT + rm[:, None] * stride_out_seqlen + tl.store(OUT, x, mask=rm[:, None] < seqlen) + + +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + """ + Arguments: + lse: (total_seqlen, nheads, 1) + cu_seqlens: (batch_size + 1,) + max_seqlen: int + Return: + unflatten_lse: (batch_size, nheads, max_seqlen) + """ + lse = lse.unsqueeze(dim=-1) + batch_size = len(cu_seqlens) - 1 + nheads = lse.shape[1] + output = torch.empty( + (batch_size, nheads, max_seqlen), + dtype=lse.dtype, + device=lse.device, + ) + + def grid(META): + return triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads + + BLOCK_M = 4 + + with torch.cuda.device(lse.device.index): + unflatten_kernel[grid]( + output, + lse, + cu_seqlens, + # strides + output.stride(0), + output.stride(1), + output.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_M, + ) + return output + + +import inspect +from functools import cache +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if "softcap" in args: + args["softcap"] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + else: + # Use the origin _init_fn in CustomOpDef + return _get_default_args(func._init_fn) + + +@torch.jit.script +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +@torch.jit.script +def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + +@torch.jit.script +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty((num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() + + +## utils.py +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] + + def send_recv_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + self.commit() + return next_k, next_v + + +class AllGatherComm: + def __init__(self, group=None) -> None: + self.group = group + self.handles = [] + + def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): + handle = dist.all_gather_into_tensor(output_tensor, input_tensor, group=self.group, async_op=True) + self.handles.append(handle) + + def wait(self): + for handle in self.handles: + handle.wait() + self.handles = [] + +Comm = AllGatherComm diff --git a/src/nanotron/nn/rotary.py b/src/nanotron/nn/rotary.py index 4e78849f..109e59f0 100644 --- a/src/nanotron/nn/rotary.py +++ b/src/nanotron/nn/rotary.py @@ -1,6 +1,11 @@ import torch from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb from torch import nn +from flash_attn.layers.rotary import RotaryEmbedding as OrigFlashRotaryEmbedding +from einops import rearrange +from nanotron import logging +from nanotron.logging import warn_once +logger = logging.get_logger(__name__) class RotaryEmbedding(nn.Module): @@ -140,3 +145,77 @@ def apply_rotary_pos_emb(self, tensor, freqs, multi_latent_attention=False, msca if pass_through_part is not None and pass_through_part.shape[-1] > 0: return torch.cat((rotated_tensor, pass_through_part), dim=-1) return rotated_tensor + +class FlashRotaryEmbedding(OrigFlashRotaryEmbedding): + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + seq_len_interpolation_factor=None, + ): + super().__init__( + dim, + base, + interleaved, + scale_base, + pos_idx_in_fp32, + device, + ) + self.seq_len_interpolation_factor = seq_len_interpolation_factor + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + + # fixed linear scaling + if self.seq_len_interpolation_factor is not None: + warn_once(f"seq_len_interpolation_factor is set to {self.seq_len_interpolation_factor}", logger, rank=0) + t *= 1 / self.seq_len_interpolation_factor + + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) + - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) \ No newline at end of file diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 088551b0..9a42f8e6 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -304,14 +304,14 @@ class FP32GradBucketManager: """Manages the fp32 gradient buckets. Attributes: - dp_pg: The process group to allreduce gradients across. + dp_cp_pg: The process group to allreduce gradients across. accumulator: The gradient accumulator which keeps the gradient buffers. bucket_id_to_fp32_grad_buckets_and_dependencies: A dictionary mapping bucket ids to: - fp32 grad bucket (torch.Tensor) - set of param ids that are in the bucket -> used to know when to delete the buffer param_id_to_bucket_id: A dictionary mapping param ids to bucket ids.""" - dp_pg: dist.ProcessGroup + dp_cp_pg: dist.ProcessGroup accumulator: FP32GradientAccumulator param_id_to_name: Dict[int, str] @@ -334,7 +334,7 @@ def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.f # nonlocal s # DDP groups grads in GradBuckets. This hook is called throughout the bwd pass, once each bucket is ready to overlap communication with computation. # See https://pytorch.org/docs/stable/ddp_comm_hooks.html#what-does-a-communication-hook-operate-on for more details. - dp_pg = state.dp_pg + dp_cp_pg = state.dp_cp_pg accumulator = state.accumulator param_id_to_name = state.param_id_to_name @@ -346,12 +346,13 @@ def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.f fp32_grad_buffer.add_(grad.view_as(fp32_grad_buffer)) # sync across dp - if dp_pg.size() == 1: + if dp_cp_pg.size() == 1: fut = torch.futures.Future() fut.set_result(bucket.buffer()) return fut if reduce_scatter: + raise NotImplementedError("Not implemented") assert hasattr(accumulator, "param_name_to_offsets") grad_buffer_tensor_list = [ accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters() @@ -372,7 +373,7 @@ def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.f output_tensor_list=output_tensor_list, input_tensor_lists=input_tensor_lists, op=reduce_op, - group=dp_pg, + group=dp_cp_pg, async_op=True, ) else: @@ -380,7 +381,7 @@ def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.f accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters() ] accumulator.fp32_grads_allreduce_handle = dist.all_reduce_coalesced( - grad_buffer_tensor_list, group=dp_pg, async_op=True, op=reduce_op + grad_buffer_tensor_list, group=dp_cp_pg, async_op=True, op=reduce_op ) # we shouldn't wait for this future for the rest of the backward diff --git a/src/nanotron/parallel/context.py b/src/nanotron/parallel/context.py index 7820f8a2..04fb4c66 100644 --- a/src/nanotron/parallel/context.py +++ b/src/nanotron/parallel/context.py @@ -90,6 +90,15 @@ def _init_parallel_groups(self): ] ) + self.dp_cp_pg = self.create_new_group( + [ + ranks[ep_rank, pp_rank, :, :, tp_rank].reshape(-1) + for tp_rank in range(self.tensor_parallel_size) + for pp_rank in range(self.pipeline_parallel_size) + for ep_rank in range(self.expert_parallel_size) + ] + ) + self.tp_and_ep_pg = self.create_new_group( [ ranks[:, pp_rank, dp_rank, cp_rank, :].reshape(-1) diff --git a/src/nanotron/sanity_checks.py b/src/nanotron/sanity_checks.py index b4772262..4e59f782 100644 --- a/src/nanotron/sanity_checks.py +++ b/src/nanotron/sanity_checks.py @@ -62,12 +62,12 @@ def before_tbi_sanity_checks( lr_scheduler: torch.optim.lr_scheduler.LRScheduler, ) -> None: if not config.general.ignore_sanity_checks: - # SANITY CHECK: Check that the model params are synchronized across dp + # SANITY CHECK: Check that the model params are synchronized across dp_cp for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]): assert_tensor_synced_across_pg( tensor=param, - pg=parallel_context.dp_pg, - msg=lambda err: f"{name} are not synchronized across DP {err}", + pg=parallel_context.dp_cp_pg, + msg=lambda err: f"{name} are not synchronized across DP_CP {err}", ) # SANITY CHECK: Tied weights are synchronized @@ -208,16 +208,16 @@ def before_optim_step_sanity_checks( assert grad is not None, f"Grad is None for {name}" assert_tensor_synced_across_pg( tensor=grad, - pg=parallel_context.dp_pg, - msg=lambda err: f"[Before optimizer step] weights grads for {name} are not synchronized across DP. {err}", + pg=parallel_context.dp_cp_pg, + msg=lambda err: f"[Before optimizer step] weights grads for {name} are not synchronized across DP_CP. {err}", ) # SANITY CHECK: Check that the model params are synchronized across dp for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]): assert_tensor_synced_across_pg( tensor=param, - pg=parallel_context.dp_pg, - msg=lambda err: f"{name} are not synchronized across DP {err}", + pg=parallel_context.dp_cp_pg, + msg=lambda err: f"{name} are not synchronized across DP_CP {err}", ) # SANITY CHECK: Tied weights are synchronized @@ -234,8 +234,8 @@ def before_optim_step_sanity_checks( msg=lambda err: f"[Before optimizer step] Tied weights {name} are not synchronized. {err}", ) - # SANITY CHECK: Check that optimizer states are synchronized across DP - check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg) + # SANITY CHECK: Check that optimizer states are synchronized across DP_CP + check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_cp_pg) # SANITY CHECK: run model specific sanity checks unwrapped_model.before_optim_step_sanity_checks() diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 2b5d4558..e073aeae 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -122,13 +122,13 @@ def save( # TODO @thomas21: sanity check, not sure whether that needs to happen at testing or now (depends how much it costs) ### - # SANITY CHECK: Check that the model params are synchronized across `parallel_context.dp_pg` + # SANITY CHECK: Check that the model params are synchronized across `parallel_context.dp_cp_pg` if sanity_checks: for name, param_or_buffer in sorted(model.state_dict().items(), key=lambda x: x[0]): assert_tensor_synced_across_pg( tensor=param_or_buffer, - pg=parallel_context.dp_pg, - msg=lambda err: f"{name} are not synced across DP {err}", + pg=parallel_context.dp_cp_pg, + msg=lambda err: f"{name} are not synced across DP_CP {err}", ) # SANITY CHECK: Check that the tied parameters are synchronized @@ -150,7 +150,7 @@ def save( tensor=tied_param, pg=group, msg=lambda err: f"Tied {tied_info.name} are not synced {err}" ) if not optimizer.inherit_from(optim.ZeroDistributedOptimizer): - check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_pg) + check_optim_state_in_sync(optimizer.state_dict(), parallel_context.dp_cp_pg) # SANITY CHECK: tied parameters have their optimizer states synchronized # Compute a mapping from id_ to index in the optimizer sense diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 5812bfd1..18865e96 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -13,23 +13,31 @@ from nanotron.constants import CHECKPOINT_FILE_NAME, CHECKPOINT_VERSION from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import SlicesPair - +from collections import defaultdict @dataclasses.dataclass class DataStageMetadata: """ - consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero). - last_train_step: The last training step across all stages. - - # NOTE: we should allow people to change the name of the data stages in the config file. - # but not the start_training_step, because it could + consumed_train_samples: The number of samples consumed by the model in the this stage (resets at each stage). + consumed_tokens_per_dataset_folder: The number of tokens consumed by the model in the this stage for each dataset folder. (resets at each stage) """ name: str start_training_step: int - consumed_train_samples: int - consumed_tokens_per_dataset_folder: Dict[str, int] = dataclasses.field(default_factory=dict) + consumed_train_samples: int # We use this for sampler, and it's reset at each stage + sequence_length: Optional[int] = None # TODO: put back as non-optional + consumed_tokens_per_dataset_folder: Dict[str, int] = dataclasses.field(default_factory=dict) # this gets reset at each stage + + def __post_init__(self): + if self.sequence_length is None: + self.sequence_length = 4096 # TODO: temp + + def sanity_consumed_train_samples(self): + assert self.consumed_train_samples*self.sequence_length == sum(self.consumed_tokens_per_dataset_folder.values()), f"Mismatch between the total consumed samples and the sum of consumed samples across dataset folders! consumed_train_samples={self.consumed_train_samples}, sequence_length={self.sequence_length}, consumed_tokens_per_dataset_folder={self.consumed_tokens_per_dataset_folder}" + @property + def consumed_tokens_all_datasets(self): + return sum(self.consumed_tokens_per_dataset_folder.values()) @dataclasses.dataclass class TrainingMetadata: @@ -40,8 +48,9 @@ class TrainingMetadata: data_stages: The metadata for each stage. """ - consumed_train_samples: int + consumed_train_samples: int # TODO: Legacy. This assumed same sequence length across all stages. Not used anymore last_train_step: int + consumed_tokens_total: Optional[int] = None # TODO: put back as non-optional # TODO(xrsrke): make this not optional, once we entirely remove # the old checkpoint version @@ -50,15 +59,31 @@ class TrainingMetadata: def __post_init__(self): # NOTE: this is a sanity check after loading a trained checkpoint - total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages) assert ( - self.consumed_train_samples == total_consumed_samples_across_stages + self.consumed_train_samples == sum(stage.consumed_train_samples for stage in self.data_stages) ), "Mismatch between the total consumed samples and the sum of consumed samples across stages! Something went wrong in the training." + if self.consumed_tokens_total is not None: + assert self.consumed_tokens_total == sum(stage.consumed_tokens_all_datasets for stage in self.data_stages), "Mismatch between the total consumed tokens and the sum of consumed tokens across stages! Something went wrong in the training." + else: + self.consumed_tokens_total = sum(stage.consumed_tokens_all_datasets for stage in self.data_stages) + # TODO(xrsrke): remove this once we entirely remove non-data-stage training if self.last_stage_idx is not None: assert self.data_stages is not None, "data_stages should not be None if last_stage_idx is not None" + @property + def consumed_tokens_per_dataset_folder_total(self): + consumed = defaultdict(int) + for stage in self.data_stages: + for dataset_folder, tokens in stage.consumed_tokens_per_dataset_folder.items(): + consumed[dataset_folder] += tokens + return consumed + + @property + def current_stage(self) -> DataStageMetadata: + return self.data_stages[self.last_stage_idx] + @dataclasses.dataclass class CheckpointMetadata: @@ -66,6 +91,7 @@ class CheckpointMetadata: tp: int dp: int metas: TrainingMetadata + cp: int = 1 custom_metas: Optional[Dict[str, Any]] = None @@ -142,6 +168,7 @@ def save_meta(parallel_context: ParallelContext, root_folder: Path, training_met version=CHECKPOINT_VERSION, tp=parallel_context.tp_pg.size(), dp=parallel_context.dp_pg.size(), + cp=parallel_context.cp_pg.size(), metas=training_metadata, ) diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index fc71a237..7f49941d 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -47,7 +47,7 @@ def save_optimizer( - If Zero-0 is used, optimizer states are replicated across all DPs. Only DP-0 saves the states - If Zero-1 is used, optimizer states are sharded across all DPs. Each DP saves its own states """ - if (not optimizer.inherit_from(optim.ZeroDistributedOptimizer)) and dist.get_rank(parallel_context.dp_pg) > 0: + if (not optimizer.inherit_from(optim.ZeroDistributedOptimizer)) and dist.get_rank(parallel_context.dp_cp_pg) > 0: # this is Zero-0, so only DP-0 saves the optimizer states return @@ -113,7 +113,7 @@ def save_lr_scheduler( root_folder: Path, ): """Saves lr scheduler states""" - if not is_zero and dist.get_rank(parallel_context.dp_pg) > 0: + if not is_zero and dist.get_rank(parallel_context.dp_cp_pg) > 0: # this is Zero-0, so only DP-0 saves the optimizer states return diff --git a/src/nanotron/serialize/weights.py b/src/nanotron/serialize/weights.py index 4c16d3a6..400a12eb 100644 --- a/src/nanotron/serialize/weights.py +++ b/src/nanotron/serialize/weights.py @@ -30,9 +30,8 @@ def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folder: Path): root_folder = root_folder / "model" - # We save only `dist.get_rank(parallel_context.dp_pg) == 0` - # TODO @thomasw21: Figure how this works with Zero-3 - if dist.get_rank(parallel_context.dp_pg) != 0: + # We save only `dist.get_rank(parallel_context.dp_cp_pg) == 0` + if dist.get_rank(parallel_context.dp_cp_pg) != 0: return module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index c94e97de..13ecf497 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -244,19 +244,19 @@ def __init__( assert isinstance(checkpoint_metadata.metas, TrainingMetadata) log_rank(str(checkpoint_metadata), logger=logger, level=logging.INFO, rank=0) self.metadata: TrainingMetadata = checkpoint_metadata.metas - # NOTE: we should not change data stages + # In case of a new datastage, metadata will be updated in `get_dataloader` assert ( self.config.tokens.train_steps > self.metadata.last_train_step ), f"Loaded checkpoint has already trained {self.metadata.last_train_step} batches, you need to specify a higher `config.tokens.train_steps`" else: data_stages = [ DataStageMetadata( - name=stage.name, start_training_step=stage.start_training_step, consumed_train_samples=0 + name=stage.name, start_training_step=stage.start_training_step, consumed_train_samples=0, sequence_length=stage.sequence_length ) for stage in self.config.data_stages ] self.metadata: TrainingMetadata = TrainingMetadata( - consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages + consumed_train_samples=0, consumed_tokens_total=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages ) # Setup tensorboard write and log writers on output rank @@ -269,7 +269,7 @@ def __init__( self.n_micro_batches_per_batch = self.config.tokens.batch_accumulation_per_replica self.global_batch_size = ( self.micro_batch_size * self.n_micro_batches_per_batch * self.parallel_context.dp_pg.size() - ) + ) # in terms of samples self.sequence_length = ( self.config.tokens.sequence_length ) # Global sequence length not divided by context parallel size @@ -339,7 +339,7 @@ def pre_training(self, *args, **kwargs): log_rank("Start training", logger=logger, level=logging.INFO, rank=0, is_separator=True) log_rank( - f"mbs: {self.micro_batch_size} | grad_accum: {self.n_micro_batches_per_batch} | sequence_length: {self.sequence_length} | global_batch_size: {self.global_batch_size} | train_steps: {self.config.tokens.train_steps} | start_iteration_step: {metadata.last_train_step} | consumed_train_samples: {metadata.consumed_train_samples}", # noqa + f"mbs: {self.micro_batch_size} | grad_accum: {self.n_micro_batches_per_batch} | cp: {self.parallel_context.cp_pg.size()} | sequence_length: {self.sequence_length} | global_batch_size: {self.global_batch_size} | train_steps: {self.config.tokens.train_steps} | start_iteration_step: {metadata.last_train_step} | consumed_tokens_total: {metadata.consumed_tokens_total}", # noqa logger=logger, level=logging.INFO, rank=0, @@ -350,21 +350,13 @@ def pre_training(self, *args, **kwargs): # Initialize wandb for each TP group if TP > 1, but only for dp=0 ranks if wandb is not None: tp_size = self.parallel_context.tp_pg.size() - dp_rank = dist.get_rank(self.parallel_context.dp_pg) + dp_cp_rank = dist.get_rank(self.parallel_context.dp_cp_pg) tp_rank = dist.get_rank(self.parallel_context.tp_pg) world_rank = dist.get_rank(self.parallel_context.world_pg) - # Log all rank info for debugging purposes - log_rank( - f"Rank info - world_rank: {world_rank}, dp_rank: {dp_rank}, tp_rank: {tp_rank}, tp_size: {tp_size}, logger_ranks: {self.logger_ranks}", - logger=logger, - level=logging.INFO, - rank=world_rank, - ) - if tp_size > 1 and self.metrics_logging.log_level > 0: # Create one wandb logger per TP group for DP=0 ranks - if dp_rank == 0: + if dp_cp_rank == 0: # Create a run name that includes the TP group run_name = f"{current_time}_{self.config.general.run}_tp_group_{tp_rank}" @@ -420,7 +412,7 @@ def pre_training(self, *args, **kwargs): ) def post_train_step(self): - + self.unwrapped_model.clear_all_tbi_logs() # Update our background upload/removal of checkpoints if self.s3_mover is not None: self.s3_mover.update() @@ -450,9 +442,6 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da return assert len(dataloaders) > 0, "No dataloaders provided" - assert len(dataloaders) == len( - self.config.data_stages - ), "Number of dataloaders should match the number of dataset stages" def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): import gc @@ -570,13 +559,13 @@ def train( self._update_dataloader_based_on_training_stages(dataloader_or_dls) # Training step - outputs, loss_avg, z_loss_avg = self.training_step(dataloader=self.current_dataloader) + outputs, loss_avg, z_loss_avg, tbi_logs = self.training_step(dataloader=self.current_dataloader) # Update consumption tracking for current batch - if hasattr(self.current_base_dl, "dataset"): + if hasattr(self.current_base_dl, "dataset") and hasattr(self.current_base_dl.dataset, "update_consumption_metrics"): + # TODO: only works for BlendableDataset self.current_base_dl.dataset.update_consumption_metrics( - start_idx=(self.iteration_step - 1) - * self.global_batch_size, # assumes we start from iteration_step=1 + start_idx=(self.iteration_step - 1) * self.global_batch_size, # assumes we start from iteration_step=1 end_idx=self.iteration_step * self.global_batch_size, sequence_length=self.sequence_length, ) @@ -592,14 +581,14 @@ def train( current_stage.consumed_tokens_per_dataset_folder[folder_path] = stats["tokens"] # Original consumption tracking - self.metadata.consumed_train_samples += self.global_batch_size + self.metadata.consumed_train_samples += self.global_batch_size # TODO: Legacy: idc abt this + self.metadata.consumed_tokens_total += self.global_batch_size * self.sequence_length self.metadata.last_train_step = self.iteration_step - self.metadata.data_stages[ - self.metadata.last_stage_idx - ].consumed_train_samples += self.global_batch_size + self.metadata.current_stage.consumed_train_samples += self.global_batch_size + assert self.metadata.current_stage.sequence_length == self.sequence_length, "Sequence length mismatch between the current stage and the global sequence length" if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs(outputs=outputs, loss_avg=loss_avg, z_loss_avg=z_loss_avg) + self.train_step_logs(outputs=outputs, loss_avg=loss_avg, z_loss_avg=z_loss_avg, tbi_logs=tbi_logs) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -652,6 +641,8 @@ def training_step( nanotron_timer("sync_gradients", "cuda").start() # Sync tied weights if not isinstance(self.model, DistributedDataParallel): + if self.parallel_context.context_parallel_size > 1: + raise NotImplementedError("Context parallel size > 1 is not supported yet without DDP") # Manually sync across DP if it's not handled by DDP sync_gradients_across_dp( module=self.model, @@ -687,7 +678,7 @@ def training_step( ) nanotron_timer("clip_gradients", "cuda").end() - # Compute DP average loss and overlap with optimizer step + # Compute DP-CP average loss and overlap with optimizer step if isinstance(outputs[0]["loss"], torch.Tensor): # This is an average on only one data rank. loss_avg = torch.stack( @@ -699,8 +690,8 @@ def training_step( ).sum() # already divided by n_micro_batches_per_batch else: z_loss_avg = None - # sync loss across DP (we should do the same for z_loss but it's only for logging so let's not sync it rn) - handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) + # sync loss across DP-CP (we should do the same for z_loss but it's only for logging so let's not sync it rn) + handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_cp_pg, async_op=True, op=dist.ReduceOp.AVG) else: z_loss_avg = None loss_avg = None @@ -718,6 +709,9 @@ def training_step( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.optimizer ) + # get_tbi_logs in a non-blocking way since we don't need to wait for it + tbi_logs = self.unwrapped_model.get_tbi_logs(non_blocking=True) + # Apply gradient nanotron_timer("optimizer_step", "cuda").start() self.optimizer.step() @@ -734,7 +728,7 @@ def training_step( self.post_train_step() - return outputs, loss_avg, z_loss_avg + return outputs, loss_avg, z_loss_avg, tbi_logs def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: outputs = self.pipeline_engine.validate_batch_iter( @@ -749,6 +743,7 @@ def train_step_logs( outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], loss_avg: Optional[torch.Tensor], z_loss_avg: Optional[torch.Tensor], + tbi_logs: Optional[Dict[str, torch.Tensor]], ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() @@ -766,7 +761,7 @@ def train_step_logs( # Get rank information (used by both console and wandb logging) tp_size = self.parallel_context.tp_pg.size() - dp_rank = dist.get_rank(self.parallel_context.dp_pg) + dp_cp_rank = dist.get_rank(self.parallel_context.dp_cp_pg) tp_rank = dist.get_rank(self.parallel_context.tp_pg) world_rank = dist.get_rank(self.parallel_context.world_pg) @@ -776,12 +771,7 @@ def train_step_logs( eta_seconds = int(remaining_steps * (elapsed_time_per_iteration_ms / 1000)) basic_log_entries = [ # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), - LogItem( - "consumed_tokens", - self.metadata.consumed_train_samples - * self.config.tokens.sequence_length, # TODO: not true if we change seqlen - "human_format", - ), # , "12d"), + LogItem("consumed_tokens", self.metadata.consumed_tokens_total, "human_format"), LogItem("time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"), LogItem("tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), LogItem( @@ -868,6 +858,20 @@ def get_cpu_logitems(): assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" self.loggerwriter.add_scalars_from_list(basic_log_entries, self.iteration_step) + if tbi_logs is not None: + for name, tensor in tbi_logs.items(): + # attn_probs is [num_local_heads, mbs * seq_len] + basic_log_entries.append(LogItem(f"tbi_logs/{name}_mean", tensor.mean().item(), "human_format")) + basic_log_entries.append(LogItem(f"tbi_logs/{name}_std", tensor.std().item(), "human_format")) + basic_log_entries.append(LogItem(f"tbi_logs/{name}_max", tensor.max().item(), "human_format")) + basic_log_entries.append(LogItem(f"tbi_logs/{name}_min", tensor.min().item(), "human_format")) + # per head logs + for head_idx in range(tensor.shape[0]): + basic_log_entries.append(LogItem(f"tbi_logs/{name}/head_{head_idx}_mean", tensor[head_idx].mean().item(), "human_format")) + basic_log_entries.append(LogItem(f"tbi_logs/{name}/head_{head_idx}_std", tensor[head_idx].std().item(), "human_format")) + basic_log_entries.append(LogItem(f"tbi_logs/{name}/head_{head_idx}_max", tensor[head_idx].max().item(), "human_format")) + basic_log_entries.append(LogItem(f"tbi_logs/{name}/head_{head_idx}_min", tensor[head_idx].min().item(), "human_format")) + if os.environ.get("DEBUG_CPU", "0") == "1": basic_log_entries.extend(get_cpu_logitems()) @@ -889,7 +893,7 @@ def get_cpu_logitems(): # WandB logging - determine if this rank should log to wandb should_log_to_wandb = wandb is not None and ( - (tp_size > 1 and dp_rank == 0 and self.metrics_logging.log_level > 0) + (tp_size > 1 and dp_cp_rank == 0 and self.metrics_logging.log_level > 0) or (tp_size > 1 and world_rank == self.logger_ranks[0] and self.metrics_logging.log_level == 0) or (tp_size == 1 and world_rank == self.logger_ranks[0]) # For TP>1, log from each TP group's dp=0 rank ) @@ -901,7 +905,7 @@ def get_cpu_logitems(): if should_log_detailed_metrics_to_wandb: assert not ( - wandb.run is None and tp_size > 1 and dp_rank == 0 + wandb.run is None and tp_size > 1 and dp_cp_rank == 0 ), f"WandB is not initialized for TP rank {tp_rank}, but logging was requested. Make sure that wandb is initialize before training." all_log_entries = list(basic_log_entries) @@ -1045,6 +1049,8 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: reloaded_from_checkpoint = True if not reloaded_from_checkpoint: log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0) + if self.parallel_context.context_parallel_size > 1: + raise NotImplementedError("Init with Context parallel size > 1 not supported yet") if isinstance(self.config.model.init_method, ExistingCheckpointInit): # Initialize model from an pretrained model checkpoint (without optimizer, lr_scheduler...) self.param_shard_metadata = load_weights( @@ -1084,7 +1090,7 @@ def _init_model( parallel_context = self.parallel_context parallel_config = config.parallelism - make_ddp = parallel_context.data_parallel_size > 1 and not ( + make_ddp = parallel_context.data_parallel_size > 1 or parallel_context.context_parallel_size > 1 and not ( config.optimizer.accumulate_grad_in_fp32 and config.optimizer.zero_stage > 0 ) @@ -1149,7 +1155,7 @@ def _init_model( # TODO @thomasw21: DDP doesn't support broadcasting complex buffers (and we don't really need that broadcasting anyway) model = DistributedDataParallel( model, - process_group=parallel_context.dp_pg, + process_group=parallel_context.dp_cp_pg, broadcast_buffers=False, bucket_cap_mb=config.model.ddp_bucket_cap_mb, ) @@ -1252,15 +1258,15 @@ def save_checkpoint(self) -> Path: # Update step/samples numbers before we save the config self.config.general.step = self.metadata.last_train_step - self.config.general.consumed_train_samples = self.metadata.consumed_train_samples + self.config.general.consumed_train_samples = self.metadata.consumed_train_samples # TODO: idc abt this save( model=self.unwrapped_model, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler, should_save_model=bool( - dist.get_rank(self.parallel_context.dp_pg) == 0 - ), # We only save the weights on DP==0 + dist.get_rank(self.parallel_context.dp_cp_pg) == 0 + ), # We only save the weights on DP_CP==0 should_save_optimizer=True, should_save_lr_scheduler=True, should_save_config=bool( From 13796b1f6d589f80fae168164534c2a4a7ec0a31 Mon Sep 17 00:00:00 2001 From: Sultan <92447824+SulRash@users.noreply.github.com> Date: Mon, 30 Jun 2025 14:56:05 +0300 Subject: [PATCH 09/10] Change types to match configs correctly --- src/nanotron/scaling/parametrization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index 8324eccf..71947fda 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -3,7 +3,7 @@ from enum import Enum, auto from typing import Dict -from nanotron.config import Config, ModelArgs +from nanotron.config import Config from nanotron.config.models_config import InitScalingMethod from nanotron.nn.layer_norm import LlamaRMSNorm, TritonRMSNorm from nanotron.nn.moe import GroupedMLP, Router @@ -22,7 +22,7 @@ class ParametrizationMethod(Enum): class Parametrizator: - def __init__(self, config: ModelArgs): + def __init__(self, config: Config): self.config = config def parametrize(self, param_name: str, module: nn.Module): @@ -128,7 +128,7 @@ class SpectralMupParametrizator(Parametrizator): https://arxiv.org/abs/2310.17813 """ - def __init__(self, config: ModelArgs): + def __init__(self, config: Config): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { TensorParallelColumnLinear: self._parametrize_mup_weight, From d1fbe1a925460cf0000999acc4fb1a0ccb997f2d Mon Sep 17 00:00:00 2001 From: Sultan <92447824+SulRash@users.noreply.github.com> Date: Mon, 30 Jun 2025 14:57:08 +0300 Subject: [PATCH 10/10] Made llama model file use correct config for parameterization --- src/nanotron/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index db820644..b63197b7 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -1092,7 +1092,7 @@ def init_model_randomly(self, config: Config): else: raise ValueError(f"Unknown init method {init_method}") - parametrizator = parametrizator_cls(config=config.model) + parametrizator = parametrizator_cls(config=config) log_rank( f"Parametrizing model parameters using {parametrizator.__class__.__name__}",