Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ZeRO set/get APIs for NVMe offload #7046

Open
wants to merge 44 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
ad5f833
Control trace cache warnings
tjruwase Feb 15, 2025
9c4a177
Update docs
tjruwase Feb 15, 2025
8c25621
Merge branch 'master' into olruwase/control_trace_cache_warnings
tjruwase Feb 15, 2025
4fd1b05
Enable safe_get/set APIs for NVMe offload
tjruwase Feb 17, 2025
6fd12e1
Formatting fixes
tjruwase Feb 17, 2025
e23bfab
Add vectorized update API
tjruwase Feb 18, 2025
de6d8b1
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 19, 2025
7e250d9
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 19, 2025
76da050
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Feb 21, 2025
cc6ed24
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 21, 2025
f9ecab7
PR feedback
tjruwase Feb 25, 2025
48e5ad7
PR feedback
tjruwase Feb 25, 2025
f20abc1
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Feb 25, 2025
28ba8af
Code cleanup
tjruwase Feb 25, 2025
3872984
Merge branch 'olruwase/update_nvme_offload_states' of github.com:deep…
tjruwase Feb 25, 2025
8bc000c
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Feb 26, 2025
6ac306a
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 27, 2025
f86e3ca
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 27, 2025
6c1ba6e
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Feb 28, 2025
17935e9
Handle offload_states
tjruwase Feb 28, 2025
61685dc
Use new dlpack api; Formatting fixes
tjruwase Mar 3, 2025
1667758
Merge branch 'olruwase/new_dlpack_api' of github.com:deepspeedai/Deep…
tjruwase Mar 3, 2025
66b40ce
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 3, 2025
5a215cc
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 4, 2025
797bb15
Revert change
tjruwase Mar 4, 2025
f033827
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 5, 2025
044db61
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 7, 2025
b0f1391
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 10, 2025
0099333
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 11, 2025
e203e8f
Add -x to test failure/debug
loadams Mar 11, 2025
d6c2999
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 12, 2025
03765b8
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 17, 2025
9ecbce8
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 17, 2025
84f5467
add destroy to tests
tohtana Mar 20, 2025
5897a8b
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 21, 2025
53f999e
Merge branch 'master' into tohtana/destroy_model_test_zero
tohtana Mar 21, 2025
e070b58
Debug CI
tjruwase Mar 24, 2025
a8991c4
Merge branch 'tohtana/destroy_model_test_zero' of github.com:deepspee…
tjruwase Mar 24, 2025
5991d43
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 25, 2025
573ac52
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 27, 2025
1274b2d
Merge branch 'master' into olruwase/update_nvme_offload_states
loadams Mar 27, 2025
0305597
Merge branch 'master' into olruwase/update_nvme_offload_states
tjruwase Mar 31, 2025
81d09a3
Quiet CI
tjruwase Mar 31, 2025
89c1233
Merge branch 'olruwase/update_nvme_offload_states' of github.com:deep…
tjruwase Mar 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
pytest $PYTEST_OPTS --forked -n 8 unit/ --torch_ver="2.6" --cuda_ver="12.4"
pytest -x $PYTEST_OPTS --forked -n 8 unit/ --torch_ver="2.6" --cuda_ver="12.4"
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.6" --cuda_ver="12.4"
1 change: 1 addition & 0 deletions deepspeed/runtime/swap_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from .utils import MIN_SWAPPABLE_BYTES
84 changes: 65 additions & 19 deletions deepspeed/runtime/swap_tensor/optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,78 @@ def __init__(self, path, length, offset):
self.length = length


class SwapTensorContext(object):

def __init__(self, tensor, swap_folder):
self.compute_tensor = tensor
self.swap_tensor = torch.Tensor()
self.swap_path = os.path.join(swap_folder, f'{OptimizerSwapper.parameter_id(tensor)}.tensor.swp')

def release_memory(self):
self.compute_tensor.data = torch.Tensor()
self.swap_tensor.data = torch.Tensor()

