Skip to content
Open
149 changes: 149 additions & 0 deletions cookbook/transformers/deepseek_v4_flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os

import twinkle
from peft import LoraConfig
from transformers import AutoConfig
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor

logger = get_logger()
# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights.
# Convert the checkpoint before training by following:
# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87
# Install `transformers==5.8.0` before running this cookbook.
MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash')
DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')

NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '4'))

BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2'))
LR = float(os.environ.get('LR', '1e-4'))
MAX_STEPS = int(os.environ.get('MAX_STEPS', '0'))
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50'))
RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1'
GRADIENT_CHECKPOINTING = True
IGNORE_MISMATCHED_SIZES = False
LORA_TARGET_MODULES = [
'q_a_proj',
'q_b_proj',
'kv_proj',
'o_b_proj',
'gate_proj',
'up_proj',
'down_proj',
]
ADAPTER_NAME = 'default'

device_mesh = DeviceMesh.from_sizes(
fsdp_size=4,
dp_size=1,
device_type=Platform.get_platform().device_prefix(),
)

twinkle.initialize(mode='local', global_device_mesh=device_mesh)


def create_dataset(data_slice=None):
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=data_slice or range(1000)))
dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
dataset.encode(batched=True)
return dataset


def eval(model):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function name eval shadows the Python built-in eval() function. It is recommended to rename it to something more descriptive, such as evaluate or run_eval, to avoid confusion and potential name resolution issues. Note that the call site at line 135 should also be updated.

Suggested change
def eval(model):
def evaluate(model):

dataset = create_dataset(data_slice=range(100))
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
for _, batch in enumerate(dataloader):
if callable(batch):
batch = batch()
model.forward_only(inputs=batch, adapter_name=ADAPTER_NAME)
model.calculate_loss(adapter_name=ADAPTER_NAME)
return model.calculate_metric(is_training=False, adapter_name=ADAPTER_NAME)


def train():
dataset = create_dataset()
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)

config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'):
config.num_hidden_layers = NUM_LAYERS
if hasattr(config, 'use_cache'):
config.use_cache = False

model = TransformersModel(
model_id=MODEL_ID,
config=config,
device_mesh=device_mesh,
strategy='native_fsdp',
memory_efficient_init=True,
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
fsdp_config={
'reshard_after_forward': RESHARD_AFTER_FORWARD,
},
)

lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=LORA_TARGET_MODULES)
model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS)

if not GRADIENT_CHECKPOINTING:
model.model.gradient_checkpointing_disable()

model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name=ADAPTER_NAME)
model.set_optimizer('AdamW', lr=LR, foreach=False, adapter_name=ADAPTER_NAME)
model.set_lr_scheduler(
scheduler_cls='CosineWarmupScheduler',
num_warmup_steps=5,
num_training_steps=len(dataloader),
adapter_name=ADAPTER_NAME,
)

logger.info(get_device_placement())
logger.info(model.get_train_configs(adapter_name=ADAPTER_NAME))
logger.info(
f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, '
f'grad_accum={GRAD_ACCUM_STEPS}, lr={LR:.2e}, '
f'num_layers={NUM_LAYERS}, ignore_mismatched_sizes={IGNORE_MISMATCHED_SIZES}, '
f'gradient_checkpointing={GRADIENT_CHECKPOINTING}, '
f'reshard_after_forward={RESHARD_AFTER_FORWARD}, '
f'lora_target_modules={LORA_TARGET_MODULES}')

best_loss = float('inf')
for step, batch in enumerate(dataloader):
if MAX_STEPS and step >= MAX_STEPS:
break
if callable(batch):
batch = batch()
model.forward_backward(
inputs=batch,
adapter_name=ADAPTER_NAME,
)
model.clip_grad_and_step(
adapter_name=ADAPTER_NAME,
gradient_accumulation_steps=GRAD_ACCUM_STEPS,
)

if step % 20 == 0:
metric = model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME)
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')

if step > 0 and step % SAVE_STEPS == 0:
metrics = eval(model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the call site to match the renamed evaluation function.

Suggested change
metrics = eval(model)
metrics = evaluate(model)

logger.info(f'Eval metric: {metrics}')
loss = float(metrics['loss'])
if loss < best_loss:
model.save(name=f'checkpoint-{step}', output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME)
best_loss = loss

model.save(name='last-checkpoint', output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME)


if __name__ == '__main__':
train()
6 changes: 6 additions & 0 deletions cookbook/transformers/deepseek_v4_flash.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights.
# Convert the checkpoint before training by following:
# https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87
# Install `transformers==5.8.0` before running this cookbook.

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 cookbook/transformers/deepseek_v4_flash.py
123 changes: 70 additions & 53 deletions src/twinkle/model/transformers/strategy/native_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh
from torch.distributed.fsdp import fully_shard
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set

from twinkle.utils import DeviceMesh, Platform, torch_util
from .load_context import fsdp_pretrained_load_context
Expand All @@ -30,7 +30,13 @@ def __init__(self,
self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None

def pretrained_load_context(self):
return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None)
# Native FSDP loads pretrained weights via rank0 broadcast during wrap_model().
# Avoid Transformers' FSDP loading env here; some versions can hang non-rank0
# ranks in from_pretrained barriers.
return fsdp_pretrained_load_context(False)

def use_rank0_pretrained_broadcast(self) -> bool:
return self._memory_efficient_init and self.device_mesh is not None

def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]:
if self.device_mesh is None:
Expand Down Expand Up @@ -59,14 +65,14 @@ def wrap_model(self, model, optimizer=None):
if optimizer is not None:
_unbind_optimizer_params(optimizer)

