Skip to content

DeepNVMe update #966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 48 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
8106bb8
Fast model checkpointing
tjruwase Dec 30, 2021
761e4e5
Support both legacy and serialized formats
tjruwase Dec 31, 2021
5967c79
Add io_buffer_mb option
tjruwase Jan 3, 2022
d96f1f6
Bug fix
tjruwase Jan 3, 2022
bbd96f2
Force flush
tjruwase Jan 3, 2022
3a16127
More model options; Refactor common codes
tjruwase Jan 4, 2022
c3df495
--gpu option
tjruwase Jan 5, 2022
315f02a
--half and more flexible options
tjruwase Jan 5, 2022
a41ba08
Add deepspeed.save_checkpoint()
tjruwase Jan 8, 2022
4fcb060
Free ds memory
tjruwase Jan 8, 2022
a49c542
Improve repro
tjruwase Jan 8, 2022
233b9e9
Double I/O buffer (#56)
tjruwase Feb 22, 2022
b1f02b2
Double I/O buffer (#60)
tjruwase Mar 11, 2022
a16ac9e
Add checkpoint comparison (#62)
jerryyangli Mar 15, 2022
b945adc
save_checkpoint perf monitoring
tjruwase Mar 19, 2022
2c7a5ed
Merge branch 'staging-fast-model-checkpoint-v2' of github.com:microso…
tjruwase Mar 19, 2022
64a8f75
Disable checkpoint save on exit
tjruwase Mar 22, 2022
44b8664
Perf statistics for save_checkpoint (#64)
tjruwase Mar 22, 2022
ff4bd69
add logs for a100-80
GuanhuaWang Sep 21, 2022
e4817a1
add torch* error log with half flag but without fused flag
GuanhuaWang Sep 22, 2022
b297e17
log for error
GuanhuaWang Sep 22, 2022
f05dab1
local rank arg
tjruwase Oct 5, 2022
fc4291f
Merge branch 'staging-fast-model-checkpoint-v2' of github.com:microso…
tjruwase Oct 5, 2022
db295f1
Merge branch 'staging-fast-model-checkpoint-v2' of github.com:microso…
tjruwase Oct 5, 2022
1aa971a
Handle local_rank arg (#78)
tjruwase Oct 5, 2022
98b2f8a
Single writer option
tjruwase Oct 5, 2022
2e42285
Single writer option (#79)
tjruwase Oct 5, 2022
09dbd8a
Merge branch 'staging-fast-model-checkpoint-v3' of github.com:microso…
tjruwase Oct 7, 2022
a567adf
Allow missing folder
tjruwase Oct 12, 2022
65793bd
DP writer refactor
tjruwase Feb 10, 2023
5bfdf04
Update for DS; Add GDS
tjruwase Feb 12, 2025
9a27914
Integrate GDS into deepspeed_model_save
tjruwase Feb 20, 2025
53572f8
Rebase fast persist
tjruwase Feb 25, 2025
515dded
Rebase fast persist (#184)
tjruwase Feb 25, 2025
d01aa27
Move folder
tjruwase Mar 26, 2025
e5a316f
Merge branch 'olruwase/fast_persist' of github.com:microsoft/DeepSpee…
tjruwase Mar 26, 2025
4059f80
Remove folder
tjruwase Mar 26, 2025
1c3a54c
More cleanup
tjruwase Mar 26, 2025
9a8540b
torch changes
tjruwase Mar 27, 2025
ee2f081
sglang+zero_inference
tjruwase Apr 7, 2025
ad81cec
Remove file
tjruwase Apr 7, 2025
dff5274
Add offload configs
tjruwase Apr 8, 2025
d84bb56
Add pin_memory
tjruwase Apr 8, 2025
db3b32b
Cleanup scripts
tjruwase Apr 8, 2025
6ee91cb
SGLang README
tjruwase Apr 12, 2025
e283b74
Remove file
tjruwase Apr 12, 2025
54872e1
Merge branch 'master' into olruwase/fast_persist
tjruwase Apr 14, 2025
d971d84
Merge branch 'master' into olruwase/fast_persist
loadams May 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion deepnvme/file_access/aio_load_cpu_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os, timeit, functools
from deepspeed.ops.op_builder import AsyncIOBuilder
from utils import parse_read_arguments, GIGA_UNIT
from deepspeed.accelerator import get_accelerator

def file_read(inp_f, handle, bounce_buffer):
handle.sync_pread(bounce_buffer, inp_f)
Expand All @@ -14,7 +15,12 @@ def main():
cnt = args.loop

aio_handle = AsyncIOBuilder().load().aio_handle()
bounce_buffer = torch.empty(os.path.getsize(input_file), dtype=torch.uint8).pin_memory()
native_locked_tensor = get_accelerator()._name == 'cpu'

if native_locked_tensor:
bounce_buffer = aio_handle.new_cpu_locked_tensor(file_sz, torch.Tensor().to(torch.uint8))
else:
bounce_buffer = torch.empty(file_sz, dtype=torch.uint8).pin_memory()

t = timeit.Timer(functools.partial(file_read, input_file, aio_handle, bounce_buffer))
aio_t = t.timeit(cnt)
Expand All @@ -27,5 +33,8 @@ def main():
py_tensor = py_file_read(input_file)
print(f'Validation success = {aio_tensor.equal(py_tensor)}')

if native_locked_tensor:
aio_handle.free_cpu_locked_tensor(bounce_buffer)

if __name__ == "__main__":
main()
10 changes: 9 additions & 1 deletion deepnvme/file_access/aio_store_cpu_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os, timeit, functools, pathlib
from deepspeed.ops.op_builder import AsyncIOBuilder
from utils import parse_write_arguments, GIGA_UNIT
from deepspeed.accelerator import get_accelerator

def file_write(out_f, tensor, handle, bounce_buffer):
bounce_buffer.copy_(tensor)
Expand All @@ -14,9 +15,13 @@ def main():
pathlib.Path(output_file).unlink(missing_ok=True)
file_sz = args.mb_size*(1024**2)
app_tensor = torch.empty(file_sz, dtype=torch.uint8, device='cpu', requires_grad=False)
native_locked_tensor = get_accelerator()._name == 'cpu'

aio_handle = AsyncIOBuilder().load().aio_handle()
bounce_buffer = torch.empty(file_sz, dtype=torch.uint8, requires_grad=False).pin_memory()
if native_locked_tensor:
bounce_buffer = aio_handle.new_cpu_locked_tensor(file_sz, torch.Tensor().to(torch.uint8))
else:
bounce_buffer = torch.empty(file_sz, dtype=torch.uint8, requires_grad=False).pin_memory()


t = timeit.Timer(functools.partial(file_write, output_file, app_tensor, aio_handle, bounce_buffer))
Expand All @@ -33,6 +38,9 @@ def main():
filecmp.clear_cache()
print(f'Validation success = {filecmp.cmp(py_ref_file, output_file, shallow=False) }')

if native_locked_tensor:
aio_handle.free_cpu_locked_tensor(bounce_buffer)

pathlib.Path(output_file).unlink(missing_ok=True)


Expand Down
139 changes: 139 additions & 0 deletions deepnvme/model_checkpoint/deepspeed_save_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import time
import torch
import os
import shutil
import gc
import random
import numpy as np
import deepspeed
from deepspeed.accelerator import get_accelerator
from save_model_utils import get_model, validate_arguments, parse_arguments

def _get_ds_config(args, writer_type, use_gds):
ds_config = {
"train_micro_batch_size_per_gpu": 1,
"zero_optimization": {
"stage": args.zero_stage,
"cpu_offload": args.cpu_offload
},
"fp16": {
"enabled": args.half
},
"optimizer": {
"type": "Adam",
"params": {
"torch_adam": not args.fused
}
},
"checkpoint": {
"checkpoint_serialization": not args.legacy
},
"aio": {
"block_size": 8 * (1024**2),
"queue_depth": 8,
"single_submit": False,
"overlap_events": True,
"intra_op_parallelism": 2,
"use_gds": use_gds,
}
}

if writer_type:
ds_config["checkpoint"]["writer"] = {
"type": writer_type,
"io_buffer_size": args.io_buffer_mb * (1024**2),
"io_buffer_double": not args.single_io_buffer,
"show_statistics": not args.no_statistics,
"data_parallel": "socket" # None # not args.single_writer
}

return ds_config


def _get_ds_engine(model, ds_config):
ds_engine, _, _, _ = deepspeed.initialize(
model=model, model_parameters=model.parameters(), config=ds_config)

return ds_engine


def _do_optimizer_step(ds_engine):
for p in ds_engine.module.parameters():
p.grad = torch.zeros_like(p)
ds_engine.step()


def _free_ds_memory(ds_engine):
ds_engine.optimizer.optimizer = None
ds_engine.optimizer = None
ds_engine.module = None
ds_engine = None
del ds_engine
gc.collect()
get_accelerator().empty_cache()


def test_save(tag, folder, model, args, writer_type):
use_gds = writer_type == 'fast' and 'gds' in tag
ds_config = _get_ds_config(args, writer_type, use_gds)
ds_engine = _get_ds_engine(model, ds_config)
if args.zero_stage == 0:
_do_optimizer_step(ds_engine)

st = time.time()
ds_engine.save_checkpoint(save_dir=folder, tag=tag)
write_sec = time.time() - st
_free_ds_memory(ds_engine)
return write_sec


def _get_folder_size(folder):
size = 0
for path, _, files in os.walk(folder):
size += sum([os.path.getsize(os.path.join(path, f)) for f in files])
return size


def run(model, model_name, ckpt_name, args):
print(f'Model name = {model_name}')
writer_dict = {
'test_save': None,
'test_ds_mock_save': 'mock',
'test_ds_py_save': 'python',
'test_ds_aio_fast_save': 'fast',
'test_ds_gds_fast_save': 'fast',
}
for tag, writer_type in writer_dict.items():
folder = os.path.join(args.folder, ckpt_name, tag)
if os.path.exists(folder):
shutil.rmtree(folder, ignore_errors=True)
# if not os.path.exists(folder):
# os.makedirs(folder, exist_ok=True)
write_sec = test_save(tag, folder, model, args, writer_type)
ckpt_size = _get_folder_size(folder)
gb_size = ckpt_size / (1024**3)
gb_per_sec = gb_size / write_sec
print(
f'{tag} -- {gb_size:5.2f} GB, {write_sec:5.2f} secs, {gb_per_sec:5.2f} GB/s'
)
print(f'*********************************************')


def main():
print(
f'Performance test of deepspeed integration of fast model checkpointing.'
)
print(f'torch version = {torch.__version__}')
torch.manual_seed(42)
np.random.seed(0)
random.seed(0)
args = parse_arguments()
if not validate_arguments(args):
quit()

model, model_name, ckpt_name = get_model(args.model)
run(model, model_name, ckpt_name, args)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions deepnvme/model_checkpoint/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers
116 changes: 116 additions & 0 deletions deepnvme/model_checkpoint/save_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import argparse
import os
from transformers import AutoModelForCausalLM
from transformers import T5ForConditionalGeneration
from torch_save_utils import PINNED_BUFFER_MB


GPT2L = 'gpt2-large'
TINY_T5 = 'tiny-t5'
PHI3_MINI = 'phi3'
PHI3_VISION = 'phi3-v'
LLAMA3_1B = 'llama3-1B'

HF_MODELS_DICT = {
TINY_T5: "hf-internal-testing/tiny-random-t5",
GPT2L: GPT2L,
PHI3_MINI: "microsoft/Phi-3.5-mini-instruct",
PHI3_VISION: "microsoft/Phi-3.5-vision-instruct",
LLAMA3_1B: "meta-llama/Llama-3.2-1B",
}

def _get_hf_model(tag):
model_name = HF_MODELS_DICT[tag]
if tag == TINY_T5:
model = T5ForConditionalGeneration.from_pretrained(model_name)
else:
model = AutoModelForCausalLM.from_pretrained(model_name)

return model, model_name, tag

def get_model(model_tag):
return _get_hf_model(model_tag)


def validate_arguments(args):
success = True

if not args.model in HF_MODELS_DICT:
print(f'{args.model} is not a supported HF model tag')
success = False

if args.optimizer and args.half:
if not args.gpu:
print(f'mixed precision only supported with gpu tensors')
success = False

return success


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--folder',
default=None,
type=str,
required=True,
help='Folder to use for I/O.')

parser.add_argument(
'--model',
default=None,
type=str,
required=True,
help=f'HuggingFace tag of model. Available models = {list(HF_MODELS_DICT.keys())}')

parser.add_argument('--local_rank',
type=int,
default=0,
help='Local rank' )

parser.add_argument('--legacy',
action='store_true',
help='Use torch legacy save format')

parser.add_argument('--optimizer',
action='store_true',
help='Include optimizer state in checkpoint.')

parser.add_argument('--fused',
action='store_true',
help='Use fused fp16 optimizer.')

parser.add_argument('--gpu', action='store_true', help='Use gpu tensors.')

parser.add_argument('--half',
action='store_true',
help='Use half-precision tensors.')

parser.add_argument(
'--io_buffer_mb',
type=int,
default=PINNED_BUFFER_MB,
help=f'Size of pinned i/o buffer in MB. Default = {PINNED_BUFFER_MB}')

parser.add_argument('--zero_stage',
type=int,
default=0,
help='ZeRO optimization stage. Default = 0')

parser.add_argument('--cpu_offload',
action='store_true',
help='Enable CPU offload of optimizer state.')

parser.add_argument('--no-statistics',
action='store_true',
help='Suppress low-level performance statistics.')

parser.add_argument('--single_io_buffer',
action='store_true',
help='Disable double buffering of i/o buffer.')


#parser.add_argument('--single_writer', action='store_true', help='Disable parallel rank writes of data parallel (replicated) state')

args = parser.parse_args()
print(f'args = {args}')
return args
Loading