Skip to content

[WIP][DeepSeek] DeepSeek training and component integration with Titan main components #1183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
edd0e76
start debug model toml
lessw2020 May 8, 2025
12be2ee
create models folder, move relevant files, update imports
lessw2020 May 9, 2025
8c99de1
add more files to models folder
lessw2020 May 9, 2025
bf0f925
move group gemm files to kernels.group_gemm and deepseek_gemm
lessw2020 May 9, 2025
9755152
start model integration
lessw2020 May 9, 2025
c1b6639
refactor model.py into components
lessw2020 May 10, 2025
cafc534
add ds TrainSpec
lessw2020 May 12, 2025
74a5c4d
add parallelize_deepseek
lessw2020 May 12, 2025
e1c8a55
update run_training with matching params as main titan
lessw2020 May 12, 2025
3ca50d6
start similar run_training.sh for integration
lessw2020 May 12, 2025
4a3f565
use parallelism from toml file instead of hardcoded
lessw2020 May 12, 2025
d562674
move parallelism into main loop
lessw2020 May 12, 2025
0457b9c
integrate bs and seqlen from config
lessw2020 May 12, 2025
fd96293
add world size vs parallel size
lessw2020 May 12, 2025
972c211
add hf_tokenizer and hf dataloader
lessw2020 May 12, 2025
fccc6d3
now generating c4 data batches for training
lessw2020 May 13, 2025
4f50c6f
now training with real c4 data
lessw2020 May 13, 2025
2fdabcc
cross entropy loss working (titan main style)
lessw2020 May 14, 2025
921a244
training working with real data, optimizer, lr scheduler
lessw2020 May 14, 2025
4c3209b
remove synthetic data generation
lessw2020 May 15, 2025
71ac5af
use toml for training steps
lessw2020 May 15, 2025
a1f1bfe
create metrics processor
lessw2020 May 15, 2025
16e44ef
add color
lessw2020 May 15, 2025
5ec3cae
metrics tracking integrated
lessw2020 May 15, 2025
a363e15
start expert token tracking
lessw2020 May 15, 2025
9a04c58
token tracking working
lessw2020 May 15, 2025
245d313
update csv report for token tracking
lessw2020 May 15, 2025
4ac7c34
export topk reports for csv
lessw2020 May 15, 2025
36ba7e6
current status for expert token tracking
lessw2020 May 20, 2025
c594d8a
small batch run
lessw2020 May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ build
outputs
dist/*
.vscode
*.csv

# data
data
Expand Down
2 changes: 2 additions & 0 deletions run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ set -ex
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
#CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/deepseek_v3/train_configs/debug_model.toml"}


overrides=""
if [ $# -ne 0 ]; then
Expand Down
5 changes: 4 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class Optimizer:
eps: float = 1e-8
"""Epsilon value to use"""

implementation: Literal["for-loop", "foreach", "fused"] = "fused"
implementation: Literal["for-loop", "foreach", "fused"] = "foreach"
"""
Specify which optimizer implementation to use:
- 'fused': Use fused implementation (CUDA only) for best performance.
Expand Down Expand Up @@ -341,6 +341,9 @@ class Parallelism:
The default value is 'allgather'.
"""

expert_parallel_degree: int = 1
"""Expert parallelism degree. 1 means disabled."""

Comment on lines +344 to +346
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is only for MoE-based models, how about let's use https://github.com/pytorch/torchtitan/blob/main/docs/extension.md#extending-jobconfig
and create a separate .py config file in the deepseek folder. Later we can see if we can reuse them for Llama 4 and DeepSeek.


@dataclass
class Checkpoint:
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def __iter__(self):
for sample in self._get_data_iter():
# Use the dataset-specific text processor
sample_text = self._text_processor(sample)
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
sample_tokens = self._tokenizer.encode(
sample_text,
) # TODO - this is temp override for ds....bos=True, eos=True)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1

Expand Down
7 changes: 7 additions & 0 deletions torchtitan/datasets/tokenizer/hf_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from torchtitan.tools.logging import logger
from transformers import AutoTokenizer


def get_hf_tokenizer(model_id: str):
logger.info(f"Instantiating tokenizer for {model_id}")
return AutoTokenizer.from_pretrained(model_id)
1 change: 1 addition & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torchtitan.experiments.deepseek_v3 # noqa: F401
import torchtitan.experiments.llama4 # noqa: F401
import torchtitan.experiments.simple_fsdp # noqa: F401
8 changes: 6 additions & 2 deletions torchtitan/experiments/deepseek_v3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ and scripts needed to run it.

You will need to download a DeepSeek model's weights if you want to run a
pre-trained checkpoint. We provided a script to download the weights from
HuggingFace Model Hub:
HuggingFace Model Hub in /model/download.py.
You can run it with the following:
```bash
python download.py [vX]
```
Expand All @@ -23,6 +24,7 @@ command:
```bash
torchrun --standalone --nproc-per-node 4 generate.py
```

This will run inference on the `DeepSeek-V2-Lite-Chat` model using 4 GPUs by
default.

Expand All @@ -31,10 +33,12 @@ followed by your prompt.

## Training

The training script is in `train.py`. You can run it by the following command:
The training script is in `train.py`. You can run it with the following command:
```bash
torchrun --standalone --nproc-per-node 8 train.py
```

This will run training on the `DeepSeek-V2-Lite-Chat` model using 8 GPUs by
default, with pipeline parallel, expert parallel, and data parallel enabled.

Alternatively, you can run training by using `bash run_traaining.sh`, and modifyh the script to your needs.
50 changes: 50 additions & 0 deletions torchtitan/experiments/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,53 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer.hf_tokenizer import get_hf_tokenizer

# ToDO - this is not suitable for deepseek but using for now...
from torchtitan.models.llama3 import pipeline_llama
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize_deepseek import parallelize_deepseek

from .models.model import DeepseekForCausalLM

from .models.model_args import TransformerModelArgs


__all__ = [
"TransformerModelArgs",
"DeepseekForCausalLM",
"deepseek_configs",
]


deepseek_configs = {
"debugmodel": TransformerModelArgs(
dim=256,
n_layers=6,
n_heads=16,
rope_theta=500000,
),
}


register_train_spec(
TrainSpec(
name="deepseek3",
cls=DeepseekForCausalLM,
config=deepseek_configs,
parallelize_fn=parallelize_deepseek,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=get_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
)
30 changes: 30 additions & 0 deletions torchtitan/experiments/deepseek_v3/cute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import cutlass
import cutlass.cute as cute
import torch


@cute.kernel
def kernel():
tidx, _, _ = cute.arch.thread_idx()
if tidx == 0:
cute.printf(">>> Hello world! from Kernel")


@cute.jit
def hello_world():
cute.printf(">>> Hello world! from CPU")

cutlass.cuda.initialize_cuda_context()
kernel().launch(
grid=(1, 1, 1),
block=(32, 1, 1),
)


print("running with out compile...")
hello_world()
print(f"\n\nrunning with compile...")
compiled = cute.compile(hello_world)
print(f"\n\ncompiled function: {compiled}")
compiled()
print("done")
12 changes: 7 additions & 5 deletions torchtitan/experiments/deepseek_v3/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# torchrun --standalone --nproc-per-node 4 generate.py

# use inference.sh "Your Question Here?" to run inference with a single prompt.
# use bash inference.sh "Your Question Here?" to run inference with a single prompt.

import sys
from dataclasses import dataclass
Expand All @@ -15,11 +15,13 @@
import torch.distributed as dist

from checkpoint import load_weights_from_hf
from model import DeepseekForCausalLM
from model_config import deepseek_config_registry
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from transformers import AutoTokenizer
from torchtitan.datasets.tokenizer.hf_tokenizer import get_hf_tokenizer
from torchtitan.experiments.deepseek_v3.models.model import DeepseekForCausalLM
from torchtitan.experiments.deepseek_v3.models.model_config import (
deepseek_config_registry,
)

from torchtitan.tools.utils import Color

Expand Down Expand Up @@ -367,7 +369,7 @@ def generate_with_cuda_graph(

dist_config = create_dist_config(mesh)
model, pp_schedule = create_model(dist_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer = get_hf_tokenizer(model_id)

messages = [
{"role": "system", "content": "You are a helpful assistant."},
Expand Down
127 changes: 127 additions & 0 deletions torchtitan/experiments/deepseek_v3/infra/parallelize_deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn as nn

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.pipelining import PipelineStage, Schedule1F1B

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims

# from checkpoint import load_weights_from_hf
from torchtitan.experiments.deepseek_v3.models.model import DeepseekForCausalLM
from torchtitan.experiments.deepseek_v3.models.model_config import (
deepseek_config_registry,
)

from torchtitan.models.llama3.parallelize_llama import (
apply_ac,
apply_compile,
apply_ddp,
apply_fsdp,
apply_tp,
)
from torchtitan.tools.logging import logger


# Use DeepSeek-V2-Lite as a proxy
model_id = "deepseek-ai/DeepSeek-V2-Lite"


# from ..model.moe import MoE


# Get model parallel subgroup by name:
# e.g. "pp", "ep", None
def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup:
glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh()
return glob.get_group(dim_name)


def parallelize_deepseek(
# model: nn.Module,
world_mesh: DeviceMesh,
device: torch.device,
model_args,
rank: int,
# parallel_dims: ParallelDims,
# job_config: JobConfig,
):
"""
Apply parallelism to the model.

NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""
logger.info("Applying parallelism to the model...")
world_size = int(os.environ["WORLD_SIZE"])

pp_mesh = world_mesh["pp"]
ep_mesh = world_mesh["ep"]
pp_rank = pp_mesh.get_local_rank()
ep_rank = ep_mesh.get_local_rank()
pp_size = pp_mesh.size()
ep_size = ep_mesh.size()

# Apply data parallelism
fsdp_mesh = world_mesh["fsdp"]
hsdp_mesh = world_mesh["ep", "fsdp"]

hsdp_size = hsdp_mesh.size()

# Apply model parallelism
model_args.ep_size = ep_size
model_args.num_stages = pp_size
model_args.stage_idx = pp_rank
logger.info(
f"Parallelism: {rank=}, {ep_size=}, {pp_size=}, {model_args.ep_size=}, {model_args.num_stages=}, {model_args.stage_idx=}"
)
# print(model_args)
# verify world size matches parallelized total
parallelized_world_size = pp_size * hsdp_size
logger.info(f"Total Parallelized World size {parallelized_world_size}")
assert (
world_size == parallelized_world_size
), f"mismatch between total world size {world_size=} and parallelized total {parallelized_world_size}"

# Instantiate model
with device, world_mesh:
model = DeepseekForCausalLM(model_args)
# Load weights
# load_weights_from_hf(model, model_id, device)
model.train()

# Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the
# optimizer (Zero-1) and gradients (Zero-2), but not the model weights.
# Reason: the MoE is "sparsely activated" compared to the dense model, thus
# it will be ineconomical re-gather the weights.
for layer in model.model.layers.values():
# Apply FSDP to experts
if hasattr(layer.mlp, "experts"):
for expert in layer.mlp.experts.values():
fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False)
# Apply HSDP to other parts such as attention, layernorm, because they
# are doing DDP on EP dimension
fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False)

# Apply HSDP on root model (lm_head, embeddings, etc)
fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)

return (
model,
pp_size,
pp_rank,
pp_mesh,
ep_size,
ep_rank,
)
Loading