Skip to content

[DeepSeek] Move seqlen from model config to setup_symm_mem #1017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: symm_bwd
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions torchtitan/experiments/deepseek_v3/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,18 @@ def create_model(dist_config: DistConfig):
model_args.ep_size = dist_config.ep_size
model_args.num_stages = dist_config.pp_size
model_args.stage_idx = dist_config.pp_rank
model_args.max_seq_len = 16384
max_seq_len = 16384

with dist_config.device, dist_config.mesh:
model = DeepseekForCausalLM(model_args)
load_weights_from_hf(model, model_id, dist_config.device)
model.train()
model.eval()
model.setup_symm_mem(
max_seq_len,
torch.bfloat16,
dist_config.device,
microbatches=dist_config.pp_size,
)

return model, PipelineStage(
model,
Expand Down
143 changes: 76 additions & 67 deletions torchtitan/experiments/deepseek_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

from attn_mask_utils import _prepare_4d_causal_attention_mask
from model_config import ModelArgs
from symm_mem_recipes import on_device_all_to_all_v
from symm_mem_recipes import OnDeviceAllToAllV
from torch import nn
from torch.distributed._functional_collectives import all_to_all_single_autograd

Expand Down Expand Up @@ -446,11 +446,14 @@ class MoE(nn.Module):
"""

# Class attributes:
# Two shuffle method supported:
# 1. "torch_all_to_all"
# 2. "symm_mem" (see `setup_symm_mem` below)
shuffle_method = "torch_all_to_all"

# Symmetric memory buffers shared by all MoE instances across layers
token_send_buf: Optional[torch.Tensor] = None
token_gather_buf: Optional[torch.Tensor] = None
input_splits: Optional[torch.Tensor] = None
output_splits: Optional[torch.Tensor] = None
token_send_buf: Optional[torch.Tensor] = []
token_gather_buf: Optional[torch.Tensor] = []

def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -483,55 +486,70 @@ def __init__(self, config):
self.shared_experts = MLP(
config=config, intermediate_size=intermediate_size
)
# Two shuffle method supported:
# 1. "torch_all_to_all"
# 2. "symm_mem" (see `setup_symm_mem` below)
self.shuffle_method = "torch_all_to_all"

# This function is used to create a symm mem buffer for MoE. It is for
# This function is used to create a symm mem buffer for MoE's. It is for
# shuffling tokens fully "on-device", as compared to traditional torch
# all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user
# calls this function, the `shuffle_method` would switch from
# `torch_all_to_all` to `symm_mem`.

# Status: supports inference. For training, this is disabled for now. Reason
# is that autograd requires tensors not be modified a second time, this
# conflicts with our wish of sharing the symm mem across layers and/or
# PP microbatches.
def setup_symm_mem(self, dtype, device):
def setup_symm_mem(
self,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
microbatches: int,
):
# Switch shuffle method
self.shuffle_method = "symm_mem"
self.microbatches = microbatches
self.curr_send = 0
self.curr_gather = 0

# Input buffer: seq len * top k (flattened)
input_len = max_seq_len * self.num_experts_per_tok
# Assuming worst case, 2x tokens are routed to one EP rank
max_output_len = 2 * input_len
# This value is needed by `OnDeviceAllToAllV` to prepare the output buffer
OnDeviceAllToAllV.max_output_len = max_output_len

# Symmetric memory buffers are shared by all MoE instances across
# layers, so we only need to initialize them once
if MoE.token_send_buf is not None:
# layers, we only need to initialize them once
if len(MoE.token_send_buf):
return

# But they are not shared across microbatches
# Input buffer for DP-to-EP shuffle
MoE.token_send_buf = symm_mem.empty(
self.config.max_seq_len
* self.num_experts_per_tok, # seq len * top k (flattened)
self.config.hidden_size, # hidden dim
dtype=dtype,
device=device,
)
# Number of tokens to send to EP peers, aka. input splits
MoE.input_splits = symm_mem.empty(
self.ep_size, dtype=torch.int64, device=device
)
# Number of tokens to receive from EP peers, aka. output splits
MoE.output_splits = symm_mem.empty(
self.ep_size, dtype=torch.int64, device=device
)
for _ in range(self.microbatches):
MoE.token_send_buf.append(
symm_mem.empty(
input_len,
self.config.hidden_size, # hidden dim
dtype=dtype,
device=device,
)
)
# Input buffer for EP-to-DP shuffle
MoE.token_gather_buf = symm_mem.empty(
# worst case, all tokens are routed to one EP rank
MoE.token_send_buf.shape[0] * self.ep_size,
self.config.hidden_size, # hidden dim
dtype=dtype,
device=device,
)
for _ in range(self.microbatches):
MoE.token_gather_buf.append(
symm_mem.empty(
max_output_len,
self.config.hidden_size, # hidden dim
dtype=dtype,
device=device,
)
)
print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE")

def get_next_send_buf(self):
rv = self.token_send_buf[self.curr_send]
self.curr_send = (self.curr_send + 1) % self.microbatches
return rv

def get_next_gather_buf(self):
rv = self.token_gather_buf[self.curr_gather]
self.curr_gather = (self.curr_gather + 1) % self.microbatches
return rv

def forward(self, hidden_states):
identity = hidden_states
orig_shape = hidden_states.shape
Expand Down Expand Up @@ -574,35 +592,28 @@ def moe_forward(self, x, topk_ids, topk_weight):
dist.all_to_all_single(
tokens_per_expert_group, tokens_per_expert, group=self.ep_group
)
input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)

# DP to EP token shuffle. This part needs gradient.
if self.shuffle_method == "symm_mem":
# Move input to the `token_send_buf` symm mem
MoE.token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
token_send_buf = self.get_next_send_buf()
token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
# Note: `out=` avoids copy, but it is not differentiable
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=MoE.token_send_buf[: idxs.shape[0]])
with torch.no_grad():
torch.sum(
tokens_per_expert.view(self.ep_size, -1),
dim=1,
out=MoE.input_splits,
)
on_device_all_to_all_v(
MoE.token_gather_buf,
MoE.output_splits,
MoE.token_send_buf,
MoE.input_splits,
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
token_send_buf,
input_splits,
self.ep_group,
)
with torch.no_grad():
# Received tokens from all other ranks. TODO: use mask instead
received = MoE.output_splits.sum()
received = output_splits.sum()
# TODO: don't use `received`
gathered_tokens = MoE.token_gather_buf[:received]
gathered_tokens = token_gather_buf[:received]
else: # "torch_all_to_all"
# Prepare input ans output splits
with torch.no_grad():
input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
dim=1
)
Expand All @@ -627,9 +638,9 @@ def moe_forward(self, x, topk_ids, topk_weight):

# Prepare buffer for tokens processed by experts
if self.shuffle_method == "symm_mem":
# Take necessary space from `token_send_buf` symm mem because we are
# Take necessary space from `token_gather_buf` symm mem because we are
# going to send them out after expert processing
processed_tokens = MoE.token_send_buf[: gathered_tokens.shape[0]]
processed_tokens = self.get_next_gather_buf()[: gathered_tokens.shape[0]]
else: # "torch_all_to_all"
processed_tokens = torch.empty_like(gathered_tokens)

Expand All @@ -643,15 +654,12 @@ def moe_forward(self, x, topk_ids, topk_weight):
# Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
# The input/output splits are just a reverse of the previous shuffle.
if self.shuffle_method == "symm_mem":
# Take necessary space from `token_gather_buf` symm mem to receive processed tokens
returned_tokens = MoE.token_gather_buf[: sorted_tokens_shape[0]]
on_device_all_to_all_v(
returned_tokens,
MoE.input_splits, # unused
token_return_buf, _ = OnDeviceAllToAllV.apply(
processed_tokens,
MoE.output_splits,
output_splits,
self.ep_group,
)
returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
else: # "torch_all_to_all"
returned_tokens = all_to_all_single_autograd(
processed_tokens,
Expand Down Expand Up @@ -1188,9 +1196,10 @@ def _reorder_cache(past_key_values, beam_idx):
return reordered_past

# Setup Symmetric Memory for MoE token shuffle.
# Supports inference currently.
def setup_symm_mem(self, dtype, device):
def setup_symm_mem(
self, seq_len: int, dtype: torch.dtype, device: torch.device, microbatches: int
):
for layer in self.model.layers.values():
if not isinstance(layer.mlp, MoE):
continue
layer.mlp.setup_symm_mem(dtype, device)
layer.mlp.setup_symm_mem(seq_len, dtype, device, microbatches)
3 changes: 0 additions & 3 deletions torchtitan/experiments/deepseek_v3/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ class ModelArgs:
attention_bias: bool = False
attention_dropout: float = 0.0
pad_token_id = None
# Added for symmetric memory
max_seq_len: int = 4096
dtype: str = "bfloat16"
# Added for pipeline parallel
num_stages: int = 1
stage_idx: int = 0
Expand Down
21 changes: 16 additions & 5 deletions torchtitan/experiments/deepseek_v3/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,19 @@ def run_full_model(
# Apply HSDP on root model (lm_head, embeddings, etc)
fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False)

# Example inputs
bs = 2
# Synthetic setting
microbatches = pp_size * 2
bs = 4 # microbatch size
seqlen = 128
x = torch.randint(model_args.vocab_size, (bs, seqlen), device=device)
label = torch.rand(bs, seqlen, model_args.vocab_size, device=device)

# Use Symmetric Memory for MoE token shuffle. The number of tokens in each
# buffer would be microbatch size * seq_len, i.e. flattened.
model.setup_symm_mem(bs * seqlen, torch.bfloat16, device, microbatches)

# Example inputs
torch.manual_seed(ep_rank)
x = torch.randint(model_args.vocab_size, (microbatches * bs, seqlen), device=device)
label = torch.rand(microbatches * bs, seqlen, model_args.vocab_size, device=device)

# Create loss function
loss_fn = torch.nn.functional.cross_entropy
Expand All @@ -95,7 +103,6 @@ def run_full_model(
)

# Create pipeline schedule
microbatches = 2
losses = []
pp_schedule = Schedule1F1B(stage, microbatches, loss_fn=loss_fn)

Expand All @@ -115,6 +122,10 @@ def run_full_model(
print(f"logits: {y.shape}")
print(f"{loss=}")

if pp_rank == 0:
param = model.get_parameter("model.layers.0.self_attn.q_proj.weight")
print(f"{torch.linalg.norm(param.grad)=}")

print("Backward done")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .triton_on_device_all_to_all_v import on_device_all_to_all_v
from .triton_on_device_all_to_all_v import OnDeviceAllToAllV

__all__ = [
"on_device_all_to_all_v",
"OnDeviceAllToAllV",
]
Loading