# EP path requires experts on a real device, incompatible with meta-device flow.
use_meta = self._memory_efficient_init and not ep_enabled
use_meta = self.use_rank0_pretrained_broadcast() and not ep_enabled

original_sd = None
saved_buffers = None
if use_meta:
original_sd = model.state_dict()
saved_buffers = _get_non_persistent_buffers(model)
is_rank0 = (dist.get_rank() == 0)
original_sd = model.state_dict() if is_rank0 else {}
saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {}
model = model.to(torch.device('meta'))
if hasattr(model, 'tie_weights'):
model.tie_weights()
Expand Down Expand Up @@ -129,14 +135,9 @@ def wrap_model(self, model, optimizer=None):

if use_meta:
device_type = self.device_mesh.device_type or 'cuda'
is_rank0 = (dist.get_rank() == 0)
_broadcast_sharded_state_dict(
model,
original_sd if is_rank0 else {},
device_type=device_type,
)
_load_rank0_full_state_dict(model, original_sd or {})
target_device = torch.device(device_type)
_restore_non_persistent_buffers(model, saved_buffers, device=target_device)
_broadcast_non_persistent_buffers(model, saved_buffers or {}, device=target_device)
if hasattr(model, 'tie_weights'):
model.tie_weights()

Expand Down Expand Up @@ -322,16 +323,43 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]:
return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=('fsdp', ))


def _get_decoder_layers(model: nn.Module) -> Optional[nn.ModuleList]:
def _get_decoder_layers(model: nn.Module) -> Optional[List[nn.Module]]:
no_split_modules = _get_no_split_module_names(model)
if no_split_modules:
layers = [
module for module in model.modules()
if module is not model and module.__class__.__name__ in no_split_modules
]
if layers:
return layers

inner_model = getattr(model, 'model', None)
if inner_model is not None:
inner_layers = getattr(inner_model, 'layers', None)
if isinstance(inner_layers, nn.ModuleList):
return inner_layers
return list(inner_layers)

return None


def _get_no_split_module_names(model: nn.Module) -> Set[str]:
names = _normalize_no_split_modules(getattr(model, '_no_split_modules', None))
if names:
return names

for module in model.modules():
names.update(_normalize_no_split_modules(getattr(module, '_no_split_modules', None)))
return names


def _normalize_no_split_modules(value) -> Set[str]:
if value is None:
return set()
if isinstance(value, str):
return {value}
return set(value)


def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]:
ignored: Set[nn.Parameter] = set()
ep_patched = False
Expand Down Expand Up @@ -473,41 +501,18 @@ def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> tor
return optimizer


def _broadcast_sharded_state_dict(
model: nn.Module,
full_sd: dict,
device_type: str = 'cuda',
) -> None:
"""Broadcast full state dict from rank 0 and materialise local shards via distribute_tensor."""
from torch.distributed.tensor import DTensor, distribute_tensor
def _load_rank0_full_state_dict(model: nn.Module, full_sd: dict) -> None:
"""Load rank0 full weights into a sharded FSDP2 model via DCP broadcast."""
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict

meta_sharded_sd = model.state_dict()
sharded_sd = {}
is_rank0 = (dist.get_rank() == 0)

for param_name, sharded_param in meta_sharded_sd.items():
shape = sharded_param.size()
dtype = sharded_param.dtype

if is_rank0:
full_param = full_sd[param_name]
full_tensor = full_param.detach().to(device_type)
if isinstance(full_tensor, DTensor):
full_tensor = full_tensor.to_local()
else:
full_tensor = torch.empty(shape, device=device_type, dtype=dtype)

dist.broadcast(full_tensor, src=0)
torch_util.synchronize()

device_mesh = sharded_param.device_mesh
placements = sharded_param.placements
sharded_tensor = distribute_tensor(full_tensor, device_mesh, placements)
del full_tensor

sharded_sd[param_name] = sharded_tensor

model.load_state_dict(sharded_sd, assign=True)
set_model_state_dict(
model=model,
model_state_dict=full_sd,
options=StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
),
)


def _get_non_persistent_buffers(model: nn.Module) -> Dict[str, torch.Tensor]:
Expand All @@ -529,14 +534,26 @@ def _unbind_optimizer_params(optimizer: torch.optim.Optimizer) -> None:
group['params'][i] = torch.empty(1, dtype=param.dtype, device=param.device)


def _restore_non_persistent_buffers(
def _broadcast_non_persistent_buffers(
model: nn.Module,
saved_buffers: Dict[str, torch.Tensor],
device: torch.device,
) -> None:
"""Re-register non-persistent buffers saved before to('meta')."""
for fqn, buf_tensor in saved_buffers.items():
buf_tensor = buf_tensor.to(device)
"""Broadcast rank0 non-persistent buffers and re-register them on all ranks."""
is_rank0 = (dist.get_rank() == 0)
metadata = None
if is_rank0:
metadata = [(name, tuple(tensor.shape), tensor.dtype) for name, tensor in saved_buffers.items()]
metadata_holder = [metadata]
dist.broadcast_object_list(metadata_holder, src=0)
metadata = metadata_holder[0] or []

for fqn, shape, dtype in metadata:
if is_rank0:
buf_tensor = saved_buffers[fqn].to(device)
else:
buf_tensor = torch.empty(shape, device=device, dtype=dtype)
dist.broadcast(buf_tensor, src=0)
if '.' in fqn:
parent_fqn, local_name = fqn.rsplit('.', 1)
parent = model.get_submodule(parent_fqn)
Expand Down
Loading
Loading