Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] implement record_stream when using CUDA streams during group offloading #11081

Merged
merged 17 commits into from
Apr 8, 2025
Merged
66 changes: 55 additions & 11 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
onload_self: bool = True,
) -> None:
Expand All @@ -68,36 +69,61 @@ def __init__(
self.buffers = buffers
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.record_stream = record_stream
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self

if self.stream is not None and self.cpu_param_dict is None:
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
raise ValueError("`cpu_param_dict` must be provided when using stream for data transfer.")

if self.record_stream and not self.stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.")

def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
current_stream = torch.cuda.current_stream() if self.record_stream else None

if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()

with context:
for group_module in self.modules:
group_module.to(self.onload_device, non_blocking=self.non_blocking)
for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
Copy link
Member

@a-r-r-o-w a-r-r-o-w Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current_stream should be self.stream here I think. We need to tell pytorch that the param.data and buffer.data here is owned by the non-default stream. Currently, we're telling it that it is owned by the default stream, which seems incorrect to me

Sorry for the back and forth but I think we will have to run the benchmark once more with the change 😅

Apart from this, everything else looks good. We can button up the docs and merge after @DN6 gives a look. Let's make sure to mention that this may use more memory in comparison to record_stream=False in certain cases and link to the torch::Tensor::record_stream docs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nevermind, ignore my comment.

This method is most suitable for use cases where you are providing a function that created a tensor on a side stream, and want users to be able to make use of the tensor without having to think carefully about stream safety when making use of them.

We don't create anything on the non-default stream, so torch.cuda.current_stream is correct

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a-r-r-o-w no problem. Better to be rigorous and double-check everything.

Just ran the benchmark after merging main into this branch (on DGX on a single 80GB A100):

{"record_stream": false, "memory": "1.3514", "time": "32.792"}
{"record_stream": true, "memory": "1.3514", "time": "30.944"}

Feel free to run it yourself if you want.

Let's make sure to mention that this may use more memory in comparison to record_stream=False in certain cases and link to the torch::Tensor::record_stream docs

Absolutely. I will mention in the docstrings. Is there any other place you wanted me to mention it?

for buffer in group_module.buffers():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)

if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)

def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
if not self.record_stream:
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if self.parameters is not None:
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
Expand Down Expand Up @@ -268,6 +294,7 @@ def apply_group_offloading(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
record_stream: bool = False,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
Expand Down Expand Up @@ -314,6 +341,7 @@ def apply_group_offloading(
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
record_stream: TODO

Example:
```python
Expand Down Expand Up @@ -349,10 +377,10 @@ def apply_group_offloading(
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")

_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, record_stream
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream, record_stream)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")

Expand All @@ -364,6 +392,7 @@ def _apply_group_offloading_block_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
Expand All @@ -382,14 +411,13 @@ def _apply_group_offloading_block_level(
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream: TODO
"""

# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
cpu_param_dict = _get_pinned_cpu_param_dict(module)

# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
Expand All @@ -411,6 +439,7 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
cpu_param_dict=cpu_param_dict,
onload_self=stream is None,
)
Expand Down Expand Up @@ -448,6 +477,7 @@ def _apply_group_offloading_block_level(
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
cpu_param_dict=None,
onload_self=True,
)
Expand All @@ -461,6 +491,7 @@ def _apply_group_offloading_leaf_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
Expand All @@ -481,14 +512,13 @@ def _apply_group_offloading_leaf_level(
stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream: TODO
"""

# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
cpu_param_dict = _get_pinned_cpu_param_dict(module)

# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
Expand All @@ -503,6 +533,7 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
Expand Down Expand Up @@ -548,6 +579,7 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
Expand All @@ -567,6 +599,7 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
record_stream=False,
cpu_param_dict=None,
onload_self=True,
)
Expand Down Expand Up @@ -604,6 +637,17 @@ def _apply_lazy_group_offloading_hook(
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)


def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
cpu_param_dict = {}
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict[param] = param.data
for buffer in module.buffers():
buffer.data = buffer.data.cpu().pin_memory()
cpu_param_dict[buffer] = buffer.data
return cpu_param_dict


def _gather_parameters_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.nn.Parameter]:
Expand Down
10 changes: 9 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def enable_group_offload(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
record_stream: bool = False,
) -> None:
r"""
Activates group offloading for the current model.
Expand Down Expand Up @@ -584,7 +585,14 @@ def enable_group_offload(
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
self,
onload_device,
offload_device,
offload_type,
num_blocks_per_group,
non_blocking,
use_stream,
record_stream,
)

def save_pretrained(
Expand Down
Loading