def set_buffers(self, compute_buffer, swap_buffer):
self.compute_tensor.data = compute_buffer.data
self.swap_tensor.data = swap_buffer.data


class OptimizerStateSwapInfo(object):

def __init__(self, parameter, numel, base_folder):
self.tensors = []
self.param_id = OptimizerSwapper.parameter_id(parameter)
self.swap_folder = base_folder
self.swap_paths = []
self.swapped_gradients = {}
self.unswapped_gradients = {}
self.tensor_numel = numel
self.tensor_dtype = parameter.dtype
self.tensor_device = parameter.device
self.has_state_tensors = False
self.swap_buffers = []
self._add_tensors([parameter])

def numel(self):
return self.tensor_numel

def has_gradients(self):
return self.swapped_gradients or self.unswapped_gradients
return bool(self.swapped_gradients) or bool(self.unswapped_gradients)

def _add_tensors(self, tensor_list):
for t in tensor_list:
self.tensors.append(t)
self.swap_paths.append(os.path.join(self.swap_folder, f'{OptimizerSwapper.parameter_id(t)}.tensor.swp'))
self.tensors.append(SwapTensorContext(t, self.swap_folder))

def add_state_tensors(self, tensor_list):
self.has_state_tensors = True
self._add_tensors(tensor_list)

def num_tensors(self):
return len(self.tensors)

def device(self):
return self.tensor_device

def dtype(self):
return self.tensor_dtype

def release_memory(self):
for tensor in self.tensors:
tensor.data = torch.Tensor()
for t in self.tensors:
t.release_memory()

def get_compute_tensors(self):
return [t.compute_tensor for t in self.tensors]

def get_swap_paths(self):
return [t.swap_path for t in self.tensors]

def get_swap_buffers_and_paths(self, pinned):
swap_buffers = []
swap_paths = []
select_tensors = [t for t in self.tensors if get_accelerator().is_pinned(t.compute_tensor) == pinned]
for t in select_tensors:
swap_buffers.append(t.swap_tensor if pinned else t.compute_tensor)
swap_paths.append(t.swap_path)
return swap_buffers, swap_paths

def get_or_create_gradient_paths(self, offsets, lengths):
gradient_paths = []
Expand All @@ -77,11 +110,15 @@ def get_or_create_gradient_paths(self, offsets, lengths):

return gradient_paths

def set_swap_buffers(self, buffers):
compute_lengths = [self.numel()] * len(self.tensors)
def set_swap_buffers(self, buffers, aligned_numel):
num_tensors = len(self.tensors)
compute_lengths = [self.numel()] * num_tensors
compute_buffers = get_sized_buffers(buffers, compute_lengths)
for t, buffer in zip(self.tensors, compute_buffers):
t.data = buffer.data
swap_lengths = [aligned_numel] * num_tensors
swap_buffers = get_sized_buffers(buffers, swap_lengths)

for i, t in enumerate(self.tensors):
t.set_buffers(compute_buffer=compute_buffers[i], swap_buffer=swap_buffers[i])

def get_swap_gradient_buffers(self, swap_buffer):
assert self.numel() <= swap_buffer.numel()
Expand All @@ -91,7 +128,7 @@ def get_swap_gradient_paths(self):
return [grad.path for grad in self.swapped_gradients.values()]

def get_unpinned_state_tensors(self):
return [t for t in self.tensors if not get_accelerator().is_pinned(t)]
return [t.compute_tensor for t in self.tensors if not get_accelerator().is_pinned(t.compute_tensor)]

def read_unswapped_gradients(self, dest_buffer):
num_elem_count = 0
Expand All @@ -102,6 +139,15 @@ def read_unswapped_gradients(self, dest_buffer):

return num_elem_count

