-
Notifications
You must be signed in to change notification settings - Fork 386
[WIP][DeepSeek] DeepSeek training and component integration with Titan main components #1183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lessw2020
wants to merge
30
commits into
main
Choose a base branch
from
lessw2020/deepseek_training
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
edd0e76
start debug model toml
lessw2020 12be2ee
create models folder, move relevant files, update imports
lessw2020 8c99de1
add more files to models folder
lessw2020 bf0f925
move group gemm files to kernels.group_gemm and deepseek_gemm
lessw2020 9755152
start model integration
lessw2020 c1b6639
refactor model.py into components
lessw2020 cafc534
add ds TrainSpec
lessw2020 74a5c4d
add parallelize_deepseek
lessw2020 e1c8a55
update run_training with matching params as main titan
lessw2020 3ca50d6
start similar run_training.sh for integration
lessw2020 4a3f565
use parallelism from toml file instead of hardcoded
lessw2020 d562674
move parallelism into main loop
lessw2020 0457b9c
integrate bs and seqlen from config
lessw2020 fd96293
add world size vs parallel size
lessw2020 972c211
add hf_tokenizer and hf dataloader
lessw2020 fccc6d3
now generating c4 data batches for training
lessw2020 4f50c6f
now training with real c4 data
lessw2020 2fdabcc
cross entropy loss working (titan main style)
lessw2020 921a244
training working with real data, optimizer, lr scheduler
lessw2020 4c3209b
remove synthetic data generation
lessw2020 71ac5af
use toml for training steps
lessw2020 a1f1bfe
create metrics processor
lessw2020 16e44ef
add color
lessw2020 5ec3cae
metrics tracking integrated
lessw2020 a363e15
start expert token tracking
lessw2020 9a04c58
token tracking working
lessw2020 245d313
update csv report for token tracking
lessw2020 4ac7c34
export topk reports for csv
lessw2020 36ba7e6
current status for expert token tracking
lessw2020 c594d8a
small batch run
lessw2020 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ build | |
outputs | ||
dist/* | ||
.vscode | ||
*.csv | ||
|
||
# data | ||
data | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from torchtitan.tools.logging import logger | ||
from transformers import AutoTokenizer | ||
|
||
|
||
def get_hf_tokenizer(model_id: str): | ||
logger.info(f"Instantiating tokenizer for {model_id}") | ||
return AutoTokenizer.from_pretrained(model_id) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import cutlass | ||
import cutlass.cute as cute | ||
import torch | ||
|
||
|
||
@cute.kernel | ||
def kernel(): | ||
tidx, _, _ = cute.arch.thread_idx() | ||
if tidx == 0: | ||
cute.printf(">>> Hello world! from Kernel") | ||
|
||
|
||
@cute.jit | ||
def hello_world(): | ||
cute.printf(">>> Hello world! from CPU") | ||
|
||
cutlass.cuda.initialize_cuda_context() | ||
kernel().launch( | ||
grid=(1, 1, 1), | ||
block=(32, 1, 1), | ||
) | ||
|
||
|
||
print("running with out compile...") | ||
hello_world() | ||
print(f"\n\nrunning with compile...") | ||
compiled = cute.compile(hello_world) | ||
print(f"\n\ncompiled function: {compiled}") | ||
compiled() | ||
print("done") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
127 changes: 127 additions & 0 deletions
127
torchtitan/experiments/deepseek_v3/infra/parallelize_deepseek.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
|
||
from torch.distributed.device_mesh import DeviceMesh | ||
from torch.distributed.fsdp import fully_shard | ||
from torch.distributed.pipelining import PipelineStage, Schedule1F1B | ||
|
||
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP | ||
from torchtitan.distributed import ParallelDims | ||
|
||
# from checkpoint import load_weights_from_hf | ||
from torchtitan.experiments.deepseek_v3.models.model import DeepseekForCausalLM | ||
from torchtitan.experiments.deepseek_v3.models.model_config import ( | ||
deepseek_config_registry, | ||
) | ||
|
||
from torchtitan.models.llama3.parallelize_llama import ( | ||
apply_ac, | ||
apply_compile, | ||
apply_ddp, | ||
apply_fsdp, | ||
apply_tp, | ||
) | ||
from torchtitan.tools.logging import logger | ||
|
||
|
||
# Use DeepSeek-V2-Lite as a proxy | ||
model_id = "deepseek-ai/DeepSeek-V2-Lite" | ||
|
||
|
||
# from ..model.moe import MoE | ||
|
||
|
||
# Get model parallel subgroup by name: | ||
# e.g. "pp", "ep", None | ||
def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup: | ||
glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh() | ||
return glob.get_group(dim_name) | ||
|
||
|
||
def parallelize_deepseek( | ||
# model: nn.Module, | ||
world_mesh: DeviceMesh, | ||
device: torch.device, | ||
model_args, | ||
rank: int, | ||
# parallel_dims: ParallelDims, | ||
# job_config: JobConfig, | ||
): | ||
""" | ||
Apply parallelism to the model. | ||
|
||
NOTE: The passed-in model preferably should be on meta device. Otherwise, | ||
the model must fit on GPU or CPU memory. | ||
""" | ||
logger.info("Applying parallelism to the model...") | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
|
||
pp_mesh = world_mesh["pp"] | ||
ep_mesh = world_mesh["ep"] | ||
pp_rank = pp_mesh.get_local_rank() | ||
ep_rank = ep_mesh.get_local_rank() | ||
pp_size = pp_mesh.size() | ||
ep_size = ep_mesh.size() | ||
|
||
# Apply data parallelism | ||
fsdp_mesh = world_mesh["fsdp"] | ||
hsdp_mesh = world_mesh["ep", "fsdp"] | ||
|
||
hsdp_size = hsdp_mesh.size() | ||
|
||
# Apply model parallelism | ||
model_args.ep_size = ep_size | ||
model_args.num_stages = pp_size | ||
model_args.stage_idx = pp_rank | ||
logger.info( | ||
f"Parallelism: {rank=}, {ep_size=}, {pp_size=}, {model_args.ep_size=}, {model_args.num_stages=}, {model_args.stage_idx=}" | ||
) | ||
# print(model_args) | ||
# verify world size matches parallelized total | ||
parallelized_world_size = pp_size * hsdp_size | ||
logger.info(f"Total Parallelized World size {parallelized_world_size}") | ||
assert ( | ||
world_size == parallelized_world_size | ||
), f"mismatch between total world size {world_size=} and parallelized total {parallelized_world_size}" | ||
|
||
# Instantiate model | ||
with device, world_mesh: | ||
model = DeepseekForCausalLM(model_args) | ||
# Load weights | ||
# load_weights_from_hf(model, model_id, device) | ||
model.train() | ||
|
||
# Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the | ||
# optimizer (Zero-1) and gradients (Zero-2), but not the model weights. | ||
# Reason: the MoE is "sparsely activated" compared to the dense model, thus | ||
# it will be ineconomical re-gather the weights. | ||
for layer in model.model.layers.values(): | ||
# Apply FSDP to experts | ||
if hasattr(layer.mlp, "experts"): | ||
for expert in layer.mlp.experts.values(): | ||
fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False) | ||
# Apply HSDP to other parts such as attention, layernorm, because they | ||
# are doing DDP on EP dimension | ||
fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False) | ||
|
||
# Apply HSDP on root model (lm_head, embeddings, etc) | ||
fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False) | ||
|
||
return ( | ||
model, | ||
pp_size, | ||
pp_rank, | ||
pp_mesh, | ||
ep_size, | ||
ep_rank, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is only for MoE-based models, how about let's use https://github.com/pytorch/torchtitan/blob/main/docs/extension.md#extending-jobconfig
and create a separate
.py
config file in the deepseek folder. Later we can see if we can reuse them for Llama 4 and DeepSeek.