-
Notifications
You must be signed in to change notification settings - Fork 32
support deepseekv4 flash fsdp2 training and reduce Transformers v5 model loading memory pressure #190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
meichangsu1
wants to merge
8
commits into
modelscope:main
Choose a base branch
from
meichangsu1:dsv4_fsdp2_ljl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
support deepseekv4 flash fsdp2 training and reduce Transformers v5 model loading memory pressure #190
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
cf53e4e
feat(template): add agent template support with ReAct-style tool calling
meichangsu1 cc0a7ef
fix: enhance FSDP decoder layer detection with no_split_modules support
meichangsu1 400acfb
fix: broadcast non-persistent buffers and validate state dict shapes …
meichangsu1 a8dbf74
fix(template): rename DeepSeekV4AgentTemplate to DeepseekV4Template a…
meichangsu1 09879ae
fix: update LoRA target modules and NPU dtype alignment in deepseek_v…
meichangsu1 00c41dd
fix: update DeepSeek-V4-Flash model ID and LoRA target modules
meichangsu1 6cecbba
refactor(native_fsdp): replace custom sharded state dict broadcast wi…
meichangsu1 d292147
fix: correct is_npu_device detection in _ensure_lora_dtype
meichangsu1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| import os | ||
|
|
||
| import twinkle | ||
| from peft import LoraConfig | ||
| from transformers import AutoConfig | ||
| from twinkle import DeviceMesh, Platform, get_device_placement, get_logger | ||
| from twinkle.dataloader import DataLoader | ||
| from twinkle.dataset import Dataset, DatasetMeta | ||
| from twinkle.model import TransformersModel | ||
| from twinkle.preprocessor import SelfCognitionProcessor | ||
|
|
||
| logger = get_logger() | ||
| # `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights. | ||
| # Convert the checkpoint before training by following: | ||
| # https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87 | ||
| # Install `transformers==5.8.0` before running this cookbook. | ||
| MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-Flash') | ||
| DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') | ||
| TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') | ||
| OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') | ||
|
|
||
| NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '4')) | ||
|
|
||
| BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) | ||
| GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2')) | ||
| LR = float(os.environ.get('LR', '1e-4')) | ||
| MAX_STEPS = int(os.environ.get('MAX_STEPS', '0')) | ||
| SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50')) | ||
| RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1' | ||
| GRADIENT_CHECKPOINTING = True | ||
| IGNORE_MISMATCHED_SIZES = False | ||
| LORA_TARGET_MODULES = [ | ||
| 'q_a_proj', | ||
| 'q_b_proj', | ||
| 'kv_proj', | ||
| 'o_b_proj', | ||
| 'gate_proj', | ||
| 'up_proj', | ||
| 'down_proj', | ||
| ] | ||
| ADAPTER_NAME = 'default' | ||
|
|
||
| device_mesh = DeviceMesh.from_sizes( | ||
| fsdp_size=4, | ||
| dp_size=1, | ||
| device_type=Platform.get_platform().device_prefix(), | ||
| ) | ||
|
|
||
| twinkle.initialize(mode='local', global_device_mesh=device_mesh) | ||
|
|
||
|
|
||
| def create_dataset(data_slice=None): | ||
| dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=data_slice or range(1000))) | ||
| dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) | ||
| dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) | ||
| dataset.encode(batched=True) | ||
| return dataset | ||
|
|
||
|
|
||
| def eval(model): | ||
| dataset = create_dataset(data_slice=range(100)) | ||
| dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) | ||
| for _, batch in enumerate(dataloader): | ||
| if callable(batch): | ||
| batch = batch() | ||
| model.forward_only(inputs=batch, adapter_name=ADAPTER_NAME) | ||
| model.calculate_loss(adapter_name=ADAPTER_NAME) | ||
| return model.calculate_metric(is_training=False, adapter_name=ADAPTER_NAME) | ||
|
|
||
|
|
||
| def train(): | ||
| dataset = create_dataset() | ||
| dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) | ||
|
|
||
| config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) | ||
| if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'): | ||
| config.num_hidden_layers = NUM_LAYERS | ||
| if hasattr(config, 'use_cache'): | ||
| config.use_cache = False | ||
|
|
||
| model = TransformersModel( | ||
| model_id=MODEL_ID, | ||
| config=config, | ||
| device_mesh=device_mesh, | ||
| strategy='native_fsdp', | ||
| memory_efficient_init=True, | ||
| ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, | ||
| fsdp_config={ | ||
| 'reshard_after_forward': RESHARD_AFTER_FORWARD, | ||
| }, | ||
| ) | ||
|
|
||
| lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=LORA_TARGET_MODULES) | ||
| model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS) | ||
|
|
||
| if not GRADIENT_CHECKPOINTING: | ||
| model.model.gradient_checkpointing_disable() | ||
|
|
||
| model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name=ADAPTER_NAME) | ||
| model.set_optimizer('AdamW', lr=LR, foreach=False, adapter_name=ADAPTER_NAME) | ||
| model.set_lr_scheduler( | ||
| scheduler_cls='CosineWarmupScheduler', | ||
| num_warmup_steps=5, | ||
| num_training_steps=len(dataloader), | ||
| adapter_name=ADAPTER_NAME, | ||
| ) | ||
|
|
||
| logger.info(get_device_placement()) | ||
| logger.info(model.get_train_configs(adapter_name=ADAPTER_NAME)) | ||
| logger.info( | ||
| f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, ' | ||
| f'grad_accum={GRAD_ACCUM_STEPS}, lr={LR:.2e}, ' | ||
| f'num_layers={NUM_LAYERS}, ignore_mismatched_sizes={IGNORE_MISMATCHED_SIZES}, ' | ||
| f'gradient_checkpointing={GRADIENT_CHECKPOINTING}, ' | ||
| f'reshard_after_forward={RESHARD_AFTER_FORWARD}, ' | ||
| f'lora_target_modules={LORA_TARGET_MODULES}') | ||
|
|
||
| best_loss = float('inf') | ||
| for step, batch in enumerate(dataloader): | ||
| if MAX_STEPS and step >= MAX_STEPS: | ||
| break | ||
| if callable(batch): | ||
| batch = batch() | ||
| model.forward_backward( | ||
| inputs=batch, | ||
| adapter_name=ADAPTER_NAME, | ||
| ) | ||
| model.clip_grad_and_step( | ||
| adapter_name=ADAPTER_NAME, | ||
| gradient_accumulation_steps=GRAD_ACCUM_STEPS, | ||
| ) | ||
|
|
||
| if step % 20 == 0: | ||
| metric = model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME) | ||
| logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') | ||
|
|
||
| if step > 0 and step % SAVE_STEPS == 0: | ||
| metrics = eval(model) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| logger.info(f'Eval metric: {metrics}') | ||
| loss = float(metrics['loss']) | ||
| if loss < best_loss: | ||
| model.save(name=f'checkpoint-{step}', output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME) | ||
| best_loss = loss | ||
|
|
||
| model.save(name='last-checkpoint', output_dir=OUTPUT_DIR, adapter_name=ADAPTER_NAME) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| train() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # `deepseek-ai/DeepSeek-V4-Flash` uses mixed FP4/FP8 weights. | ||
| # Convert the checkpoint before training by following: | ||
| # https://gitcode.com/cann/cann-recipes-train/blob/master/llm_pretrain/deepseekv4/README.md#%E6%A8%A1%E5%9E%8B%E6%9D%83%E9%87%8D%E5%87%86%E5%A4%87 | ||
| # Install `transformers==5.8.0` before running this cookbook. | ||
|
|
||
| CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 cookbook/transformers/deepseek_v4_flash.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function name
evalshadows the Python built-ineval()function. It is recommended to rename it to something more descriptive, such asevaluateorrun_eval, to avoid confusion and potential name resolution issues. Note that the call site at line 135 should also be updated.