def write_unswapped_gradients(self, src_buffer):
num_elem_count = 0
for offset, grad_partition in self.unswapped_gradients.items():
src_tensor = src_buffer.narrow(0, offset, grad_partition.numel())
grad_partition.data.copy_(src_tensor.data)
num_elem_count += grad_partition.numel()

return num_elem_count

def release_unswapped_gradients(self):
self.unswapped_gradients = {}

Expand Down Expand Up @@ -158,10 +204,10 @@ def purge_state(self):
swap_info.tensors = [swap_info.tensors[0]]
swap_info.has_state_tensors = False

def swappable_tensor(self, param=None, numel=None):
assert param is not None or numel is not None, "Either param or numel must be provided"
if param is not None:
return self.min_aio_bytes <= (param.numel() * self.swap_element_size)
def is_swappable_tensor(self, tensor=None, numel=None):
assert tensor is not None or numel is not None, "Either tensor or numel must be provided"
if tensor is not None:
return self.min_aio_bytes <= (tensor.numel() * self.swap_element_size)
return self.min_aio_bytes <= (numel * self.swap_element_size)

def init_timers(self):
Expand Down Expand Up @@ -201,7 +247,7 @@ def _swap_out_gradients(self, parameter, gradient_offsets, gradient_tensors, gra

self._start_timer(SWAP_OUT_GRADIENT_TIMER)
for tensor, offset in zip(aligned_gradients, aligned_offsets):
if not self.swappable_tensor(param=tensor):
if not self.is_swappable_tensor(tensor=tensor):
swap_info.unswapped_gradients[offset] = tensor
continue

Expand Down Expand Up @@ -355,7 +401,7 @@ def _get_swap_paths(self, parameters, num_elems):
]
assert len(swap_info_list) == len(num_elems)

swap_paths = [info.swap_paths[0] for info in swap_info_list]
swap_paths = [info.tensors[0].swap_path for info in swap_info_list]
return swap_paths

def _swap_out_unpinned_tensors(self, aio_handle, unpinned_tensors, dest_paths, pinned_buffers):
Expand Down Expand Up @@ -386,7 +432,7 @@ def _adjust_for_misaligned_lengths(self, tensors, offsets):
new_offsets = []

for orig_tensor, orig_offset in zip(tensors, offsets):
if not self.swappable_tensor(param=orig_tensor):
if not self.is_swappable_tensor(tensor=orig_tensor):
new_tensors.append(orig_tensor)
new_offsets.append(orig_offset)
continue
Expand Down Expand Up @@ -430,7 +476,7 @@ def _get_state_tensors(self, parameter):

tensor_list = []
for state_name, value in self.optimizer.state[parameter].items():
if torch.is_tensor(value):
if torch.is_tensor(value) and self.is_swappable_tensor(tensor=value):
value.ds_id = state_name + '-' + parameter.ds_id
tensor_list.append(value)

