Skip to content

Commit 09858a7

Browse files
author
amaurya
committed
Add preserves_storage_sharing for checkpoint engines
Signed-off-by: amaurya <[email protected]>
1 parent 1c701d7 commit 09858a7

File tree

7 files changed

+81
-6
lines changed

7 files changed

+81
-6
lines changed

deepspeed/runtime/checkpoint_engine/checkpoint_engine.py

+5
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,8 @@ def commit(self, tag):
3232
def wait(self):
3333
# To wait in asynchronous checkpoint engines (e.g. DataStates-LLM) for the previous snapshot to finish
3434
pass
35+
36+
def preserves_storage_sharing(self):
37+
# Check if the checkpoint engine preserves storage sharing
38+
# (set to false if cloning is required to get actual tensor sizes)
39+
return False

deepspeed/runtime/checkpoint_engine/datastates_checkpoint_engine.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,6 @@ def commit(self, tag):
3232

3333
def wait(self):
3434
return self.ckpt_engine.wait()
35+
36+
def preserves_storage_sharing(self):
37+
return True

deepspeed/runtime/pipe/module.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,6 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
608608
layer_list = self.forward_funcs[start:end]
609609

610610
checkpoint_engine.makedirs(save_dir, exist_ok=True)
611-
debloat_memory = "DataStatesCheckpointEngine" not in str(type(checkpoint_engine))
612611
for idx, layer in enumerate(layer_list):
613612
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
614613
if not hasattr(layer, 'state_dict'):
@@ -619,10 +618,11 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
619618
for n in self._get_frozen_parameter_names(layer):
620619
del orig_state_dict[n]
621620

622-
if debloat_memory:
623-
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
624-
else:
621+
if checkpoint_engine.preserves_storage_sharing():
625622
final_state_dict = orig_state_dict
623+
else:
624+
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
625+
626626
checkpoint_engine.save(final_state_dict, model_ckpt_path)
627627

628628
def load_state_dir(self, load_dir, checkpoint_engine, strict=True):

deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
"""
66
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
77
"""
8-
import torch
8+
99
from deepspeed.ops.op_builder import AsyncIOBuilder
1010
from deepspeed import comm as dist
11+
import torch
1112

1213
from deepspeed.runtime.swap_tensor.constants import *
1314
from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, print_object
@@ -185,7 +186,6 @@ def _swap_out_optimizer_state(self, aio_handle, parameter, swap_in_op):
185186
for pinned_dst, unpinned_src in zip(new_alloc_buffers, unpinned_tensors):
186187
dst = get_sized_buffer(pinned_dst, unpinned_src.numel())
187188
dst.data.copy_(unpinned_src.data)
188-
unpinned_src.data = torch.Tensor()
189189

190190
swap_paths = param_info.swap_paths.copy()
191191
assert len(swap_paths) == len(swap_buffers)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
---
2+
title: "DataStates-LLM Checkpointing Engine"
3+
tags: asynchronous checkpointing for minimizing I/O overheads.
4+
---
5+
This tutorial will show how to use [DataStates-LLM](https://github.com/DataStates/datastates-llm) for asynchronous checkpointing. DataStates-LLM introduces a lazy asynchronous checkpointing mechanism tailored for LLMs, aiming to minimize I/O overhead and enhance training efficiency. This tutorial provides a guide on integrating DataStates-LLM with the DeepSpeed framework.
6+
7+
## Overview of DataStates-LLM
8+
9+
DataStates-LLM is designed to address the challenges of frequent checkpointing in LLM training by introducing a lazy asynchronous multi-level approach. It leverages the immutability of model parameters and optimizer states during forward and backward passes to perform non-blocking data transfers, thereby reducing interference with the training process. This method has demonstrated up to 48x faster checkpointing and 2.2x faster end-to-end training times compared to traditional approaches as outlined in [DataStates-LLM: Lazy Asynchronous Checkpointing for Large Language Models](https://arxiv.org/abs/2406.10707).
10+
11+
## Prerequisites
12+
13+
Before integrating DataStates-LLM with DeepSpeed, ensure the following:
14+
15+
- **DeepSpeed Installation**: DeepSpeed should be installed in your environment. If not, refer to the [DeepSpeed Getting Started Guide](https://github.com/microsoft/DeepSpeed/blob/master/docs/_tutorials/getting-started.md) for installation instructions.
16+
17+
- **DataStates-LLM Repository**: Access the DataStates-LLM source code from its [GitHub repository](https://github.com/DataStates/datastates-llm) and follow the installation instructions provided therein.
18+
19+
## Configuring DeepSpeed for DataStates-LLM
20+
21+
To enable DataStates-LLM's asynchronous checkpointing within DeepSpeed, please modify the `deepspeed_config.json` file to include specific configurations under the `datastates_ckpt` section. Below is an example configuration:
22+
23+
```json
24+
{
25+
// ... other DeepSpeed configuration options
26+
"datastates_ckpt": {
27+
"host_cache_size": 16,
28+
"parser_threads": 8
29+
}
30+
}
31+
```
32+
33+
### Configuration Parameters
34+
35+
- **`host_cache_size`**: Specifies the amount of pinned host memory (in gigabytes) reserved for asynchronous checkpoint flushing. Adjust this value based on your system's memory capacity and the size of your model checkpoints.
36+
37+
- **`parser_threads`**: Determines the number of threads dedicated to parsing checkpoint file requests in parallel. Increasing this value can enhance parsing throughput but may also increase CPU utilization.
38+
39+
## Implementing DataStates-LLM in Your Training Script
40+
41+
After enabling datastates checkpointing the `deepspeed_config.json`, the frequency of checkpointing can be configured by specifying the number of iterations after which the checkpoints should be captured using command-line parameter ` --save-interval`.
42+
43+
## Performance Results
44+
45+
The checkpoint acceleration achieved by DataStates-LLM for various models are shown in
46+
47+
![Higher checkpointing throughput](/assets/images/datastates-async-checkpointing/diff-models-ckpt-throughput.png){: .align-center}
48+
49+
![Faster training iterations](/assets/images/datastates-async-checkpointing/diff-models-iter-times.png){: .align-center}
50+
51+
52+
## Limitations and Ongoing Work
53+
54+
1. DataStates-LLM currently only supports the CUDA runtime on Nvidia-based GPUs.
55+
56+
57+
2. DataStates-LLM has only been tested with ZeRO stage-1 without offloading to any other tiers.
58+
59+
60+
3. While the checkpoint layout of datastates matches Huggingface's [safetensor](https://huggingface.co/docs/safetensors/) format, due to pickled objects required by DeepSpeed during restart, it is not fully compatible with safetensor library yet.
61+
62+
4. DataStates-LLM does not yet support universal or elastic checkpointing.
63+
64+
65+
## Questions and Support
66+
67+
Please use the [DataStates-LLM Github repository](https://github.com/DataStates/datastates-llm) for any questions, issues, or feature requests.
Loading
Loading

0 commit comments

Comments
 (0)