diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a776daf24..4a5ebde4e 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,3 +1,8 @@ + + # What does this PR do? \ No newline at end of file + diff --git a/.github/workflows/pr-rules.yaml b/.github/workflows/pr-rules.yaml new file mode 100644 index 000000000..b82d61baf --- /dev/null +++ b/.github/workflows/pr-rules.yaml @@ -0,0 +1,15 @@ +name: Check PR Source Branch +on: + pull_request: + branches: + - main + +jobs: + check-branch: + runs-on: ubuntu-latest + steps: + - name: Check PR source branch + if: github.base_ref == 'main' && github.head_ref != 'dev' + run: | + echo "ERROR: PRs to main must come from dev branch" + exit 1 diff --git a/README.md b/README.md index 719d0720b..1e4e44c51 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ Nanotron is a library for pretraining transformer models. It provides a simple a 📚 **Check out our [Ultrascale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook)** - A comprehensive guide to efficiently scale LLM training with Nanotron! +📝 **AI generated docs thanks to [DeepWiki](https://deepwiki.com/huggingface/nanotron)** + ## Installation To run the code in this project, first create a Python virtual environment using e.g. `uv`: @@ -98,7 +100,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 run_train.py --config- The model will be saved in the `checkpoints` directory as specified in the config file. > [!NOTE] -> You can use `examples/config_tiny_llama.py` to generate your own training config +> You can use `examples/config_tiny_llama.py` to generate your own training config For detailed instructions on training your first model, check out our [Your First Training guide](docs/your-first-training.md). For multi-node training with Slurm, see our [Multi-Node Training guide](docs/multi-node-training.md). @@ -108,7 +110,7 @@ For detailed instructions on training your first model, check out our [Your Firs torchrun --nproc_per_node=1 run_generate.py --ckpt-path checkpoints/{checkpoint_number}/ --tp 1 --pp 1 ``` -Increase the value of `--tp` (tensor paralle) to accelerate generation with multiple GPUs and use a larger value of `--pp` (pipeline parallel) for very large models. +Increase the value of `--tp` (tensor parallel) to accelerate generation with multiple GPUs and use a larger value of `--pp` (pipeline parallel) for very large models. ### Debugging with VSCode To debug with VSCode, add the following configuration to your `launch.json` file: @@ -173,6 +175,7 @@ We currently support the following features: - [x] Custom module checkpointing for large models - [x] Spectral µTransfer parametrization for scaling up neural networks - [x] Mamba example +- [x] CUDA event-based timing for accurate GPU performance measurement And we have on our roadmap: - [ ] FP8 training diff --git a/docs/cuda_event_timing.md b/docs/cuda_event_timing.md new file mode 100644 index 000000000..9f8f49351 --- /dev/null +++ b/docs/cuda_event_timing.md @@ -0,0 +1,92 @@ +# CUDA Event-Based Timing in Nanotron + +## Overview + +Nanotron now uses CUDA events for timing GPU operations instead of CPU-based timing with `time.time()`. This change provides several benefits: + +1. **More accurate measurement of GPU execution time**: CUDA events are recorded directly on the GPU timeline, providing more precise timing of GPU operations. +2. **Reduced need for explicit CUDA synchronization**: CPU-based timing requires synchronization between CPU and GPU to get accurate measurements, which can introduce overhead and affect performance. +3. **Lower overhead**: CUDA event-based timing has minimal impact on the execution of GPU operations. +4. **Better performance monitoring**: More accurate timing leads to better performance analysis and optimization. + +## Implementation Details + +The implementation uses `torch.cuda.Event` with `enable_timing=True` to create start and end events that are recorded on the GPU timeline. The elapsed time is then calculated using `start_event.elapsed_time(end_event)`, which returns the time in milliseconds. + +### Key Changes + +1. **Default Timer Type**: The default timer type in `nanotron/src/nanotron/logging/timers.py` has been changed from `TimerType.CPU` to `TimerType.CUDA`. + +2. **Iteration Timing**: The iteration timing in `trainer.py` now uses CUDA events instead of `time.time()`. + +3. **Synchronization Control**: By default, CUDA event-based timers do not force synchronization unless explicitly requested with `cuda_sync=True`. + +## Usage + +### Basic Usage + +```python +# Create and use a CUDA timer (default) +with nanotron_timer("my_operation"): + # Your GPU operation here + ... + +# Explicitly specify CUDA timing +with nanotron_timer("my_operation", timer_type="cuda"): + # Your GPU operation here + ... + +# For CPU-only operations, you can still use CPU-based timing +with nanotron_timer("cpu_operation", timer_type="cpu"): + # Your CPU operation here + ... + +# As a decorator with default CUDA timing +@nanotron_timer +def my_function(): + # Your GPU operation here + ... + +# As a decorator with custom name +@nanotron_timer("custom_name") +def my_function(): + # Your GPU operation here + ... + +# As a decorator with CPU timing +@nanotron_timer(timer_type=TimerType.CPU) +def my_cpu_function(): + # Your CPU operation here + ... +``` + +### Advanced Usage + +```python +# Start and end a timer manually +timer = nanotron_timer("my_operation") +timer.start() +# Your operation here +timer.end() + +# Get the elapsed time in seconds +elapsed_time = timer.elapsed + +# Get the total time across all calls +total_time = timer.total_time + +# Get the average time per call +avg_time = timer.average_time +``` + +## Considerations + +1. **Synchronization**: By default, CUDA event-based timers do not force synchronization to avoid overhead. If you need more accurate timing at the cost of performance, you can set `cuda_sync=True`. + +2. **Units**: CUDA events measure time in milliseconds, but the timer API converts this to seconds for consistency with the previous CPU-based timing. + +3. **Fallback**: If CUDA is not available, the timer will automatically fall back to CPU-based timing. + +## Performance Impact + +Using CUDA events for timing instead of CPU-based timing with synchronization can significantly reduce overhead, especially in distributed training scenarios with thousands of GPUs. diff --git a/examples/config_qwen.py b/examples/config_qwen.py index 639ed2d6b..a5d901b24 100644 --- a/examples/config_qwen.py +++ b/examples/config_qwen.py @@ -30,7 +30,7 @@ "410m": (24, 1024, 16, 16, 4096), # ~410M params # Small to medium models "1b": (16, 2048, 16, 16, 5632), # ~1B params - "3b": (28, 2048, 16, 2, 11008), # ~3B params + "3b": (36, 2048, 16, 4, 11008), # ~3B params # Standard sizes "7b": (32, 4096, 32, 32, 11008), # ~7B params "13b": (40, 5120, 40, 40, 13824), # ~13B params @@ -47,7 +47,7 @@ def get_args(): parser.add_argument( "--model", choices=MODEL_SIZES.keys(), - default="custom", + default="3b", help="Model size to generate config for (e.g., 7b, 13b)", ) parser.add_argument( @@ -76,6 +76,10 @@ def get_args(): tokens_group.add_argument("--mbs", type=int, default=3, help="Micro batch size") tokens_group.add_argument("--acc", type=int, default=1, help="Batch accumulation per replica") + # checkpoints + checkpoints_group = parser.add_argument_group("checkpoints") + checkpoints_group.add_argument("--ckpt-save", type=int, default=10, help="Checkpoint save interval") + args = parser.parse_args() return args @@ -108,7 +112,7 @@ def get_model_config(model_size: str) -> Qwen2Config: is_qwen2_config=True, pad_token_id=None, _attn_implementation="flash_attention_2", - sliding_window_size=20, + _use_doc_masking=True, ) @@ -154,7 +158,7 @@ def calculate_parameters(model_config: Qwen2Config) -> str: def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config: learning_rate = LRSchedulerArgs( - learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 + learning_rate=3e-4, lr_warmup_steps=2000, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=0 ) parallelism = ParallelismArgs( dp=args.dp, @@ -175,7 +179,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config ) optimizer = OptimizerArgs( zero_stage=args.zero, - weight_decay=0.01, + weight_decay=0.1, clip_grad=1.0, accumulate_grad_in_fp32=True, learning_rate_scheduler=learning_rate, @@ -192,7 +196,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config return Config( general=GeneralArgs(project="debug", run=args.run, seed=seed, ignore_sanity_checks=args.no_sanity), - checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=args.ckpt_save), parallelism=parallelism, model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), # tokenizer=TokenizerArgs("HuggingFaceTB/cosmo2-tokenizer"), @@ -219,7 +223,11 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config world_size = args.dp * args.tp * args.pp * args.cp if world_size <= 8: print( - f"CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}" + f"ENABLE_TIMERS=1 DEBUG_CPU=1 STATS_SAMPLING_INTERVAL_IN_SEC=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}" ) + print("You can also use environment variables for more debugging:") + print(" - ENABLE_TIMERS=1: Enable detailed timing information") + print(" - DEBUG_CPU=1: Log CPU and memory usage statistics") + print(" - STATS_SAMPLING_INTERVAL_IN_SEC=1: Set sampling interval for metrics collection") else: print("Checkout slurm_launcher.py to launch a multi-node job") diff --git a/examples/config_qwen.yaml b/examples/config_qwen.yaml index 5fc8e48ea..cf6f40fac 100644 --- a/examples/config_qwen.yaml +++ b/examples/config_qwen.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 10 + checkpoint_interval: 100000 checkpoints_path: checkpoints checkpoints_path_is_shared_file_system: false load_lr_scheduler: true @@ -30,9 +30,9 @@ data_stages: general: benchmark_csv_path: null consumed_train_samples: null - ignore_sanity_checks: false + ignore_sanity_checks: true project: debug - run: qwen_20250410_014907_16027793 + run: qwen_20250424_120835_16423158 seed: 42 step: null lighteval: null @@ -45,6 +45,7 @@ model: ddp_bucket_cap_mb: 25 dtype: bfloat16 init_method: + scaling_method: NUM_LAYERS std: 0.025 make_vocab_size_divisible_by: 1 model_config: @@ -58,15 +59,15 @@ model: eos_token_id: 2 flex_attention_mask: null hidden_act: silu - hidden_size: 256 + hidden_size: 2048 initializer_range: 0.02 - intermediate_size: 768 + intermediate_size: 11008 is_qwen2_config: true max_position_embeddings: 4096 moe_config: null no_rope_layer: null - num_attention_heads: 4 - num_hidden_layers: 12 + num_attention_heads: 16 + num_hidden_layers: 36 num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 @@ -74,7 +75,7 @@ model: rope_interleaved: false rope_scaling: null rope_theta: 10000.0 - sliding_window_size: 20 + sliding_window_size: null tie_word_embeddings: true use_cache: true vocab_size: 128256 @@ -104,11 +105,10 @@ parallelism: context_parallel_size: 1 dp: 2 expert_parallel_size: 1 - moe_layer_recompute: false pp: 1 pp_engine: 1f1b recompute_layer: false - tp: 1 + tp: 2 tp_linear_async_communication: true tp_mode: REDUCE_SCATTER tp_recompute_allgather: true diff --git a/examples/config_qwen_with_moe.yaml b/examples/config_qwen_with_moe.yaml new file mode 100644 index 000000000..5e51307ff --- /dev/null +++ b/examples/config_qwen_with_moe.yaml @@ -0,0 +1,132 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: /fsx/phuc/new_workspace/experiments/qwen2_moe_test + checkpoints_path_is_shared_file_system: false + load_lr_scheduler: true + load_optimizer: true + resume_checkpoint_path: null + save_final_state: true + save_initial_state: false +data_stages: +- data: + dataset: + dataset_folder: + - /fsx/loubna/datasets/llama_tokenized/fineweb-edu/merged + dataset_max_tokens: null + dataset_read_path: null + dataset_weights: null + pad_samples_to_global_batch_size: false + return_positions: true + shuffle_files: false + skip_in_stream: false + token_size_in_bytes: 4 + tokenizer_name: meta-llama/Llama-3.2-1B + use_old_brrr_dataloader: false + vocab_size: 128256 + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: false + project: qwen_moe + run: qwen_20250410_014907_16027793 + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +metrics_logging: null +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + _attn_implementation: flash_attention_2 + _fused_rms_norm: true + _fused_rotary_emb: true + _use_doc_masking: true + _use_qkv_packed: true + attention_bias: false + bos_token_id: 1 + eos_token_id: 2 + flex_attention_mask: null + hidden_act: silu + hidden_size: 256 + initializer_range: 0.02 + intermediate_size: 768 + is_qwen2_config: true + max_position_embeddings: 4096 + moe_config: null + no_rope_layer: null + num_attention_heads: 4 + num_hidden_layers: 12 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-06 + rope_interleaved: false + rope_scaling: null + rope_theta: 10000.0 + sliding_window_size: 20 + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 + z_loss_coefficient: 0.0001 + z_loss_enabled: false + moe_config: + num_experts: 8 + top_k: 1 + enable_shared_expert: true + token_dispatcher_type: alltoall +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 31998 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + weight_decay_exclude_named_params: [] + zero_stage: 0 +parallelism: + context_parallel_size: 1 + dp: 2 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + recompute_layer: false + tp: 1 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER + tp_recompute_allgather: true +profiler: null +s3_upload: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Llama-3.2-1B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 3 + sequence_length: 4096 + train_steps: 32000 + val_check_interval: -1 diff --git a/examples/inference/qwen_moe/README.md b/examples/inference/qwen_moe/README.md new file mode 100644 index 000000000..594b6ce53 --- /dev/null +++ b/examples/inference/qwen_moe/README.md @@ -0,0 +1,27 @@ +# Qwen-MoE Inference + +This guide explains how to convert Hugging face Qwen-MoE models to Nanotron format and run inference with them. + +## Convert Qwen-MoE to Nanotron Format + +Navigate to the `inference/qwen_moe` directory and run: + +```bash +torchrun --nproc-per-node 1 examples/inference/qwen_moe/convert.py \ + --nanotron-checkpoint-path nanotron_checkpoints/Qwen1.5-MoE-A2.7B \ + --pretrained-model-name-or-path Qwen/Qwen1.5-MoE-A2.7B +``` + +This command will save the converted model weights to the specified path in `nanotron_checkpoints` + +## Run Inference + +From the root directory of Nanotron, run: + +```bash +torchrun --rdzv_endpoint=localhost:29700 --rdzv-backend=c10d --nproc_per_node=1 \ + run_generate.py \ + --ckpt-path nanotron_checkpoints/Qwen1.5-MoE-A2.7B +``` + +This command will load the converted model weights and run inference. diff --git a/examples/inference/qwen_moe/convert.py b/examples/inference/qwen_moe/convert.py new file mode 100644 index 000000000..419495da2 --- /dev/null +++ b/examples/inference/qwen_moe/convert.py @@ -0,0 +1,329 @@ +""" +torchrun --nproc-per-node 1 convert.py --nanotron-checkpoint-path nanotron_checkpoints/Qwen1.5-MoE-A2.7B --pretrained-model-name-or-path Qwen/Qwen1.5-MoE-A2.7B +""" +import argparse +import json +from dataclasses import asdict +from pathlib import Path + +import torch +import yaml +from nanotron import logging +from nanotron.config import Config, GeneralArgs, LoggingArgs, ModelArgs, ParallelismArgs, TokenizerArgs +from nanotron.config.models_config import ExistingCheckpointInit, MoEConfig, Qwen2Config +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.models.qwen import Qwen2ForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize import TrainingMetadata, save_meta, save_weights +from nanotron.serialize.metadata import DataStageMetadata +from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2MoeConfig + +logger = logging.get_logger(__name__) + +# NOTE: We need to initialize the model on gpu, because RotaryEmbedding +# requires its buffer to be on gpu +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Qwen-MoE HF model + log_rank( + f"Loading pretrained qwen moe Model: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + hf_model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + ).to(DEVICE) + hf_config: Qwen2MoeConfig = hf_model.config + + # Set Nanotron Qwen2Config + nanotron_config = Qwen2Config( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_qwen2_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + attention_bias=True, # qwen-moe uses attention bias + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + rope_interleaved=False, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size, + moe_config=MoEConfig( + top_k=hf_config.num_experts_per_tok, + num_experts=hf_config.num_experts, + moe_intermediate_size=hf_config.moe_intermediate_size, + shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size, + router_aux_loss_coef=hf_config.router_aux_loss_coef, + enable_shared_expert=True, + ), + ) + + # Init Nanotron Qwen-MoE model + log_rank("Init empty Nanotron Qwen Moe Model", logger=logger, level=logging.INFO, rank=0) + nanotron_model = build_model( + model_builder=lambda: Qwen2ForTraining( + config=nanotron_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context, parallel_config=parallel_config) + sanity_check(root_module=nanotron_model) + + # Copy params from HF to Nanotron + log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + + with torch.no_grad(): + # token embeddings + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_( + hf_model.model.embed_tokens.weight + ) + + # Decoder layers + for i in tqdm( + range(nanotron_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_( + hf_model.model.layers[i].input_layernorm.weight + ) + + # Self attn + ## QKV + tmp_qkv_proj = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.weight, + hf_model.model.layers[i].self_attn.k_proj.weight, + hf_model.model.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + ## QKV bias + tmp_qkv_bias = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.bias, + hf_model.model.layers[i].self_attn.k_proj.bias, + hf_model.model.layers[i].self_attn.v_proj.bias, + ], + dim=0, + ) + assert tmp_qkv_bias.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.bias.shape + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.bias.copy_(tmp_qkv_bias) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_( + hf_model.model.layers[i].self_attn.o_proj.weight + ) + + # MLP + ## Router + assert ( + hf_model.model.layers[i].mlp.gate.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.router.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.router.weight.copy_(hf_model.model.layers[i].mlp.gate.weight) + + ## shared expert: Gate Up Proj + tmp_shared_expert = torch.cat( + [ + hf_model.model.layers[i].mlp.shared_expert.gate_proj.weight, + hf_model.model.layers[i].mlp.shared_expert.up_proj.weight, + ], + dim=0, + ) + assert ( + tmp_shared_expert.shape + == nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.gate_up_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.gate_up_proj.weight.copy_(tmp_shared_expert) + + ## shared expert: Down Proj + assert ( + hf_model.model.layers[i].mlp.shared_expert.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.down_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.down_proj.weight.copy_( + hf_model.model.layers[i].mlp.shared_expert.down_proj.weight + ) + + ## shared expert: Gate + assert ( + hf_model.model.layers[i].mlp.shared_expert_gate.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.shared_expert_gate.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.shared_expert_gate.weight.copy_( + hf_model.model.layers[i].mlp.shared_expert_gate.weight + ) + + ## experts: + # concatenate all gate_up_proj and down_proj weights for experts into merged_gate_up_proj and merged_down_proj + tmp_merged_gate_up_proj = torch.zeros( + nanotron_config.moe_config.num_experts, + nanotron_config.hidden_size, + 2 * nanotron_config.moe_config.moe_intermediate_size, + ) + tmp_merged_down_proj = torch.zeros( + nanotron_config.moe_config.num_experts, + nanotron_config.moe_config.moe_intermediate_size, + nanotron_config.hidden_size, + ) + + for j in range(nanotron_config.moe_config.num_experts): + ## Gate Up Proj + tmp_merged_gate_up_proj[j, :, : nanotron_config.moe_config.moe_intermediate_size] = ( + hf_model.model.layers[i].mlp.experts[j].gate_proj.weight.T + ) + tmp_merged_gate_up_proj[j, :, nanotron_config.moe_config.moe_intermediate_size :] = ( + hf_model.model.layers[i].mlp.experts[j].up_proj.weight.T + ) + + ## Down Proj + tmp_merged_down_proj[j] = hf_model.model.layers[i].mlp.experts[j].down_proj.weight.T + + # copy to merged_gate_up_proj and merged_down_proj + nanotron_model.model.decoder[i].pp_block.mlp.experts.merged_gate_up_proj.copy_(tmp_merged_gate_up_proj) + nanotron_model.model.decoder[i].pp_block.mlp.experts.merged_down_proj.copy_(tmp_merged_down_proj) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + hf_model.model.layers[i].post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) + + # LM_Head + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) + + log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + # Store weights + nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path) + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path) + + # Store metadata + log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0) + training_metadata = TrainingMetadata( + last_train_step=0, + consumed_train_samples=0, + data_stages=[DataStageMetadata(name="Empty", consumed_train_samples=0, start_training_step=0)], + ) + save_meta( + root_folder=nanotron_checkpoint_path, parallel_context=parallel_context, training_metadata=training_metadata + ) + # Store Tokenizer into Nanotron Checkpoint folder + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) + tokenizer.save_pretrained(nanotron_checkpoint_path) + + # Store Config and Model Config files + with open(nanotron_checkpoint_path / "config.yaml", "w") as f: + config = Config( + general=GeneralArgs(project="Nanotron", run="Qwen2-MoE"), + parallelism=parallel_config, + model=ModelArgs( + init_method=ExistingCheckpointInit(nanotron_checkpoint_path), + model_config=nanotron_config, + ), + tokenizer=TokenizerArgs(tokenizer_name_or_path=args.pretrained_model_name_or_path), + ) + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) + yaml.dump(config.as_dict(), f) + + with open(nanotron_checkpoint_path / "model_config.json", "w") as f: + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) + json.dump(asdict(nanotron_config), f) + + log_rank( + f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/pyproject.toml b/pyproject.toml index 390c32c40..26171e9f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "dacite", "tqdm", "datasets", + "torchtyping" ] [tool.setuptools.packages.find] diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 4a8472097..c16f076c1 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -17,6 +17,7 @@ from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit from nanotron.config.parallelism_config import ParallelismArgs from nanotron.config.utils_config import ( + InitScalingMethod, RecomputeGranularity, cast_str_to_pipeline_engine, cast_str_to_torch_dtype, @@ -460,6 +461,13 @@ def __post_init__(self): if self.s3_upload is not None: self.s3_upload.__post_init__() + if self.lighteval is not None: + if self.lighteval.eval_interval is None: + self.lighteval.eval_interval = self.checkpoints.checkpoint_interval + else: + assert ( + self.lighteval.eval_interval % self.checkpoints.checkpoint_interval == 0 + ), f"eval_interval={self.lighteval.eval_interval} must be a multiple of checkpoint_interval={self.checkpoints.checkpoint_interval}" # Some final sanity checks across separate arguments sections: if self.profiler is not None and self.profiler.profiler_export_path is not None: @@ -542,14 +550,15 @@ def global_batch_size(self): def global_batch_size_in_tokens(self): return self.global_batch_size * self.tokens.sequence_length - def save_as_yaml(self, file_path: str): + def save_as_yaml(self, file_path: str, sanity_checks: bool = True): config_dict = serialize(self) file_path = str(file_path) with open(file_path, "w") as f: yaml.dump(config_dict, f) # Sanity test config can be reloaded - _ = get_config_from_file(file_path, config_class=self.__class__) + if sanity_checks: + _ = get_config_from_file(file_path, config_class=self.__class__) def get_yaml(self): config_dict = serialize(self) @@ -620,6 +629,7 @@ def get_config_from_dict( PipelineEngine: cast_str_to_pipeline_engine, TensorParallelLinearMode: lambda x: TensorParallelLinearMode[x.upper()], RecomputeGranularity: lambda x: RecomputeGranularity[x.upper()], + InitScalingMethod: lambda x: InitScalingMethod[x.upper()], SamplerType: lambda x: SamplerType[x.upper()], }, # strict_unions_match=True, diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index b5f12059a..0806acffe 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -73,6 +73,22 @@ def __post_init__(self): assert self.wandb_project != "", "Please specify a wandb_project" +@dataclass +class LightEvalSlurm: + """Arguments related to SLURM configuration for LightEval""" + + gpus_per_node: int = 8 + partition: str = "hopper-prod" + hf_cache: str = "~/.cache/huggingface" + cpus_per_task: int = 88 + qos: str = "low" + time: str = "24:00:00" + reservation: Optional[str] = "smollm" + + def __post_init__(self): + self.hf_cache = str(Path(self.hf_cache).expanduser()) + + @dataclass class LightEvalConfig: """Arguments related to running LightEval on checkpoints. @@ -81,13 +97,48 @@ class LightEvalConfig: the saved config when running LightEval after training. """ - slurm_template: Optional[str] = None - slurm_script_dir: Optional[str] = None - - checkpoints_path: Optional[str] = None + slurm_script_dir: Optional[Path] = Path("eval_results/launch-config") + logs_path: Optional[Path] = Path("eval_results/logs") + local_checkpoint_dir: Path = Path( + "/scratch" + ) # Base directory for temporary checkpoint storage, will store under {local_checkpoint_dir}/{run_name}/{step} parallelism: Optional[ParallelismArgs] = None batch_size: Optional[int] = None generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None tasks: Optional[LightEvalTasksArgs] = None logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None + slurm: Optional[LightEvalSlurm] = None + s3_save_path: Optional[str] = None # should not be dependent of the run_name + upload_to_wandb: Optional[bool] = False + wandb_project: Optional[str] = None + wandb_entity: Optional[str] = None + output_dir: Optional[ + str + ] = None # we should sanity check that it's the same as the one in the eval_config_override + nanotron_path: Optional[str] = "./" + eval_config_override: str = None + eval_config_override: Path = None # Previously hardcoded in run_slurm_one_job + eval_interval: Optional[ + int + ] = None # Must be multiple of checkpoint_interval. If None, eval will be done after each checkpoint upload to s3 + eval_interval_file: Optional[ + Path + ] = None # If specified, eval_interval will be read from this file upon the next evaluation. + + def __post_init__(self): + if self.parallelism is None: + self.parallelism = ParallelismArgs(dp=1, pp=1, tp=1, tp_linear_async_communication=True) + if self.slurm is None: + self.slurm = LightEvalSlurm() + self.local_checkpoint_dir = str(Path(self.local_checkpoint_dir).expanduser()) + if self.upload_to_wandb: + assert ( + self.s3_save_path is not None + ), " We should have a s3_save_path if we want to upload to wandb" # todo: add the option to read from local folder i guess + assert self.wandb_project is not None, "wandb_project must be specified if upload_to_wandb is True" + assert self.wandb_entity is not None, "wandb_entity must be specified if upload_to_wandb is True" + if self.eval_interval_file is not None and Path(self.eval_interval_file).exists(): + logger.warning( + f"Eval interval file {self.eval_interval_file} exists. `eval_interval` will be replaced by the value in the file upon the next evaluation. You should probably delete this file if that's not what you want." + ) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 410634b87..999d13379 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Any, List, Optional, Union +from nanotron.config.utils_config import InitScalingMethod from nanotron.nn.attention import ALL_ATTENTION_FUNCTIONS, AttentionImplementation # The default attention implementation to use @@ -11,6 +12,7 @@ @dataclass class RandomInit: std: float + scaling_method: InitScalingMethod = InitScalingMethod.NUM_LAYERS @dataclass @@ -36,6 +38,9 @@ class MoEConfig: num_experts: int = 8 # Total number of experts top_k: int = 2 # Number of experts to route each token to + moe_intermediate_size: int = 1408 # Intermediate size of the MoE layer + shared_expert_intermediate_size: int = 5632 # Intermediate size of the shared expert + router_aux_loss_coef: float = 0.01 # Coefficient for the router auxiliary loss layers: List[int] = field( default_factory=lambda: [-1] ) # Indices of layers that use MoE. -1 means all layers. Default is all layers @@ -141,11 +146,13 @@ class Qwen2Config: sliding_window_size: Optional[int] = None z_loss_enabled: bool = False # Z-loss regularization https://www.jmlr.org/papers/volume24/22-1144/22-1144.pdf z_loss_coefficient: float = 0.0001 # Default from the paper (10^-4) - no_rope_layer: Optional[int] = None # Skip rope every no_rope_layer layers (see https://arxiv.org/abs/2501.18795 https://arxiv.org/abs/2305.19466 and Llama4) + no_rope_layer: Optional[ + int + ] = None # Skip rope every no_rope_layer layers (see https://arxiv.org/abs/2501.18795 https://arxiv.org/abs/2305.19466 and Llama4) _fused_rotary_emb: bool = True _fused_rms_norm: bool = True _use_qkv_packed: bool = True - _use_doc_masking: bool = True + _use_doc_masking: bool = False # MoE configuration moe_config: Optional[MoEConfig] = None diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 48aa941e8..40d95119a 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -33,8 +33,6 @@ class ParallelismArgs: tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None recompute_layer: bool = False - moe_layer_recompute: bool = False - tp_recompute_allgather: bool = True expert_parallel_size: int = 1 diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py index f4c071462..84e8079a4 100644 --- a/src/nanotron/config/utils_config.py +++ b/src/nanotron/config/utils_config.py @@ -18,6 +18,13 @@ class RecomputeGranularity(Enum): FULL = auto() +class InitScalingMethod(Enum): + NONE = auto() + NUM_LAYERS = auto() + LAYER_INDEX = auto() + MODEL_SCALE = auto() + + def serialize(data) -> dict: """Recursively serialize a nested dataclass to a dict - do some type conversions along the way""" if data is None: @@ -39,6 +46,8 @@ def serialize(data) -> dict: result[field.name] = value.name elif isinstance(value, RecomputeGranularity): result[field.name] = value.name + elif isinstance(value, InitScalingMethod): + result[field.name] = value.name elif isinstance(value, SamplerType): result[field.name] = value.name elif isinstance(value, torch.dtype): diff --git a/src/nanotron/data/clm_collator.py b/src/nanotron/data/clm_collator.py index 5c141adf0..89fd00830 100644 --- a/src/nanotron/data/clm_collator.py +++ b/src/nanotron/data/clm_collator.py @@ -97,6 +97,7 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni result["label_mask"] = np.ones((batch_size, self.sequence_length), dtype=np.bool_) # Context Parallelism: Each CP rank gets a slice of the label_ids and label_mask + cp_rank, cp_size = dist.get_rank(self.parallel_context.cp_pg), self.parallel_context.context_parallel_size local_slice = slice( cp_rank * self.sequence_length // cp_size, (cp_rank + 1) * self.sequence_length // cp_size ) diff --git a/src/nanotron/data/tokenized_bytes.py b/src/nanotron/data/tokenized_bytes.py index 4f9063eb6..880983bb6 100644 --- a/src/nanotron/data/tokenized_bytes.py +++ b/src/nanotron/data/tokenized_bytes.py @@ -369,12 +369,13 @@ def __init__( ) from datatrove.utils.dataset import url_to_fs - fs_folder, folder_path = url_to_fs(folder_path) + fs_folder, stripped_folder_path = url_to_fs(folder_path) matched_files = ( - fs_folder.find(folder_path, detail=False, maxdepth=1 if not recursive else None) + fs_folder.find(stripped_folder_path, detail=False, maxdepth=1 if not recursive else None) if not filename_pattern else fs_folder.glob( - os.path.join(folder_path, filename_pattern), maxdepth=1 if not recursive else None + os.path.join(stripped_folder_path, filename_pattern), + maxdepth=1 if not recursive else None, ) ) matched_files = sorted(matched_files) diff --git a/src/nanotron/debug_utils.py b/src/nanotron/debug_utils.py new file mode 100644 index 000000000..a0d576d90 --- /dev/null +++ b/src/nanotron/debug_utils.py @@ -0,0 +1,47 @@ +import logging +from typing import Optional + +import torch +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizer + +logger = logging.getLogger(__name__) + +def debug_dataloader_samples( + dataloader: DataLoader, + tokenizer: PreTrainedTokenizer, + num_samples: int = 2 +) -> None: + """ + Debug utility to inspect samples from a DataLoader. + + This function pulls the first batch from the given DataLoader, + detokenizes the 'input_ids' using the provided tokenizer, + and prints the decoded texts for a few samples. + + Args: + dataloader (torch.utils.data.DataLoader): The DataLoader to inspect. + tokenizer (PreTrainedTokenizer): Tokenizer used to decode input_ids. + num_samples (int): Number of samples to print from the first batch. + """ + try: + batch = next(iter(dataloader)) + except Exception as e: + logger.error("[debug] Failed to retrieve batch from dataloader: %s", e) + return + + input_ids = batch.get("input_ids") + if input_ids is None: + logger.warning("[debug] 'input_ids' not found in batch. Available keys: %s", list(batch.keys())) + return + + if hasattr(input_ids, "cpu"): + input_ids = input_ids.cpu() + + logger.info("\n[Debug] Printing detokenized samples from the first batch:\n") + for i in range(min(num_samples, len(input_ids))): + try: + decoded = tokenizer.decode(input_ids[i], skip_special_tokens=True) + logger.info("[Sample %d]:\n%s\n%s", i+1, decoded, "=" * 40) + except Exception as e: + logger.error("[debug] Failed to decode sample %d: %s", i+1, e) \ No newline at end of file diff --git a/src/nanotron/eval/README.md b/src/nanotron/eval/README.md new file mode 100644 index 000000000..05bfe1623 --- /dev/null +++ b/src/nanotron/eval/README.md @@ -0,0 +1,13 @@ +# Nanotron Evaluation + +This directory contains code for evaluating models trained with Nanotron. + +## Installation + +To use the evaluation functionality, you need to install the `lighteval` package: + +```bash +uv pip install lighteval[dev] +``` + +## Usage diff --git a/src/nanotron/eval/__init__.py b/src/nanotron/eval/__init__.py new file mode 100644 index 000000000..d7ea002c5 --- /dev/null +++ b/src/nanotron/eval/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa: F401 + +from .one_job_runner import LightEvalRunner diff --git a/src/nanotron/eval/evaluation_tasks.py b/src/nanotron/eval/evaluation_tasks.py new file mode 100644 index 000000000..2543df313 --- /dev/null +++ b/src/nanotron/eval/evaluation_tasks.py @@ -0,0 +1,368 @@ +from functools import partial + +from lighteval.metrics.dynamic_metrics import ( + loglikelihood_acc_metric, +) +from lighteval.metrics.normalizations import LogProbCharNorm, LogProbPMINorm, LogProbTokenNorm +from lighteval.tasks.default_prompts import LETTER_INDICES +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.multilingual.adapters import ( + winogrand_adapter, +) +from lighteval.tasks.multilingual.tasks import TASKS_TABLE as ML_TASKS_TABLE +from lighteval.tasks.multilingual.utils.task_utils import get_metrics_for_formulation +from lighteval.tasks.templates.continuation import get_continuation_prompt_function +from lighteval.tasks.templates.hellaswag import get_hellaswag_prompt_function +from lighteval.tasks.templates.multichoice import get_mcq_prompt_function +from lighteval.tasks.templates.utils.formulation import ( + CFFormulation, + HybridFormulation, + MCFFormulation, +) +from lighteval.utils.language import Language + +TASKS_TABLE = [] + +TASKS_TABLE.extend(ML_TASKS_TABLE) + +arc_tasks = [ + LightevalTaskConfig( + name=f"arc_{formulation.name.lower()}:{subset.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["choices"]["text"], + "gold_idx": int(line["answerKey"]) - 1 + if line["answerKey"].isdigit() + else LETTER_INDICES.index(line["answerKey"]), + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="allenai/ai2_arc", + hf_subset=f"ARC-{subset}", + hf_revision="210d026faf9955653af8916fad021475a3f00453", + trust_dataset=True, + evaluation_splits=("test",), + few_shots_split="train", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for subset in ["Easy", "Challenge"] + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(arc_tasks) + +hellaswag_tasks = [ + LightevalTaskConfig( + name=f"hellaswag_{formulation.name.lower()}", + suite=["custom"], + prompt_function=get_hellaswag_prompt_function( + language=Language.ENGLISH, + adapter=lambda line: { + "activity_label": line["activity_label"], + "ctx_a": line["ctx_a"], + "ctx_b": line["ctx_b"], + "continuations": line["endings"], + "gold_idx": int(line["label"]), + }, + formulation=formulation, + ), + hf_repo="Rowan/hellaswag", + hf_subset="default", + hf_revision="6002345709e0801764318f06bf06ce1e7d1a1fe3", + evaluation_splits=["validation"], + hf_avail_splits=["validation"], + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + trust_dataset=True, + ) + for formulation in [MCFFormulation(), CFFormulation(), HybridFormulation()] +] + +TASKS_TABLE.extend(hellaswag_tasks) + +commonsense_qa_tasks = [ + LightevalTaskConfig( + name=f"commonsenseqa_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["choices"]["text"], + "gold_idx": line["choices"]["label"].index(line["answerKey"].strip()), + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="tau/commonsense_qa", + hf_subset="default", + hf_revision="94630fe30dad47192a8546eb75f094926d47e155", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(commonsense_qa_tasks) + +openbook_qa_tasks = [ + LightevalTaskConfig( + name=f"openbookqa_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question_stem"], + "choices": line["choices"]["text"], + "gold_idx": LETTER_INDICES.index(line["answerKey"]), + }, + formulation=formulation, + ), + suite=["custom"], + hf_repo="allenai/openbookqa", + hf_subset="main", + hf_revision="388097ea7776314e93a529163e0fea805b8a6454", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(openbook_qa_tasks) + +winogrande_tasks = [ + LightevalTaskConfig( + name=f"winogrande_{formulation.name.lower()}", + suite=("custom",), + prompt_function=get_continuation_prompt_function( + Language.ENGLISH, partial(winogrand_adapter, Language.ENGLISH), formulation=formulation + ), + hf_repo="allenai/winogrande", + hf_subset="winogrande_xl", + trust_dataset=True, + hf_revision="85ac5b5a3b7a930e22d590176e39460400d19e41", + metric=[ + loglikelihood_acc_metric(normalization=None), + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(winogrande_tasks) + +piqa_tasks = [ + LightevalTaskConfig( + name=f"piqa_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["goal"], + "choices": [line["sol1"], line["sol2"]], + "gold_idx": int(line["label"]), + }, + formulation=formulation, + ), + suite=["custom"], + hf_repo="ybisk/piqa", + hf_revision="2e8ac2dffd59bac8c3c6714948f4c551a0848bb0", + hf_subset="plain_text", + trust_dataset=True, + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(piqa_tasks) + + +MMLU_SUBSETS = [ + "abstract_algebra", + "anatomy", + "astronomy", + "business_ethics", + "clinical_knowledge", + "college_biology", + "college_chemistry", + "college_computer_science", + "college_mathematics", + "college_medicine", + "college_physics", + "computer_security", + "conceptual_physics", + "econometrics", + "electrical_engineering", + "elementary_mathematics", + "formal_logic", + "global_facts", + "high_school_biology", + "high_school_chemistry", + "high_school_computer_science", + "high_school_european_history", + "high_school_geography", + "high_school_government_and_politics", + "high_school_macroeconomics", + "high_school_mathematics", + "high_school_microeconomics", + "high_school_physics", + "high_school_psychology", + "high_school_statistics", + "high_school_us_history", + "high_school_world_history", + "human_aging", + "human_sexuality", + "international_law", + "jurisprudence", + "logical_fallacies", + "machine_learning", + "management", + "marketing", + "medical_genetics", + "miscellaneous", + "moral_disputes", + "moral_scenarios", + "nutrition", + "philosophy", + "prehistory", + "professional_accounting", + "professional_law", + "professional_medicine", + "professional_psychology", + "public_relations", + "security_studies", + "sociology", + "us_foreign_policy", + "virology", + "world_religions", +] + +mmlu_tasks = [ + LightevalTaskConfig( + name=f"mmlu_{formulation.name.lower()}:{subset}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["choices"], + "gold_idx": int(line["answer"]), + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="cais/mmlu", + hf_subset=subset, + hf_revision="c30699e8356da336a370243923dbaf21066bb9fe", + trust_dataset=True, + evaluation_splits=("test",), + few_shots_split="dev", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for subset in MMLU_SUBSETS + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(mmlu_tasks) + +mmlu_pro_tasks = [ + LightevalTaskConfig( + name=f"mmlu_pro_{formulation.name.lower()}", + prompt_function=get_mcq_prompt_function( + Language.ENGLISH, + lambda line: { + "question": line["question"], + "choices": line["options"], + "gold_idx": line["answer_index"], + }, + formulation=formulation, + ), + suite=("custom",), + hf_repo="TIGER-Lab/MMLU-Pro", + hf_subset="default", + hf_revision="3373e0b32277875b8db2aa555a333b78a08477ea", + trust_dataset=True, + evaluation_splits=("test",), + few_shots_split="validation", + metric=get_metrics_for_formulation( + formulation, + [ + loglikelihood_acc_metric(normalization=LogProbTokenNorm()), + loglikelihood_acc_metric(normalization=LogProbCharNorm()), + loglikelihood_acc_metric(normalization=LogProbPMINorm()), + ], + ), + ) + for formulation in [ + MCFFormulation(), + CFFormulation(), + HybridFormulation(), + ] +] + +TASKS_TABLE.extend(mmlu_pro_tasks) + + +if __name__ == "__main__": + print(t.name for t in TASKS_TABLE) + print(len(TASKS_TABLE)) diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py new file mode 100644 index 000000000..6567ec94a --- /dev/null +++ b/src/nanotron/eval/one_job_runner.py @@ -0,0 +1,382 @@ +""" Mostly complete a SLURM template with a link to a single checkpoint on s3 and launch it +""" +import datetime +import math +import os +import subprocess +from typing import List, Optional, Tuple + +from datasets.download.streaming_download_manager import xPath + +from nanotron import logging +from nanotron.config import Config, LightEvalConfig +from nanotron.data.s3_utils import _get_s3_path_components +from nanotron.logging import log_rank +from nanotron.parallel import ParallelContext + +logger = logging.get_logger(__name__) + + +class LightEvalRunner: + def __init__(self, config: Config, parallel_context: Optional[ParallelContext] = None): + self.config = config + assert config.lighteval is not None, "LightEval config is required" + self.lighteval_config = config.lighteval + self.parallel_context = parallel_context + + def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: + """Run light evaluation on uploaded files.""" + if ( + self.config.lighteval.eval_interval is not None + and self.config.general.step % self.config.lighteval.eval_interval != 0 + ): + logger.debug( + f"Skipping evaluation at step {self.config.general.step} because eval_interval is {self.config.lighteval.eval_interval}" + ) + return + config_files = [ + f for f in uploaded_files if "config.py" in f["destination"] or "config.yaml" in f["destination"] + ] + # Sanity check on the config files len (we want only one) + if len(config_files) == 0: + log_rank( + "No config files founds in uploaded checkpoints. Not running evaluation.", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return + if len(config_files) > 1: + log_rank( + f"Found multiple config files in uploaded checkpoints: {config_files}", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return + checkpoint_path = config_files[0]["destination"].replace("config.yaml", "") + logger.warning( + f"Lighteval Runner got {len(uploaded_files)} files. Using {checkpoint_path} as checkpoint path." + ) + if self.config.general.step % self.lighteval_config.eval_interval == 0: + slurm_job_id, slurm_log = run_slurm_one_job( + config=self.config, + lighteval_config=self.lighteval_config, + model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + ) + else: + logger.warning( + f"Skipping evaluation at step {self.config.general.step} because it's not a multiple of {self.lighteval_config.eval_interval}" + ) + return None, None + + return slurm_job_id, slurm_log + + +def normalize_s3_path(path: str) -> str: + """Normalize S3 path using existing s3_utils""" + # Use existing utility to normalize path components + path = xPath(path) + bucket, prefix = _get_s3_path_components(path) + # Reconstruct normalized path + return f"s3://{bucket}/{prefix}".rstrip("/") + + +def run_slurm_one_job( + config: Config, + lighteval_config: LightEvalConfig, + model_checkpoint_path: str, + current_step: int, +): + """Launch a single job on Slurm with the given mapping""" + # Normalize S3 path if needed + if model_checkpoint_path.startswith(("s3:/", "s3://")): + model_checkpoint_path = normalize_s3_path(model_checkpoint_path) + logger.info(f"Normalized S3 path: {model_checkpoint_path}") + + # Use config values instead of hardcoded defaults + slurm_config = lighteval_config.slurm + + # Calculate the number of nodes based on parallelism config + total_gpus_needed = ( + lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp + ) + nodes = math.ceil(total_gpus_needed / slurm_config.gpus_per_node) + + # Get timestamp for log files + timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + general_run_name = config.general.run + run_name = f"{timestamp}-eval_{general_run_name}".replace(" ", "_") + + # Use lighteval config paths if available, otherwise use defaults + eval_launch_script_path = lighteval_config.slurm_script_dir + eval_logs_path = lighteval_config.logs_path + eval_launch_script_path = os.path.join(eval_launch_script_path, general_run_name, f"step-{current_step}") + eval_logs_path = os.path.join(eval_logs_path, general_run_name, f"step-{current_step}") + + # Create directories + os.makedirs(eval_launch_script_path, exist_ok=True) + os.makedirs(eval_logs_path, exist_ok=True) + + # Use configured local path instead of hardcoded /tmp + local_path = os.path.join(lighteval_config.local_checkpoint_dir, run_name, str(current_step)) + nanotron_path = lighteval_config.nanotron_path + # Create the SLURM script content + slurm_script = f"""#!/bin/bash +#SBATCH --job-name=eval_{current_step}_{run_name} +#SBATCH --partition={slurm_config.partition} +#SBATCH --nodes={nodes} +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task={slurm_config.cpus_per_task} +#SBATCH --gpus={slurm_config.gpus_per_node} +#SBATCH --exclusive +#SBATCH --qos={slurm_config.qos} +#SBATCH --time={slurm_config.time} +#SBATCH --output={eval_logs_path}/%j-{timestamp}.out +#SBATCH --requeue""" + + if slurm_config.reservation: + slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" + + # Rest of the script content + slurm_script += f""" + +set -x + +LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={local_path} + +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token setup +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +# Set environment variables +export CUDA_DEVICE_MAX_CONNECTIONS=1 +# export CUBLAS_WORKSPACE_CONFIG=":4096:8" + +# Set HuggingFace cache locations +export HUGGINGFACE_HUB_CACHE={slurm_config.hf_cache} +export HF_DATASETS_CACHE={slurm_config.hf_cache} +export HF_MODULES_CACHE={slurm_config.hf_cache} +export HF_HOME={slurm_config.hf_cache} + +echo "Running on $COUNT_NODE nodes: $HOSTNAMES" + +# Create checkpoint directory +mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER + +# Handle S3 paths +if [[ "{model_checkpoint_path}" == s3://* ]]; then + echo "Downloading checkpoint from S3: {model_checkpoint_path}" + + # First check if the S3 path exists + if ! s5cmd ls "{model_checkpoint_path}" &>/dev/null; then + echo "Error: S3 path {model_checkpoint_path} does not exist" + exit 1 + fi + + # Try sync command and check its exit status + s5cmd cp \\ + --concurrency=50 \\ + --exclude "optimizer/*" \\ + --exclude "random/*" \\ + --exclude "lr_scheduler/*" \\ + --part-size 100 \\ + "{model_checkpoint_path}/*" "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/" + + if [ $? -ne 0 ]; then + echo "Error: Failed to sync files from S3" + exit 1 + fi + + # Verify that config.yaml was downloaded + if [ ! -f "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml" ]; then + echo "Error: config.yaml not found in downloaded checkpoint" + echo "Contents of $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER:" + ls -la $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + exit 1 + fi +else + echo "Copying checkpoint files from {model_checkpoint_path} to $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER" + rsync -av --progress --inplace --no-whole-file \\ + --exclude 'optimizer/' \\ + --exclude 'random/' \\ + --exclude 'lr_scheduler/' \\ + {model_checkpoint_path} $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + + if [ $? -ne 0 ]; then + echo "Error: Failed to copy files using rsync" + exit 1 + fi + + # Verify that config.yaml was copied + if [ ! -f "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml" ]; then + echo "Error: config.yaml not found in copied checkpoint" + echo "Contents of $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER:" + ls -la $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + exit 1 + fi +fi + +echo "Contents of checkpoint directory:" +ls -la $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/ + +# Add random sleep to avoid hub request conflicts +# sleep $(( RANDOM % 300 )) + +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun \\ + --nproc_per_node {slurm_config.gpus_per_node} \\ + --nnodes $COUNT_NODE \\ + --node_rank $SLURM_PROCID \\ + --master_addr $MASTER_ADDR \\ + --master_port $MASTER_PORT \\ + {nanotron_path}/run_evals.py \\ + --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \\ + --lighteval-override {lighteval_config.eval_config_override} + --cache-dir {slurm_config.hf_cache}""" + if lighteval_config.output_dir is not None and lighteval_config.s3_save_path is not None: + slurm_script += f""" +s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path}/ +""" + if lighteval_config.upload_to_wandb: + gbs_tok = ( + config.parallelism.dp + * config.tokens.micro_batch_size + * config.tokens.sequence_length + * config.tokens.batch_accumulation_per_replica + ) + slurm_script += f""" +python {nanotron_path}/src/nanotron/eval/upload_to_wandb.py \\ + --wandb_project {lighteval_config.wandb_project} \\ + --wandb_entity {lighteval_config.wandb_entity} \\ + --model_name {general_run_name} \\ + --results_path {lighteval_config.s3_save_path}/results/results/{general_run_name}/{current_step}/ \\ + --train_step {current_step} \\ + --consumed_tokens {current_step*gbs_tok} +""" + slurm_script += """ +echo "Cleaning up downloaded checkpoints..." +rm -rf "$LOCAL_DOWNLOAD_CHECKPOINT_FOLDER" +echo "Cleanup completed" + +echo "END TIME: $(date)" +""" + + # Write the script to file + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + launch_script_path = os.path.join(eval_launch_script_path, f"launch_script-{current_time}.slurm") + os.makedirs(os.path.dirname(launch_script_path), exist_ok=True) + + with open(launch_script_path, "w") as f: + f.write(slurm_script) + + # Preserve important environment variables + env = { + "PATH": os.environ["PATH"], + "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""), + "HOME": os.path.expanduser("~"), + } + + try: + # Use subprocess.run instead of check_output for better error handling + result = subprocess.run(["sbatch", launch_script_path], env=env, check=True, capture_output=True, text=True) + output = result.stdout + job_ids = output.split()[-1] + + output_log = os.path.join(eval_logs_path, f"{timestamp}-{run_name}-{job_ids}.out") + + logger.warning( + f"""🚀 Slurm job launched successfully: + Job name: {run_name} + Job ID: {job_ids} + Launch script: {launch_script_path} + Log file: {output_log}""" + ) + except subprocess.CalledProcessError as e: + logger.error(f"Error while launching Slurm job: {e}") + logger.error(f"Command output: {e.output}") + logger.error(f"Command stderr: {e.stderr}") + job_ids = None + output_log = None + + return job_ids, output_log + + +if __name__ == "__main__": + + from nanotron.config.config import Config + + # Load existing config from checkpoint + # checkpoint_path = "checkpoints/smollm3-training-test-tps-48nn-seed-6-/10" + # config_path = os.path.join(checkpoint_path, "config.yaml") + checkpoint_path = "s3://smollm3/smollm3-3B-final/3B-final-GQA-noTP-2k-seq/20000/" + config_path = "checkpoints/smollm3-training-test-tps-48nn-seed-6-/10/config.yaml" + try: + # Load the existing config + print(f"\nLoading config from: {config_path}") + config = Config.load_from_yaml(config_path) + + # Print config details + print("\nConfig details:") + print(f"Project: {config.general.project}") + print(f"Run: {config.general.run}") + print(f"Step: {config.general.step}") + + if config.lighteval: + print("\nLightEval config:") + print( + f"Parallelism: dp={config.lighteval.parallelism.dp}, tp={config.lighteval.parallelism.tp}, pp={config.lighteval.parallelism.pp}" + ) + print(f"Batch size: {config.lighteval.batch_size}") + print(f"Slurm template: {config.lighteval.slurm_template}") + print(f"Checkpoints path: {config.lighteval.checkpoints_path}") + if config.lighteval.tasks: + print(f"Tasks: {config.lighteval.tasks.tasks}") + print(f"Custom tasks: {config.lighteval.tasks.custom_tasks}") + print(f"Max samples: {config.lighteval.tasks.max_samples}") + + # Create test files structure + test_files = [ + { + "destination": os.path.join(checkpoint_path, "config.yaml"), + "source": "existing_config", + } + ] + + if config.lighteval is None: + config.lighteval = LightEvalConfig() + + print("\nInitializing LightEvalRunner...") + runner = LightEvalRunner(config=config) + + print("\nTesting LightEvalRunner.eval_single_checkpoint()...") + job_id, log_path = runner.eval_single_checkpoint(test_files) + + except Exception as e: + print(f"\nError during test: {str(e)}") + import traceback + + traceback.print_exc() + + finally: + print("\nTest completed") diff --git a/src/nanotron/eval/upload_to_wandb.py b/src/nanotron/eval/upload_to_wandb.py new file mode 100644 index 000000000..aa8c12d41 --- /dev/null +++ b/src/nanotron/eval/upload_to_wandb.py @@ -0,0 +1,87 @@ +import json +import s3fs +import wandb +import re +import argparse +from wandb.sdk.lib.runid import generate_id + + +def push_to_wandb(wandb_project, wandb_entity, model_name, results_path, train_step, consumed_tokens): + s3 = s3fs.S3FileSystem(anon=False) + all_metrics = { + # basic X axis replacements for all metrics + "consumed_tokens": consumed_tokens, + "train_step": train_step, + } + + for result_file in sorted(s3.ls(results_path)): + if not result_file.endswith(".json"): + continue + + with s3.open(result_file, "r") as f: + results = json.loads(f.read())["results"] + + for benchmark, metrics in results.items(): + if benchmark == "all": + continue + + # extract dataset and config name + match = re.search(r"\|(.*?)(?::(.*?))?\|", benchmark) + if match: + dataset, subtask = match.groups() + + for metric_name, metric_value in metrics.items(): + if "_stderr" in metric_name: + continue + # wandb-friendly metric name + wandb_metric = f"{dataset}/{subtask}/{metric_name}" if subtask else f"{dataset}/{metric_name}" + all_metrics[wandb_metric] = metric_value + + run_id = f"{model_name}-{generate_id()}" + + # try to find the run in wandb and resume it + api = wandb.Api() + runs = api.runs(f"{wandb_entity}/{wandb_project}") + for run in runs: + if run.name == model_name: + run_id = run.id + break + + wandb.init( + project=wandb_project, + entity=wandb_entity, + name=model_name, + id=run_id, + config={ + "model_name": model_name, + }, + resume="allow", + ) + + # log all metrics for this checkpoint + wandb.log(all_metrics) + + wandb.finish() + +if __name__ == "__main__": + # Setup argument parser + parser = argparse.ArgumentParser(description="Upload evaluation results to Weights & Biases.") + parser.add_argument("--wandb_project", type=str, required=True, help="WandB project name.") + parser.add_argument("--wandb_entity", type=str, required=True, help="WandB entity name.") + parser.add_argument("--model_name", type=str, required=True, help="Name of the model.") + parser.add_argument("--results_path", type=str, required=True, help="S3 path to the results directory.") + parser.add_argument("--train_step", type=int, required=True, help="Training step corresponding to the checkpoint.") + parser.add_argument("--consumed_tokens", type=int, required=True, help="Total consumed tokens up to this checkpoint.") + + # Parse arguments + args = parser.parse_args() + + # Call the main function with parsed arguments + push_to_wandb( + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, + model_name=args.model_name, + results_path=args.results_path, + train_step=args.train_step, + consumed_tokens=args.consumed_tokens + ) diff --git a/src/nanotron/logging/base.py b/src/nanotron/logging/base.py index e84554ee1..5cde1bb14 100644 --- a/src/nanotron/logging/base.py +++ b/src/nanotron/logging/base.py @@ -229,6 +229,7 @@ def log_rank( rank: Optional[int] = None, category: Optional[str] = None, is_separator: bool = False, + main_rank_only: bool = False, **kwargs, ): """Log only if the current process is the rank specified.""" @@ -246,6 +247,8 @@ def log_rank( kwargs["extra"] = kwargs.get("extra", {}) kwargs["extra"]["separator"] = True + if main_rank_only: + rank = 0 # rank is None means everyone logs if rank is None or dist.get_rank(group) == rank: if is_separator: @@ -265,7 +268,7 @@ def warn_once( def human_format(num: float, billions: bool = False, divide_by_1024: bool = False) -> str: if abs(num) < 1: return "{:.3g}".format(num) - SIZES = ["", "K", "M", "G", "T", "P", "E"] + SIZES = ["", "K", "M", "B", "T", "P", "E"] num = float("{:.3g}".format(num)) magnitude = 0 i = 0 diff --git a/src/nanotron/logging/timers.py b/src/nanotron/logging/timers.py index 1129b9c6c..e3603f118 100644 --- a/src/nanotron/logging/timers.py +++ b/src/nanotron/logging/timers.py @@ -19,15 +19,23 @@ class TimerType(Enum): @dataclass class TimerRecord: - """Records timing information for a single timer.""" + """ + Records timing information for a single timer. + + By default, uses CUDA events for timing GPU operations, which provides more accurate + measurements of GPU execution time without forcing CPU-GPU synchronization. + + For CPU-only operations, you can use CPU-based timing by specifying timer_type=TimerType.CPU. + """ name: str - timer_type: TimerType = TimerType.CPU + timer_type: TimerType = TimerType.CUDA start_time: float = 0.0 end_time: float = 0.0 running: bool = False call_count: int = 0 cuda_sync: bool = False # Option to add CUDA synchronization for more accurate timings + enabled: bool = True # Allow individual timer to be enabled/disabled # For CPU timers we still track total_time _cpu_total_time: float = 0.0 @@ -48,7 +56,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def start(self) -> "TimerRecord": """Start the timer.""" - if self.name == "dummy": # disabled + if self.name == "dummy" or not self.enabled: # disabled return self if self.running: @@ -75,7 +83,7 @@ def start(self) -> "TimerRecord": def end(self) -> None: """End the timer, but don't compute elapsed time yet.""" - if self.name == "dummy": # disabled + if self.name == "dummy" or not self.enabled: # disabled return if not self.running: @@ -175,7 +183,17 @@ def average_time(self) -> float: class Timers: - """A collection of timers for tracking execution time in Nanotron.""" + """ + A collection of timers for tracking execution time in Nanotron. + + By default, timers use CUDA events for timing GPU operations, which provides several benefits: + 1. More accurate measurement of GPU execution time + 2. Reduced need for explicit CUDA synchronization + 3. Lower overhead compared to CPU-based timing with synchronization + 4. Better performance monitoring for distributed training + + For CPU-only operations, you can still use CPU-based timing by specifying timer_type=TimerType.CPU. + """ _instance = None _enabled = os.environ.get("ENABLE_TIMERS", "0") == "1" # Add global enable/disable flag @@ -202,39 +220,56 @@ def is_enabled(cls) -> bool: return cls._enabled def __call__( - self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU, cuda_sync: bool = True + self, + name: str, + timer_type: Union[TimerType, str] = TimerType.CUDA, + cuda_sync: bool = False, + enabled: bool = bool(int(os.environ.get("ENABLE_TIMERS", "0"))), ) -> TimerRecord: """Get or create a timer with the given name. Can be used as a decorator, context manager, or directly: - - @nanotron_timer("name") # As decorator + - @nanotron_timer # As decorator with default CUDA timing + - @nanotron_timer("my_function") # As decorator with custom name + - @nanotron_timer(timer_type=TimerType.CPU) # As decorator with CPU timing - with nanotron_timer("name"): ... # As context manager - nanotron_timer("name").start(); ...; nanotron_timer("name").end() # Direct use Args: name: Name of the timer - timer_type: Type of timer, either TimerType.CPU or TimerType.CUDA - (or 'cpu'/'cuda' strings) - cuda_sync: Whether to perform torch.cuda.synchronize() for more accurate CUDA timing + timer_type: Type of timer, either TimerType.CUDA (default) or TimerType.CPU + (or 'cuda'/'cpu' strings) + cuda_sync: Whether to perform torch.cuda.synchronize() for more accurate CUDA timing. + Default is False to avoid unnecessary synchronization overhead. + enabled: Override default enabled setting from environment variable + + Raises: + ValueError: If a timer with the same name already exists with different settings """ - if not self._enabled: - # Return a dummy timer that does nothing when timing is disabled - return TimerRecord(name="dummy", timer_type=TimerType.CPU) - if isinstance(timer_type, str): timer_type = TimerType(timer_type) - if callable(name) and timer_type == TimerType.CPU: - # Being used as a decorator with default settings + if callable(name): + # Being used as a decorator with specified or default settings func = name timer_name = func.__name__ - return self._create_timer_decorator(timer_name, TimerType.CPU, cuda_sync)(func) + return self._create_timer_decorator(timer_name, timer_type, cuda_sync, enabled)(func) - if name not in self._timers: - self._timers[name] = TimerRecord(name=name, timer_type=timer_type, cuda_sync=cuda_sync) - else: - # Update the cuda_sync option if the timer already exists - self._timers[name].cuda_sync = cuda_sync + if name in self._timers: + existing_timer = self._timers[name] + if ( + existing_timer.timer_type != timer_type + or existing_timer.cuda_sync != cuda_sync + or existing_timer.enabled != enabled + ): + raise ValueError( + f"Timer '{name}' already exists with different settings.\n" + f"Existing: type={existing_timer.timer_type}, cuda_sync={existing_timer.cuda_sync}, enabled={existing_timer.enabled}\n" + f"New: type={timer_type}, cuda_sync={cuda_sync}, enabled={enabled}" + ) + return existing_timer + + self._timers[name] = TimerRecord(name=name, timer_type=timer_type, cuda_sync=cuda_sync, enabled=enabled) # Check if we're being called as a decorator if not callable(name): @@ -243,9 +278,9 @@ def __call__( return timer_record # If we get here, we're being called as @nanotron_timer("name", timer_type) - return self._create_timer_decorator(name, timer_type, cuda_sync) + return self._create_timer_decorator(name, timer_type, cuda_sync, enabled) - def _create_timer_decorator(self, name, timer_type, cuda_sync=False): + def _create_timer_decorator(self, name, timer_type=TimerType.CUDA, cuda_sync=False, enabled=None): """Create a decorator that times the execution of a function.""" def decorator(func): @@ -253,7 +288,7 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - with self(name, timer_type, cuda_sync): + with self(name, timer_type, cuda_sync, enabled): return func(*args, **kwargs) return wrapper diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index af26c6da0..6bb4d6472 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -1,3 +1,4 @@ +import threading from abc import ABCMeta, abstractmethod from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Optional, Tuple @@ -238,7 +239,34 @@ def build_model( return model -# TODO @thomasw21: Should this option override user defined options? Maybe not ... right now it does. +@contextmanager +def ignore_init_on_device_and_dtype(): + """ + A context manager that temporarily disables dtype enforcement from init_on_device_and_dtype. + + Example: + ```python + with init_on_device_and_dtype(device=torch.device("cuda"), dtype=torch.float32): + with ignore_init_on_device_and_dtype(): + # This parameter will keep its specified dtype (float32) + self.weight = nn.Parameter(torch.randn(..., dtype=torch.float32)) + ``` + """ + # Create a thread-local storage for the ignore flag + if not hasattr(ignore_init_on_device_and_dtype, "_ignore_flag"): + ignore_init_on_device_and_dtype._ignore_flag = threading.local() + + # Set the ignore flag + old_value = getattr(ignore_init_on_device_and_dtype._ignore_flag, "value", False) + ignore_init_on_device_and_dtype._ignore_flag.value = True + + try: + yield + finally: + # Restore the previous value + ignore_init_on_device_and_dtype._ignore_flag.value = old_value + + @contextmanager def init_on_device_and_dtype( device: torch.device = torch.device("cpu"), @@ -250,35 +278,30 @@ def init_on_device_and_dtype( device (`torch.device` defaults to `cpu`): Device to initialize all parameters on. dtype (`torch.dtype` defaults to `torch.float`): - Dtype to initialize all parameters on. - include_buffers (`bool`, defaults to `False`): - Whether or not to also default all buffers constructors given previous arguments. - Example: - ```python - import torch.nn as nn - from accelerate import init_on_device - with init_on_device_and_dtype(device=torch.device("cuda")): - tst = nn.Liner(100, 100) # on `cuda` device - ``` + Dtype to initialize all parameters on. If specified, will override any dtype + set in parameter initialization with a warning, unless within an ignore_init_on_device_and_dtype context. """ old_register_parameter = nn.Module.register_parameter old_register_buffer = nn.Module.register_buffer + def should_ignore_init_on_device_and_dtype(): + if not hasattr(ignore_init_on_device_and_dtype, "_ignore_flag"): + return False + return getattr(ignore_init_on_device_and_dtype._ignore_flag, "value", False) + def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: - if isinstance(param, DTypeInvariantTensor): - # if param is DTypeInvariantTensor we should avoid updating it - param.data = param.data.to(device) + if should_ignore_init_on_device_and_dtype(): + pass else: param.data = param.data.to(device, dtype) def register_empty_buffer(module, name, buffer, persistent=True): old_register_buffer(module, name, buffer, persistent=persistent) if buffer is not None: - if isinstance(buffer, DTypeInvariantTensor): - # if buffer is DTypeInvariantTensor we should avoid updating it - buffer.data = buffer.data.to(device) + if should_ignore_init_on_device_and_dtype(): + pass else: module._buffers[name] = module._buffers[name].to(device, dtype) diff --git a/src/nanotron/models/qwen.py b/src/nanotron/models/qwen.py index 8115a9bb9..8aa6eb46a 100644 --- a/src/nanotron/models/qwen.py +++ b/src/nanotron/models/qwen.py @@ -3,7 +3,6 @@ import torch from flash_attn.modules.mha import flash_attn_varlen_kvpacked_func from torch import nn -from torch.nn import functional as F from torch.utils.checkpoint import CheckpointFunction from nanotron import distributed as dist @@ -303,6 +302,7 @@ def __init__( config: Qwen2Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + intermediate_size: int, ) -> None: super().__init__() @@ -312,14 +312,14 @@ def __init__( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - # Define gate_up_proj as a merged layer for gate and up projections gate_up_contiguous_chunks = ( - config.intermediate_size, # shape of gate_linear - config.intermediate_size, # shape of up_linear + intermediate_size, # shape of gate_linear + intermediate_size, # shape of up_linear ) + self.gate_up_proj = TensorParallelColumnLinear( config.hidden_size, - 2 * config.intermediate_size, + 2 * intermediate_size, pg=tp_pg, mode=tp_mode, bias=False, # Qwen2 doesn't use bias for gate_up_proj @@ -330,7 +330,7 @@ def __init__( # Define down projection self.down_proj = TensorParallelRowLinear( - config.intermediate_size, + intermediate_size, config.hidden_size, pg=tp_pg, mode=tp_mode, @@ -355,183 +355,6 @@ def forward(self, hidden_states): return {"hidden_states": hidden_states} -class Qwen2MoELayer(nn.Module): - """Mixture of experts Layer for Qwen2 models.""" - - def __init__( - self, - config: Qwen2Config, - parallel_config: Optional[ParallelismArgs], - tp_pg: dist.ProcessGroup, - layer_idx: int = 0, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - - # MoE specific configurations - self.num_experts = config.moe_config.num_experts # Total number of experts - self.num_experts_per_token = config.moe_config.top_k # Number of experts used per token (top-k) - self.expert_parallel_size = getattr(parallel_config, "expert_parallel_size", 1) - self.num_local_experts = self.num_experts // self.expert_parallel_size # Experts per device - - # Get TP mode configuration - tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - tp_linear_async_communication = ( - parallel_config.tp_linear_async_communication if parallel_config is not None else False - ) - - # Router for selecting experts - self.router = TensorParallelColumnLinear( - self.hidden_size, - self.num_experts, - pg=tp_pg, - mode=tp_mode, - bias=False, - async_communication=tp_linear_async_communication, - ) - - # Enable shared experts if configured - self.enable_shared_expert = getattr(config.moe_config, "enable_shared_expert", False) - if self.enable_shared_expert: - self.shared_expert = Qwen2MLP( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - ) - self.shared_expert_gate = TensorParallelColumnLinear( - self.hidden_size, - 1, - pg=tp_pg, - mode=tp_mode, - bias=False, - async_communication=tp_linear_async_communication, - ) - - # Create the expert MLPs - self.experts = nn.ModuleList( - [ - Qwen2MLP( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - ) - for _ in range(self.num_local_experts) - ] - ) - - # Whether to recompute MoE layer during backward pass for memory efficiency - self.recompute_layer = getattr(parallel_config, "recompute_layer", False) - - # Token dispatcher type - determines communication pattern - self.token_dispatcher_type = getattr(config.moe_config, "token_dispatcher_type", "alltoall") - # For more sophisticated implementations, we would add token dispatcher logic here - - def _compute_router_probabilities(self, hidden_states): - """Compute routing probabilities for each token to each expert.""" - router_logits = self.router(hidden_states) # [batch_size*seq_length, num_experts] - - # Get the top-k experts per token - routing_weights, routing_indices = torch.topk(router_logits, k=self.num_experts_per_token, dim=-1) - - # Apply softmax on the top-k values - routing_weights = F.softmax(routing_weights, dim=-1) - - return routing_weights, routing_indices - - def _dispatch_tokens(self, hidden_states, routing_weights, routing_indices): - """ - Dispatches tokens to their selected experts. - In a full implementation, this would handle the actual token routing logic - including communication between devices. - """ - # Simplified implementation - in a complete version this would handle - # all-to-all or all-gather communications for distributed experts - - hidden_states.shape[0] - dispatched_inputs = [] - expert_counts = [] - - # For each expert, gather the tokens assigned to it - for expert_idx in range(self.num_local_experts): - # Find tokens that have this expert in their top-k - expert_mask = (routing_indices == expert_idx).any(dim=-1) - tokens_for_expert = hidden_states[expert_mask] - - # Get the routing weights for this expert - expert_positions = (routing_indices == expert_idx).nonzero(as_tuple=True) - token_positions, k_positions = expert_positions - expert_weights = routing_weights[token_positions, k_positions].unsqueeze(-1) - - # Scale inputs by routing weights - scaled_inputs = tokens_for_expert * expert_weights - - dispatched_inputs.append(scaled_inputs) - expert_counts.append(len(tokens_for_expert)) - - return dispatched_inputs, expert_counts - - def _combine_expert_outputs(self, expert_outputs, routing_indices, original_shape): - """ - Combines outputs from different experts back to the original tensor layout. - """ - # Initialize output tensor with zeros - combined_output = torch.zeros(original_shape, device=expert_outputs[0].device) - - for expert_idx, expert_output in enumerate(expert_outputs): - if expert_output.shape[0] == 0: # Skip if no tokens were routed to this expert - continue - - # Find positions where this expert was in the top-k - expert_mask = (routing_indices == expert_idx).any(dim=-1) - combined_output[expert_mask] += expert_output - - return combined_output - - def _core_forward(self, hidden_states): - """Core forward logic for MoE layer.""" - # Get router probabilities - routing_weights, routing_indices = self._compute_router_probabilities(hidden_states) - - # Dispatch tokens to experts - dispatched_inputs, expert_counts = self._dispatch_tokens(hidden_states, routing_weights, routing_indices) - - # Process tokens with their assigned experts - expert_outputs = [] - for expert_idx, (inputs, count) in enumerate(zip(dispatched_inputs, expert_counts)): - if count == 0: # Skip computation if no tokens assigned - expert_outputs.append(torch.tensor([], device=hidden_states.device)) - continue - - # Forward through the expert - output = self.experts[expert_idx](hidden_states=inputs)["hidden_states"] - expert_outputs.append(output) - - # Combine expert outputs - output = self._combine_expert_outputs(expert_outputs, routing_indices, hidden_states.shape) - - # Add shared expert contribution if enabled - if self.enable_shared_expert: - shared_expert_output = self.shared_expert(hidden_states=hidden_states)["hidden_states"] - shared_gate = torch.sigmoid(self.shared_expert_gate(hidden_states)) - output = output + shared_gate * shared_expert_output - - return output - - def _checkpointed_forward(self, hidden_states): - """Apply gradient checkpointing to save memory during training.""" - return CheckpointFunction.apply(self._core_forward, True, hidden_states) - - def forward(self, hidden_states): - """Forward pass for the MoE layer.""" - if self.recompute_layer and self.training: - hidden_states = self._checkpointed_forward(hidden_states) - else: - hidden_states = self._core_forward(hidden_states) - - return {"hidden_states": hidden_states} - - class Qwen2DecoderLayer(nn.Module): def __init__( self, @@ -559,6 +382,8 @@ def __init__( # Use MoE layer if this layer is in the MoE layers list if config.moe_config and layer_idx in config.moe_config.layers: + from nanotron.nn.moe import Qwen2MoELayer + self.mlp = Qwen2MoELayer( config=config, parallel_config=parallel_config, @@ -570,6 +395,7 @@ def __init__( config=config, parallel_config=parallel_config, tp_pg=tp_pg, + intermediate_size=config.intermediate_size, ) self.recompute_layer = parallel_config.recompute_layer @@ -896,7 +722,7 @@ def init_model_randomly(self, config: Config): else: raise ValueError(f"Unknown init method {init_method}") - parametrizator = parametrizator_cls(config=config.model) + parametrizator = parametrizator_cls(config=config) log_rank( f"Parametrizing model parameters using {parametrizator.__class__.__name__}", diff --git a/src/nanotron/nn/moe.py b/src/nanotron/nn/moe.py new file mode 100644 index 000000000..40a296d02 --- /dev/null +++ b/src/nanotron/nn/moe.py @@ -0,0 +1,212 @@ +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import CheckpointFunction + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import ParallelismArgs +from nanotron.config.models_config import Qwen2Config +from nanotron.models.base import ignore_init_on_device_and_dtype +from nanotron.nn.activations import ACT2FN + +logger = logging.get_logger(__name__) + + +try: + import grouped_gemm.ops as ops +except ImportError: + raise RuntimeError( + "Grouped GEMM is not available. Please run `pip install --no-build-isolation git+https://github.com/fanshiqing/grouped_gemm@main` (takes less than 5 minutes)" + ) + + +class Router(nn.Module): + def __init__( + self, config: Qwen2Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, layer_idx: int + ): + super().__init__() + self.config = config + self.parallel_config = parallel_config + self.tp_pg = tp_pg + self.layer_idx = layer_idx + + self.num_experts = config.moe_config.num_experts + self.num_experts_per_token = config.moe_config.top_k + + # float32 routing weights + # NOTE: qwen keep the routing weights in float32 + # https://github.com/huggingface/transformers/blob/27a25bee4fcb865e8799ba026f1ea4455f2cca98/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L608 + with ignore_init_on_device_and_dtype(): + self.weight = nn.Parameter( + torch.randn(self.num_experts, config.hidden_size, dtype=torch.float32, device="cuda") + ) + assert self.weight.dtype == torch.float32 + + def gating(self, x: torch.Tensor) -> torch.Tensor: + """Compute logits for all experts (no softmax).""" + # NOTE: qwen keep the routing logits in float32 + # https://github.com/huggingface/transformers/blob/27a25bee4fcb865e8799ba026f1ea4455f2cca98/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L613 + return F.linear(x.to(torch.float32), self.weight, bias=None) + + def routing(self, logits: torch.Tensor): + """Top-k softmax-normalized routing weights and indices.""" + routing_weights = F.softmax(logits, dim=-1, dtype=torch.float32) + routing_weights, routing_indices = torch.topk(routing_weights, k=self.num_experts_per_token, dim=-1) + routing_indices = routing_indices.to(torch.int32) # NOTE: ops.permute requires indices to be int32 + return routing_weights, routing_indices + + def forward(self, x: torch.Tensor): + logits = self.gating(x) + return self.routing(logits) + + +class GroupedMLP(nn.Module): + def __init__(self, config: Qwen2Config, parallel_config: Optional[ParallelismArgs]): + super().__init__() + + num_local_experts = config.moe_config.num_experts // parallel_config.expert_parallel_size + self.merged_gate_up_proj = nn.Parameter( + torch.randn(num_local_experts, config.hidden_size, 2 * config.moe_config.moe_intermediate_size) + ) + self.merged_down_proj = nn.Parameter( + torch.randn(num_local_experts, config.moe_config.moe_intermediate_size, config.hidden_size) + ) + self.act = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ): + """ + assume hidden_states is permuted + + grouped_gemm's notes: + ops.gemm expect the inputs to have the following criteria: + + expect a, b are in bfloat16 + + expect num_tokens_per_expert is a on cpu + """ + # NOTE: ops.gemm requires "batch_sizes" (aka: num_tokens_per_expert here) to be on cpu + num_tokens_per_expert = num_tokens_per_expert.to("cpu") + merged_states = ops.gmm(hidden_states, self.merged_gate_up_proj, num_tokens_per_expert, trans_b=False) + gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1) + hidden_states = self.act(gate_states) * up_states + hidden_states = ops.gmm(hidden_states, self.merged_down_proj, num_tokens_per_expert, trans_b=False) + + return {"hidden_states": hidden_states} + + +class Qwen2MoELayer(nn.Module): + """Mixture of experts Layer for Qwen2 models.""" + + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int = 0, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # MoE specific configurations + self.num_experts = config.moe_config.num_experts # Total number of experts + self.num_local_experts = ( + config.moe_config.num_experts // parallel_config.expert_parallel_size + ) # Experts per device + self.num_experts_per_token = config.moe_config.top_k # Number of experts used per token (top-k) + self.expert_parallel_size = parallel_config.expert_parallel_size + self.num_local_experts = self.num_experts // self.expert_parallel_size # Experts per device + + # Get TP mode configuration + + # Router for selecting experts + self.router = Router(config, parallel_config, tp_pg, layer_idx) + + # Enable shared experts if configured + self.enable_shared_expert = config.moe_config.enable_shared_expert + if self.enable_shared_expert: + from nanotron.models.qwen import Qwen2MLP + + self.shared_expert = Qwen2MLP( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + intermediate_size=config.moe_config.shared_expert_intermediate_size, + ) + # TODO: duplicte the shared expert gate + self.shared_expert_gate = nn.Linear( + self.hidden_size, + 1, + bias=False, + ) # TODO: ensure shared_expert_gate is tied across TP + + # Create the expert MLPs + self.experts = GroupedMLP(config, parallel_config) + # Whether to recompute MoE layer during backward pass for memory efficiency + self.recompute_layer = parallel_config.recompute_layer + + def _dispatch_tokens( + self, + hidden_states: torch.Tensor, + routing_indices: torch.Tensor, + ): + """ + Dispatches tokens to their selected experts. + In a full implementation, this would handle the actual token routing logic + including communication between devices. + """ + # NOTE: start from expert 0 to expert n + num_tokens_per_expert = torch.bincount( + routing_indices.flatten(), minlength=self.num_local_experts + ) # [num_local_experts] + dispatched_inputs, inverse_permute_mapping = ops.permute(hidden_states, routing_indices) + return dispatched_inputs, inverse_permute_mapping, num_tokens_per_expert + + def _combine_expert_outputs(self, expert_outputs, inverse_mapping, routing_weights): + """ + Combines outputs from different experts back to the original tensor layout. + """ + hidden_states = ops.unpermute(expert_outputs, inverse_mapping, routing_weights) + return hidden_states + + def _core_forward(self, hidden_states): + """Core forward logic for MoE layer.""" + # Get top-k routing weights and indices + routing_weights, routing_indices = self.router(hidden_states) # [num_tokens, num_experts_per_token] + + # Dispatch tokens to experts + dispatched_inputs, inverse_permute_mapping, num_tokens_per_expert = self._dispatch_tokens( + hidden_states, routing_indices + ) + + expert_outputs = self.experts(dispatched_inputs, num_tokens_per_expert) + + output = self._combine_expert_outputs( + expert_outputs["hidden_states"], inverse_permute_mapping, routing_weights + ) + + # Add shared expert contribution if enabled + if self.enable_shared_expert: + shared_expert_output = self.shared_expert(hidden_states=hidden_states)["hidden_states"] + shared_gate = torch.sigmoid(self.shared_expert_gate(hidden_states)) + output = output + shared_gate * shared_expert_output + + return output + + def _checkpointed_forward(self, hidden_states): + """Apply gradient checkpointing to save memory during training.""" + return CheckpointFunction.apply(self._core_forward, True, hidden_states) + + def forward(self, hidden_states): + """Forward pass for the MoE layer.""" + if self.recompute_layer and self.training: + hidden_states = self._checkpointed_forward(hidden_states) + else: + hidden_states = self._core_forward(hidden_states) + + return {"hidden_states": hidden_states} diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 8107b46e6..088551b0f 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -182,7 +182,8 @@ def build_grad_buffers( if not param.requires_grad: continue - assert param.dtype != torch.float, f"Expected {name} not to be float" + # MoE router weights are initialized in float32 + assert param.dtype != torch.float or "router.weight" in name, f"Expected {name} not to be float" assert param.is_contiguous(), f"Expected {name} to be contiguous" next_offset = offset + param.numel() * element_size diff --git a/src/nanotron/s3_checkpoints/s3_mover.py b/src/nanotron/s3_checkpoints/s3_mover.py index 483019d5a..32842a909 100644 --- a/src/nanotron/s3_checkpoints/s3_mover.py +++ b/src/nanotron/s3_checkpoints/s3_mover.py @@ -225,7 +225,7 @@ def distributed_wait_for_completion(self, group: Optional[ProcessGroup] = None): dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False) dist.barrier() all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list) - time.sleep(1) + time.sleep(1) # TODO @nouamane: make this configurable def is_previous_save_finished(self) -> bool: """Return True if a potential previous checkpoint has been fully uploaded to S3 diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index 187e76e09..8324eccf0 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -3,8 +3,10 @@ from enum import Enum, auto from typing import Dict -from nanotron.config import ModelArgs +from nanotron.config import Config, ModelArgs +from nanotron.config.models_config import InitScalingMethod from nanotron.nn.layer_norm import LlamaRMSNorm, TritonRMSNorm +from nanotron.nn.moe import GroupedMLP, Router from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -31,33 +33,77 @@ def parametrize(self, param_name: str, module: nn.Module): class StandardParametrizator(Parametrizator): - def __init__(self, config: ModelArgs): + def __init__(self, config: Config): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { TensorParallelColumnLinear: self._parametrize_column_linear, + # TODO: double check if correct initialization for grouped MLP TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, LlamaRMSNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, + # NOTE: MoE's specific initialization + GroupedMLP: self._parametrize_grouped_mlp, + Router: self._parametrize_router, + nn.Linear: self._parametrize_column_linear, } - self.std = config.init_method.std - self.num_layers = config.model_config.num_hidden_layers + self.std = config.model.init_method.std + self.num_layers = config.model.model_config.num_hidden_layers + self.tp = config.parallelism.tp + self.scaling_method = config.model.init_method.scaling_method + self.hidden_size = config.model.model_config.hidden_size def _parametrize_column_linear(self, param_name: str, module: nn.Module): assert param_name in ["weight", "bias"] + if "weight" == param_name: + # TODO @nouamane: should we use trunc_normal_ + init.normal_(module.weight, mean=0.0, std=self.std) + elif "bias" == param_name: + module.bias.zero_() + + def _parametrize_grouped_mlp(self, param_name: str, module: nn.Module): + for n, p in module.named_parameters(): + if n == "merged_gate_up_proj": + # NOTE: the same as parametrization of column linear + init.normal_(p, mean=0.0, std=self.std) + elif n == "merged_down_proj": + # NOTE: the same as parametrization of row linear + scaling = self._compute_scaling_factor() + adjusted_std = self.std / scaling + # TODO @nouamane: should we use trunc_normal_ + init.normal_(p, mean=0.0, std=adjusted_std) + else: + raise ValueError(f"Unknown parameter {n}") + + def _parametrize_router(self, param_name: str, module: nn.Module): if "weight" == param_name: init.normal_(module.weight, mean=0.0, std=self.std) elif "bias" == param_name: module.bias.zero_() + def _compute_scaling_factor(self) -> float: + """Compute initialization scaling based on selected method""" + if self.scaling_method == InitScalingMethod.NONE: + return 1.0 + elif self.scaling_method == InitScalingMethod.NUM_LAYERS: + # Scale based on total network depth + return math.sqrt(2 * self.num_layers) + elif self.scaling_method == InitScalingMethod.LAYER_INDEX: + # Scale based on layer position + raise NotImplementedError("Layer position scaling not yet implemented") + else: + raise ValueError(f"Invalid scaling method: {self.scaling_method}") + def _parametrize_row_linear(self, param_name: str, module: nn.Module): assert param_name in ["weight", "bias"] if "weight" == param_name: - std = self.std / math.sqrt(2 * self.num_layers) - init.normal_(module.weight, mean=0.0, std=std) + scaling = self._compute_scaling_factor() + adjusted_std = self.std / scaling + # TODO @nouamane: should we use trunc_normal_ + init.normal_(module.weight, mean=0.0, std=adjusted_std) elif "bias" == param_name: module.bias.zero_() @@ -65,7 +111,6 @@ def _parametrize_layer_norm(self, param_name: str, module: nn.Module): assert param_name in ["weight", "bias"] if "weight" == param_name: - # TODO @thomasw21: Sometimes we actually want 0 module.weight.fill_(1) elif "bias" == param_name: module.bias.zero_() diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index b1445b481..2b5d45585 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -64,7 +64,7 @@ def save( try: if should_save_config: - config.save_as_yaml(root_folder / "config.yaml") + config.save_as_yaml(root_folder / "config.yaml", sanity_checks=sanity_checks) except Exception as e: # TODO @nouamane: catch full disk error log_rank( diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 5110d6eb2..c94e97de4 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -4,7 +4,6 @@ import os import shutil import tempfile -import time from dataclasses import asdict from pathlib import Path from pprint import pformat @@ -39,6 +38,7 @@ ) from nanotron.constants import MODEL_CONFIG_FILE_NAME from nanotron.data.dataloader import sanity_check_dataloader +from nanotron.eval import LightEvalRunner from nanotron.helpers import ( _vocab_size_with_padding, compute_remain_train_steps_of_a_data_stage_from_ckp, @@ -122,7 +122,7 @@ def get_size(bytes): """Convert bytes to human readable format""" - for unit in ["", "K", "M", "G", "T", "P"]: + for unit in ["", "K", "M", "B", "T", "P"]: if bytes < 1024: return f"{bytes:.2f}{unit}B" bytes /= 1024 @@ -185,7 +185,9 @@ def __init__( ######################################## # Set random states - set_random_seed(self.config.general.seed) + # Set different random seed for each TP rank to ensure diversity (especially at weight init) + tp_rank = dist.get_rank(self.parallel_context.tp_pg) + set_random_seed(self.config.general.seed + tp_rank) # Init model and build on pp ranks self.random_states = init_random_states( @@ -275,6 +277,7 @@ def __init__( self.limit_val_batches = self.config.tokens.limit_val_batches self.current_dataloader: Optional[DataLoader] = None # used for the current training stage self.current_base_dl: Optional[DataLoader] = None # used for the current training stage + self.iteration_timer = None # Will be initialized during training log_libraries_versions(logger=logger) log_rank("Config:", logger=logger, level=logging.INFO, rank=0, is_separator=True) @@ -312,6 +315,14 @@ def post_init(self): else: self.s3_mover = None + # Initialize LightEval runner on rank 0 + if dist.get_rank(self.parallel_context.world_pg) == 0: + if self.config.lighteval is not None: + self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) + if self.s3_mover is not None: + # If we have S3 upload enabled, use the eval_single_checkpoint as post-upload callback + self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint + def pre_training(self, *args, **kwargs): if not self.config.general.ignore_sanity_checks: log_rank( @@ -523,8 +534,6 @@ def train( ], **kwargs, ) -> None: - self.pre_training(**kwargs) - if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None: self.save_checkpoint() @@ -543,6 +552,7 @@ def train( self.initial_iter_step = self.metadata.last_train_step + 1 self.last_iter_step = self.config.tokens.train_steps + self.pre_training(**kwargs) prof = get_profiler(config=self.config) # free memory @@ -554,28 +564,32 @@ def train( logger.info(f"Profiler on for step {self.iteration_step}") prof.step() - self.iteration_start_time = time.time() + # Use CUDA event-based timing for more accurate GPU-side elapsed time measurement + self.iteration_timer = nanotron_timer("iteration_time", "cuda", cuda_sync=False, enabled=True) + self.iteration_timer.start() self._update_dataloader_based_on_training_stages(dataloader_or_dls) # Training step outputs, loss_avg, z_loss_avg = self.training_step(dataloader=self.current_dataloader) # Update consumption tracking for current batch - self.current_base_dl.dataset.update_consumption_metrics( - start_idx=(self.iteration_step - 1) - * self.global_batch_size, # assumes we start from iteration_step=1 - end_idx=self.iteration_step * self.global_batch_size, - sequence_length=self.sequence_length, - ) + if hasattr(self.current_base_dl, "dataset"): + self.current_base_dl.dataset.update_consumption_metrics( + start_idx=(self.iteration_step - 1) + * self.global_batch_size, # assumes we start from iteration_step=1 + end_idx=self.iteration_step * self.global_batch_size, + sequence_length=self.sequence_length, + ) # Training Logs # Track consumed tokens for all dataset folders in current stage - consumption_stats = self.current_base_dl.dataset.get_consumption_stats() - current_stage = self.metadata.data_stages[self.metadata.last_stage_idx] + if hasattr(self.current_base_dl, "dataset"): + consumption_stats = self.current_base_dl.dataset.get_consumption_stats() + current_stage = self.metadata.data_stages[self.metadata.last_stage_idx] - # Update consumed tokens for all folders in the consumption stats - for folder_path, stats in consumption_stats.items(): - current_stage.consumed_tokens_per_dataset_folder[folder_path] = stats["tokens"] + # Update consumed tokens for all folders in the consumption stats + for folder_path, stats in consumption_stats.items(): + current_stage.consumed_tokens_per_dataset_folder[folder_path] = stats["tokens"] # Original consumption tracking self.metadata.consumed_train_samples += self.global_batch_size @@ -738,8 +752,9 @@ def train_step_logs( ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() - torch.cuda.synchronize() - elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 + # End the iteration timer and get elapsed time in milliseconds + self.iteration_timer.end() + elapsed_time_per_iteration_ms = self.iteration_timer.elapsed * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) ) # tokens_per_sec is calculated using sequence_length @@ -763,7 +778,8 @@ def train_step_logs( # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "consumed_tokens", - self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + self.metadata.consumed_train_samples + * self.config.tokens.sequence_length, # TODO: not true if we change seqlen "human_format", ), # , "12d"), LogItem("time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"), @@ -863,12 +879,13 @@ def get_cpu_logitems(): assert self.current_base_dl is not None, "current_base_dl should be defined" # Log consumption statistics - for dataset_name, stats in self.current_base_dl.dataset.get_consumption_stats().items(): - basic_log_entries.extend( - [ - LogItem(f"dataloader/consumed_tokens/{dataset_name}", stats["tokens"], "human_format"), - ] - ) + if hasattr(self.current_base_dl, "dataset"): + for dataset_name, stats in self.current_base_dl.dataset.get_consumption_stats().items(): + basic_log_entries.extend( + [ + LogItem(f"dataloader/consumed_tokens/{dataset_name}", stats["tokens"], "human_format"), + ] + ) # WandB logging - determine if this rank should log to wandb should_log_to_wandb = wandb is not None and ( @@ -1160,26 +1177,69 @@ def setup_log_writers( return loggerwriter def pre_save_checkpoint(self) -> Path: + # Check if eval_interval should be updated from file + eval_interval_file = self.config.lighteval.eval_interval_file if self.config.lighteval is not None else None + if eval_interval_file is not None and Path(eval_interval_file).exists(): + try: + with open(eval_interval_file, "r") as f: + new_eval_interval = int(f.read().strip()) + + # Verify that the new interval is a multiple of checkpoint_interval + if new_eval_interval == self.config.lighteval.eval_interval: + pass + elif new_eval_interval % self.config.checkpoints.checkpoint_interval == 0: + log_rank( + f"Updating lighteval.eval_interval from {self.config.lighteval.eval_interval} to {new_eval_interval}", + logger=logger, + level=logging.INFO, + rank=0, + ) + self.config.lighteval.eval_interval = new_eval_interval + else: + log_rank( + f"New eval_interval={new_eval_interval} must be a multiple of checkpoint_interval={self.config.checkpoints.checkpoint_interval}. Keeping current value: {self.config.lighteval.eval_interval}", + logger=logger, + level=logging.WARNING, + rank=0, + ) + except (ValueError, IOError) as e: + log_rank( + f"Error reading eval_interval from file: {e}. Keeping current value: {self.config.lighteval.eval_interval}", + logger=logger, + level=logging.WARNING, + rank=0, + ) + if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) if self.s3_mover.post_upload_callback_outputs is not None: slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs - self.log_object({"job_id": slurm_job_id, "log": slurm_log}, "slurm_eval") + log_rank( + f"launching eval job: job_id={slurm_job_id} log at {slurm_log} slurm_eval", + logger=logger, + level=logging.WARNING, + rank=0, + ) def post_save_checkpoint(self): # Upload to S3 if self.s3_mover is not None: self.s3_mover.start_uploading() - # free memory TODO: do we need this? - # gc.collect() - # torch.cuda.empty_cache() + if dist.get_rank(self.parallel_context.world_pg) == 0: + if self.config.lighteval is not None and self.s3_mover is None: + if ( + self.config.lighteval.eval_interval is None + or self.iteration_step % self.config.lighteval.eval_interval == 0 + ): + checkpoint_path = Path(self.config.checkpoints.checkpoints_path) / f"{self.config.general.step}" + self.lighteval_runner.eval_single_checkpoint(checkpoint_path) def save_checkpoint(self) -> Path: self.pre_save_checkpoint() checkpoints_path = self.config.checkpoints.checkpoints_path - checkpoint_path = checkpoints_path / f"{self.iteration_step}" + checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}" if self.config.checkpoints.checkpoints_path_is_shared_file_system: should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0 else: @@ -1210,6 +1270,7 @@ def save_checkpoint(self) -> Path: root_folder=checkpoint_path, training_metadata=self.metadata, config=self.config, + sanity_checks=not self.config.general.ignore_sanity_checks, ) save_random_states( random_states=self.random_states, parallel_context=self.parallel_context, root_folder=checkpoint_path diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index a70035d48..dcd1b66c4 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -1,8 +1,6 @@ import functools import inspect import os -import random -import socket from contextlib import ExitStack, contextmanager from typing import ContextManager, List, Optional @@ -148,15 +146,3 @@ def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: to tensor = torch.empty([], dtype=dtype, device=device) tensor.set_(source=untyped_storage) return tensor - - -def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int: - while True: - port = random.randint(min_port, max_port) - try: - with socket.socket() as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("localhost", port)) - return port - except OSError: - continue diff --git a/src/nanotron/utils_network.py b/src/nanotron/utils_network.py new file mode 100644 index 000000000..b9d8684bc --- /dev/null +++ b/src/nanotron/utils_network.py @@ -0,0 +1,13 @@ +import random +import socket + +def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int: + while True: + port = random.randint(min_port, max_port) + try: + with socket.socket() as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("localhost", port)) + return port + except OSError: + continue \ No newline at end of file diff --git a/test_timer_decorator.py b/test_timer_decorator.py new file mode 100644 index 000000000..b900f33d9 --- /dev/null +++ b/test_timer_decorator.py @@ -0,0 +1,63 @@ +"""Test script for the timer decorator with both CPU and CUDA timer types.""" + +from nanotron.logging.timers import nanotron_timer, TimerType +import time +import torch + +# Enable timers for testing +nanotron_timer.enable() + +# Test with default CUDA timing +@nanotron_timer +def test_default_decorator(): + """Test function with default CUDA timing.""" + # Simulate some work + time.sleep(0.1) + if torch.cuda.is_available(): + x = torch.randn(1000, 1000, device="cuda") + y = torch.matmul(x, x) + torch.cuda.synchronize() + return "Done" + +# Test with explicit CUDA timing +@nanotron_timer(timer_type=TimerType.CUDA) +def test_cuda_decorator(): + """Test function with explicit CUDA timing.""" + # Simulate some work + time.sleep(0.1) + if torch.cuda.is_available(): + x = torch.randn(1000, 1000, device="cuda") + y = torch.matmul(x, x) + torch.cuda.synchronize() + return "Done" + +# Test with CPU timing +@nanotron_timer(timer_type=TimerType.CPU) +def test_cpu_decorator(): + """Test function with CPU timing.""" + # Simulate some CPU work + time.sleep(0.2) + return "Done" + +# Test with custom name +@nanotron_timer("custom_name") +def test_custom_name_decorator(): + """Test function with custom name.""" + # Simulate some work + time.sleep(0.1) + return "Done" + +if __name__ == "__main__": + print("Testing timer decorators...") + + # Run the test functions + test_default_decorator() + test_cuda_decorator() + test_cpu_decorator() + test_custom_name_decorator() + + # Log all timers + print("\nTimer results:") + nanotron_timer.log_all(rank=None) # Log on all ranks + + print("\nTest completed successfully!") diff --git a/tests/helpers/qwen_helper.py b/tests/helpers/qwen_helper.py index b333e2a71..7528b9d50 100644 --- a/tests/helpers/qwen_helper.py +++ b/tests/helpers/qwen_helper.py @@ -17,30 +17,44 @@ TokensArgs, ) from nanotron.config.config import PretrainDatasetsArgs +from nanotron.config.models_config import MoEConfig from nanotron.models import build_model from nanotron.models.qwen import Qwen2Config, Qwen2ForTraining from nanotron.parallel.context import ParallelContext from nanotron.trainer import mark_tied_parameters -TINY_QWEN_CONFIG = Qwen2Config( +QWEN_MOE_CONFIG = MoEConfig( + num_experts=8, + top_k=1, + enable_shared_expert=True, + token_dispatcher_type="alltoall", +) + +QWEN_MODEL_CONFIG = { + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 128, + "initializer_range": 0.02, + "intermediate_size": 128 * 4, + "max_position_embeddings": 128, + "num_attention_heads": 4, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "pad_token_id": None, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 4096, + "_attn_implementation": "flash_attention_2", +} + +TINY_QWEN_CONFIG = Qwen2Config(**QWEN_MODEL_CONFIG) +TINY_MOE_QWEN_CONFIG = Qwen2Config( **{ - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 128, - "initializer_range": 0.02, - "intermediate_size": 128 * 4, - "max_position_embeddings": 128, - "num_attention_heads": 4, - "num_hidden_layers": 4, - "num_key_value_heads": 2, - "pad_token_id": None, - "rms_norm_eps": 1e-06, - "rope_theta": 10000.0, - "tie_word_embeddings": False, - "use_cache": True, - "vocab_size": 4096, - "_attn_implementation": "flash_attention_2", + **QWEN_MODEL_CONFIG, + "moe_config": QWEN_MOE_CONFIG, } ) diff --git a/tests/test_moe.py b/tests/test_moe.py new file mode 100644 index 000000000..635757e1f --- /dev/null +++ b/tests/test_moe.py @@ -0,0 +1,37 @@ +import torch +from helpers.qwen_helper import TINY_MOE_QWEN_CONFIG +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.models.base import init_on_device_and_dtype +from nanotron.nn.moe import GroupedMLP + + +def test_grouped_mlp(): + parallel_config = ParallelismArgs( + dp=1, + pp=1, + tp=1, + expert_parallel_size=1, + pp_engine="1f1b", + tp_mode="REDUCE_SCATTER", + tp_linear_async_communication=True, + ) + num_tokens_per_experts = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + NUM_TOKENS = num_tokens_per_experts.sum() + NUM_EXPERTS = TINY_MOE_QWEN_CONFIG.moe_config.num_experts + HIDDEN_SIZE = TINY_MOE_QWEN_CONFIG.hidden_size + permuted_hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device="cuda") + + assert len(num_tokens_per_experts) == NUM_EXPERTS + + with init_on_device_and_dtype(device=torch.device("cuda"), dtype=torch.bfloat16): + grouped_mlp = GroupedMLP(config=TINY_MOE_QWEN_CONFIG, parallel_config=parallel_config) + + output = grouped_mlp(permuted_hidden_states, num_tokens_per_experts) + + assert output["hidden_states"].shape == (NUM_TOKENS, HIDDEN_SIZE) + assert output["hidden_states"].dtype == torch.bfloat16 + assert output["hidden_states"].device.type == "cuda" + + +if __name__ == "__main__": + test_grouped_mlp()