Expand Down
114 changes: 61 additions & 53 deletions deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
Functionality of swapping optimizer tensors to/from (NVMe) storage devices.
"""

import torch

from deepspeed.utils.logging import logger
from deepspeed.ops.op_builder import AsyncIOBuilder
from deepspeed import comm as dist
Expand Down Expand Up @@ -63,71 +61,98 @@ def initialize_from_swapped_fp16_params(self, fp16_partitions_info, fp16_num_ele
def flush_gradients(self):
self._flush_gradient_swapper(self.gradient_swapper)

def release_swap_buffers(self, parameter):
swap_info = self._get_param_swap_info(parameter)
if swap_info is None:
return
swap_info.release_memory()

self.swap_buffer_manager.free(swap_info.swap_buffers)
swap_info.swap_buffers = []

def swap_in_optimizer_state(self, parameter, async_parameter=None):
swap_info = self._get_param_swap_info(parameter)
if swap_info is None:
return

self._flush_gradient_swapper(self.gradient_swapper)

required_buffer_count = len(swap_info.tensors) + (1 if swap_info.has_gradients() else 0)
required_buffer_count = swap_info.num_tensors() + (1 if swap_info.has_gradients() else 0)
aligned_numel = self._io_aligned_numel(swap_info.numel())
pinned_buffers = self.swap_buffer_manager.allocate(num_elems=aligned_numel,
count=required_buffer_count,
dtype=parameter.dtype)
assert pinned_buffers is not None
self.allocated_swap_buffers = pinned_buffers.copy()
swap_info.swap_buffers = pinned_buffers.copy()

self._start_timer(SWAP_IN_PARAM_TIMER)
self._swap_in_parameter(aio_handle=self.aio_handle,
parameter=parameter,
dest_buffers=pinned_buffers[:required_buffer_count])
dest_buffers=pinned_buffers[:swap_info.num_tensors()])
self._stop_timer(SWAP_IN_PARAM_TIMER)
self.timer_names.add(SWAP_IN_PARAM_TIMER)

self._start_timer(SWAP_IN_GRADIENT_TIMER)
self._swap_in_gradients(aio_handle=self.aio_handle, parameter=parameter, dest_buffer=pinned_buffers[-1])
self._stop_timer(SWAP_IN_GRADIENT_TIMER)
self.timer_names.add(SWAP_IN_GRADIENT_TIMER)

def swap_out_optimizer_state(self, parameter, async_swap=False):
swap_info = self._get_param_swap_info(parameter=parameter)

if swap_info is None:
return

self._start_timer(SWAP_OUT_PARAM_TIMER)
pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths = self._separate_pinned_tensors(swap_info)
swap_bytes = sum([self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.tensors])
if swap_info.has_gradients():
self._start_timer(SWAP_IN_GRADIENT_TIMER)
self._swap_in_gradients(aio_handle=self.aio_handle, parameter=parameter, dest_buffer=pinned_buffers[-1])
self._stop_timer(SWAP_IN_GRADIENT_TIMER)
self.timer_names.add(SWAP_IN_GRADIENT_TIMER)

def _swap_out_optimizer_state(self, swap_info):
pinned_tensors, pinned_paths = swap_info.get_swap_buffers_and_paths(True)
WRITE_TIMER = 'swap_submit_write'
self._start_timer(WRITE_TIMER)

swap_out_tensors(self.aio_handle, pinned_tensors, pinned_paths)
assert self.aio_handle.wait() == len(pinned_tensors)
for t in pinned_tensors:
t.data = torch.Tensor()

unpinned_tensors, unpinned_paths = swap_info.get_swap_buffers_and_paths(False)
if len(unpinned_tensors) > 0:
pinned_buffers = self.swap_buffer_manager.allocate_all(num_elems=self.largest_numel, dtype=self.dtype)
self._swap_out_unpinned_tensors(aio_handle=self.aio_handle,
unpinned_tensors=unpinned_tensors,
dest_paths=unpinned_paths,
pinned_buffers=pinned_buffers)
self.allocated_swap_buffers += pinned_buffers
swap_info.swap_buffers += pinned_buffers.copy()

for t in unpinned_tensors:
t.data = torch.Tensor()
self._stop_timer(WRITE_TIMER)
self._log_timers([WRITE_TIMER])

def writeback_optimizer_state_and_gradients(self, parameter, write_opt_state, write_gradients):
swap_info = self._get_param_swap_info(parameter=parameter)

if swap_info is None:
return

self.swap_buffer_manager.free(self.allocated_swap_buffers)
self.allocated_swap_buffers = []
if write_opt_state:
self._swap_out_optimizer_state(swap_info)

if write_gradients and swap_info.has_gradients():
param_gradients = swap_info.swapped_gradients.values()
swap_buffers = [parameter.grad.narrow(0, grad.offset, grad.length) for grad in param_gradients]
swap_paths = [grad.path for grad in param_gradients]
swap_out_tensors(self.aio_handle, swap_buffers, swap_paths)
assert len(swap_buffers) == self.aio_handle.wait()
if swap_info.unswapped_gradients:
swap_info.write_unswapped_gradients(src_buffer=parameter.grad)

self.release_swap_buffers(parameter)

def swap_out_optimizer_state(self, parameter, async_swap=False):
swap_info = self._get_param_swap_info(parameter=parameter)

if swap_info is None:
return

swap_bytes = sum(
[self._io_aligned_numel(t.numel()) * t.element_size() for t in swap_info.get_compute_tensors()])

self._start_timer(SWAP_OUT_PARAM_TIMER)
self._swap_out_optimizer_state(swap_info)
self.release_swap_buffers(parameter)
self._stop_timer(SWAP_OUT_PARAM_TIMER)
self.timer_names.add(SWAP_OUT_PARAM_TIMER)

self._log_timers([WRITE_TIMER])

if DEBUG_MODE and dist.get_rank() == 0:
logger.info(f'optimizer_param_swap_out: {(swap_bytes/(1024**3)):5.2f} GB')

Expand All @@ -142,16 +167,20 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
if swap_info is None:
return

assert len(swap_info.tensors) <= len(dest_buffers)
num_swap_tensors = swap_info.num_tensors()
assert num_swap_tensors <= len(dest_buffers)

swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(swap_info.tensors)
swap_lengths = [self._io_aligned_numel(swap_info.numel())] * num_swap_tensors
swap_buffers = get_sized_buffers(dest_buffers, swap_lengths)

compute_lengths = [swap_info.numel()] * num_swap_tensors
compute_buffers = get_sized_buffers(dest_buffers, compute_lengths)

READ_TIMER = 'swap_submit_read_param'
WAIT_TIMER = 'swap_wait_read_param'

self._start_timer(READ_TIMER)
swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths)
swap_in_tensors(aio_handle, swap_buffers, swap_info.get_swap_paths())
self._stop_timer(READ_TIMER)

swap_bytes = sum([buffer.numel() * buffer.element_size() for buffer in swap_buffers])
Expand All @@ -160,40 +189,19 @@ def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
aio_handle.wait()
self._stop_timer(WAIT_TIMER)

compute_lengths = [swap_info.numel()] * len(swap_info.tensors)
compute_buffers = get_sized_buffers(dest_buffers, compute_lengths)
for t, buffer in zip(swap_info.tensors, compute_buffers):
t.data = buffer.data
swap_info.set_swap_buffers(dest_buffers, self._io_aligned_numel(swap_info.numel()))

self._log_timers([READ_TIMER, WAIT_TIMER])
if DEBUG_MODE and dist.get_rank() == 0:
logger.info(f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')

def _separate_pinned_tensors(self, swap_info):
pinned_tensors = []
pinned_paths = []

unpinned_tensors = []
unpinned_paths = []

for tensor, path in zip(swap_info.tensors, swap_info.swap_paths):
if get_accelerator().is_pinned(tensor):
pinned_tensors.append(tensor)
pinned_paths.append(path)
else:
unpinned_tensors.append(tensor)
unpinned_paths.append(path)

return pinned_tensors, pinned_paths, unpinned_tensors, unpinned_paths

def _swap_in_pinned_gradients(self, aio_handle, parameter, gradient_tensor):
swap_info = self.swap_params_info[OptimizerSwapper.parameter_id(parameter)]
param_gradients = swap_info.swapped_gradients.values()
swap_buffers = [gradient_tensor.narrow(0, grad.offset, grad.length) for grad in param_gradients]
swap_paths = [grad.path for grad in param_gradients]
SWAP_READ_GRADIENTS = 'swap_submit_read_gradient'
SWAP_WAIT_GRADIENTS = 'swap_submit_wait_gradient'

self._start_timer(SWAP_READ_GRADIENTS)
swap_in_tensors(aio_handle, swap_buffers, swap_paths)
self._stop_timer(SWAP_READ_GRADIENTS)
Expand Down
Loading
Loading