From ffce2d199c76de6d17c2b373c0ffa87aa11c65f4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 17 Mar 2025 11:20:25 +0530 Subject: [PATCH 01/10] implement record_stream for better performance. --- src/diffusers/hooks/group_offloading.py | 33 ++++++++++++++++++++++--- src/diffusers/models/modeling_utils.py | 3 ++- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index c389c5dc9826..cb915b0fc3a8 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -56,6 +56,7 @@ def __init__( buffers: Optional[List[torch.Tensor]] = None, non_blocking: bool = False, stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False, cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, onload_self: bool = True, ) -> None: @@ -68,15 +69,21 @@ def __init__( self.buffers = buffers self.non_blocking = non_blocking or stream is not None self.stream = stream + self.record_stream = record_stream self.cpu_param_dict = cpu_param_dict self.onload_self = onload_self if self.stream is not None and self.cpu_param_dict is None: - raise ValueError("cpu_param_dict must be provided when using stream for data transfer.") + raise ValueError("`cpu_param_dict` must be provided when using stream for data transfer.") + + if self.record_stream and not self.stream: + raise ValueError("`record_stream` cannot be True when `stream` is None.") def onload_(self): r"""Onloads the group of modules to the onload_device.""" context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) + current_stream = torch.cuda.current_stream() if self.record_stream else None + if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -84,16 +91,23 @@ def onload_(self): with context: for group_module in self.modules: group_module.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + for param in group_module.parameters(): + param.data.record_stream(current_stream) if self.parameters is not None: for param in self.parameters: param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + param.data.record_stream(current_stream) if self.buffers is not None: for buffer in self.buffers: buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) def offload_(self): r"""Offloads the group of modules to the offload_device.""" - if self.stream is not None: + if self.stream is not None and not self.record_stream: torch.cuda.current_stream().synchronize() for group_module in self.modules: for param in group_module.parameters(): @@ -268,6 +282,7 @@ def apply_group_offloading( num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, + record_stream: bool = False ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -314,6 +329,7 @@ def apply_group_offloading( use_stream (`bool`, defaults to `False`): If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for overlapping computation and data transfer. + record_stream: TODO Example: ```python @@ -349,10 +365,10 @@ def apply_group_offloading( raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") _apply_group_offloading_block_level( - module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream + module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, record_stream ) elif offload_type == "leaf_level": - _apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream) + _apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream, record_stream) else: raise ValueError(f"Unsupported offload_type: {offload_type}") @@ -364,6 +380,7 @@ def _apply_group_offloading_block_level( onload_device: torch.device, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False ) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to @@ -382,6 +399,7 @@ def _apply_group_offloading_block_level( stream (`torch.cuda.Stream`, *optional*): If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful for overlapping computation and data transfer. + record_stream: TODO """ # Create a pinned CPU parameter dict for async data transfer if streams are to be used @@ -411,6 +429,7 @@ def _apply_group_offloading_block_level( onload_leader=current_modules[0], non_blocking=non_blocking, stream=stream, + record_stream=record_stream, cpu_param_dict=cpu_param_dict, onload_self=stream is None, ) @@ -448,6 +467,7 @@ def _apply_group_offloading_block_level( buffers=buffers, non_blocking=False, stream=None, + record_stream=False, cpu_param_dict=None, onload_self=True, ) @@ -461,6 +481,7 @@ def _apply_group_offloading_leaf_level( onload_device: torch.device, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False ) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory @@ -481,6 +502,7 @@ def _apply_group_offloading_leaf_level( stream (`torch.cuda.Stream`, *optional*): If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful for overlapping computation and data transfer. + record_stream: TODO """ # Create a pinned CPU parameter dict for async data transfer if streams are to be used @@ -503,6 +525,7 @@ def _apply_group_offloading_leaf_level( onload_leader=submodule, non_blocking=non_blocking, stream=stream, + record_stream=record_stream, cpu_param_dict=cpu_param_dict, onload_self=True, ) @@ -548,6 +571,7 @@ def _apply_group_offloading_leaf_level( buffers=buffers, non_blocking=non_blocking, stream=stream, + record_stream=record_stream, cpu_param_dict=cpu_param_dict, onload_self=True, ) @@ -567,6 +591,7 @@ def _apply_group_offloading_leaf_level( buffers=None, non_blocking=False, stream=None, + record_stream=False, cpu_param_dict=None, onload_self=True, ) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 6983940f139b..443500f5bb76 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -546,6 +546,7 @@ def enable_group_offload( num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, + record_stream: bool = False, ) -> None: r""" Activates group offloading for the current model. @@ -584,7 +585,7 @@ def enable_group_offload( f"open an issue at https://github.com/huggingface/diffusers/issues." ) apply_group_offloading( - self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream + self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream, record_stream ) def save_pretrained( From f25ea18cf9be4943ac0151ef077d432926c505f2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 17 Mar 2025 11:50:18 +0530 Subject: [PATCH 02/10] fix --- src/diffusers/hooks/group_offloading.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index cb915b0fc3a8..0d1680b904ff 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -107,8 +107,9 @@ def onload_(self): def offload_(self): r"""Offloads the group of modules to the offload_device.""" - if self.stream is not None and not self.record_stream: - torch.cuda.current_stream().synchronize() + if self.stream is not None: + if not self.record_stream: + torch.cuda.current_stream().synchronize() for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] From 2a28f6df88a82b8afabb3e8047f32657ffadca4f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 17 Mar 2025 12:25:43 +0530 Subject: [PATCH 03/10] style. --- src/diffusers/hooks/group_offloading.py | 10 +++++----- src/diffusers/models/modeling_utils.py | 9 ++++++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 0d1680b904ff..5c1bee08d984 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -75,7 +75,7 @@ def __init__( if self.stream is not None and self.cpu_param_dict is None: raise ValueError("`cpu_param_dict` must be provided when using stream for data transfer.") - + if self.record_stream and not self.stream: raise ValueError("`record_stream` cannot be True when `stream` is None.") @@ -83,7 +83,7 @@ def onload_(self): r"""Onloads the group of modules to the onload_device.""" context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) current_stream = torch.cuda.current_stream() if self.record_stream else None - + if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -283,7 +283,7 @@ def apply_group_offloading( num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, - record_stream: bool = False + record_stream: bool = False, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -381,7 +381,7 @@ def _apply_group_offloading_block_level( onload_device: torch.device, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, - record_stream: Optional[bool] = False + record_stream: Optional[bool] = False, ) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to @@ -482,7 +482,7 @@ def _apply_group_offloading_leaf_level( onload_device: torch.device, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, - record_stream: Optional[bool] = False + record_stream: Optional[bool] = False, ) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 443500f5bb76..7c6a2c1e836e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -585,7 +585,14 @@ def enable_group_offload( f"open an issue at https://github.com/huggingface/diffusers/issues." ) apply_group_offloading( - self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream, record_stream + self, + onload_device, + offload_device, + offload_type, + num_blocks_per_group, + non_blocking, + use_stream, + record_stream, ) def save_pretrained( From 41ea4c83ce9bdb6f6c4b94334067c5c67a394e11 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Mar 2025 01:39:22 +0000 Subject: [PATCH 04/10] merge #11097 --- src/diffusers/hooks/group_offloading.py | 36 ++++++++++++++++++------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 5c1bee08d984..e7f77ca35a23 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -90,10 +90,15 @@ def onload_(self): with context: for group_module in self.modules: - group_module.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - for param in group_module.parameters(): + for param in group_module.parameters(): + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: param.data.record_stream(current_stream) + for buffer in group_module.buffers(): + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) + if self.parameters is not None: for param in self.parameters: param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) @@ -113,6 +118,12 @@ def offload_(self): for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] + if self.parameters is not None: + for param in self.parameters: + param.data = self.cpu_param_dict[param] + if self.buffers is not None: + for buffer in self.buffers: + buffer.data = self.cpu_param_dict[buffer] else: for group_module in self.modules: group_module.to(self.offload_device, non_blocking=self.non_blocking) @@ -406,9 +417,7 @@ def _apply_group_offloading_block_level( # Create a pinned CPU parameter dict for async data transfer if streams are to be used cpu_param_dict = None if stream is not None: - for param in module.parameters(): - param.data = param.data.cpu().pin_memory() - cpu_param_dict = {param: param.data for param in module.parameters()} + cpu_param_dict = _get_pinned_cpu_param_dict(module) # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() @@ -509,9 +518,7 @@ def _apply_group_offloading_leaf_level( # Create a pinned CPU parameter dict for async data transfer if streams are to be used cpu_param_dict = None if stream is not None: - for param in module.parameters(): - param.data = param.data.cpu().pin_memory() - cpu_param_dict = {param: param.data for param in module.parameters()} + cpu_param_dict = _get_pinned_cpu_param_dict(module) # Create module groups for leaf modules and apply group offloading hooks modules_with_group_offloading = set() @@ -630,6 +637,17 @@ def _apply_lazy_group_offloading_hook( registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) +def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]: + cpu_param_dict = {} + for param in module.parameters(): + param.data = param.data.cpu().pin_memory() + cpu_param_dict[param] = param.data + for buffer in module.buffers(): + buffer.data = buffer.data.cpu().pin_memory() + cpu_param_dict[buffer] = buffer.data + return cpu_param_dict + + def _gather_parameters_with_no_group_offloading_parent( module: torch.nn.Module, modules_with_group_offloading: Set[str] ) -> List[torch.nn.Parameter]: From 9281e84a9ba4bd0fe03a6b2e1ddc4703add1a33c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Mar 2025 16:17:11 +0530 Subject: [PATCH 05/10] Update src/diffusers/hooks/group_offloading.py Co-authored-by: Aryan --- src/diffusers/hooks/group_offloading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 6503a581bf29..b9a9005f5130 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -76,7 +76,7 @@ def __init__( if self.stream is not None and self.cpu_param_dict is None: raise ValueError("`cpu_param_dict` must be provided when using stream for data transfer.") - if self.record_stream and not self.stream: + if self.stream is None and self.record_stream: raise ValueError("`record_stream` cannot be True when `stream` is None.") def onload_(self): From d5afea563b52771575032da79e20a9838c2e0f14 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 27 Mar 2025 18:05:01 +0100 Subject: [PATCH 06/10] fixes --- src/diffusers/hooks/group_offloading.py | 26 +++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 639c0140aee5..a6d4b19b4be8 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -436,18 +436,24 @@ def apply_group_offloading( raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") _apply_group_offloading_block_level( - module, - num_blocks_per_group, - offload_device, - onload_device, - non_blocking, - stream, - record_stream, - low_cpu_mem_usage, + module=module, + num_blocks_per_group=num_blocks_per_group, + offload_device=offload_device, + onload_device=onload_device, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, ) elif offload_type == "leaf_level": _apply_group_offloading_leaf_level( - module, offload_device, onload_device, non_blocking, stream, record_stream, low_cpu_mem_usage + module=module, + offload_device=offload_device, + onload_device=onload_device, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, ) else: raise ValueError(f"Unsupported offload_type: {offload_type}") @@ -660,7 +666,7 @@ def _apply_group_offloading_leaf_level( buffers=None, non_blocking=False, stream=None, - record_stream=record_stream, + record_stream=False, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, ) From 87a93fedd3a9ef5fe608d1ebd518759198c2a2b1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Apr 2025 17:48:25 +0530 Subject: [PATCH 07/10] docstring. --- src/diffusers/hooks/group_offloading.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index a6d4b19b4be8..9c8c5fc1be7c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -396,7 +396,10 @@ def apply_group_offloading( use_stream (`bool`, defaults to `False`): If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for overlapping computation and data transfer. - record_stream: TODO + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. low_cpu_mem_usage (`bool`, defaults to `False`): If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when From 1d4ca615a7f3b221b34c6de200d4df422f4110db Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Apr 2025 17:50:18 +0530 Subject: [PATCH 08/10] remaining todos in low_cpu_mem_usage --- src/diffusers/hooks/group_offloading.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 9c8c5fc1be7c..ac6cf653641b 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -489,8 +489,14 @@ def _apply_group_offloading_block_level( stream (`torch.cuda.Stream`, *optional*): If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful for overlapping computation and data transfer. - record_stream: TODO - low_cpu_mem_usage: TODO + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. + low_cpu_mem_usage (`bool`, defaults to `False`): + If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This + option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when + the CPU memory is a bottleneck but may counteract the benefits of using streams. """ # Create module groups for ModuleList and Sequential blocks @@ -586,8 +592,14 @@ def _apply_group_offloading_leaf_level( stream (`torch.cuda.Stream`, *optional*): If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful for overlapping computation and data transfer. - record_stream: TODO - low_cpu_mem_usage: TODO + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. + low_cpu_mem_usage (`bool`, defaults to `False`): + If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This + option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when + the CPU memory is a bottleneck but may counteract the benefits of using streams. """ # Create module groups for leaf modules and apply group offloading hooks From 535dcd1b859684132929a3b3df100eaeaedc998d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Apr 2025 18:20:33 +0530 Subject: [PATCH 09/10] tests --- tests/models/test_modeling_common.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index d55ff6e62872..847677884a35 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1525,8 +1525,9 @@ def get_memory_usage(storage_dtype, compute_dtype): or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE ) + @parameterized.expand([False, True]) @require_torch_gpu - def test_group_offloading(self): + def test_group_offloading(self, record_stream): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() torch.manual_seed(0) @@ -1566,7 +1567,9 @@ def run_forward(model): torch.manual_seed(0) model = self.model_class(**init_dict) - model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True) + model.enable_group_offload( + torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream + ) output_with_group_offloading4 = run_forward(model) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) From 2ff9112c7da6ec8fefcf44ba646221b2cf27fa1c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Apr 2025 19:25:07 +0530 Subject: [PATCH 10/10] updates to docs. --- docs/source/en/optimization/memory.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index fd72957471c0..fc939477616f 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -178,6 +178,9 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch # We can utilize the enable_group_offload method for Diffusers model implementations pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) +# Uncomment the following to also allow recording the current streams. +# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True) + # For any other model implementations, the apply_group_offloading function can be used apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2) apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level") @@ -205,6 +208,7 @@ Group offloading (for CUDA devices with support for asynchronous data transfer s - The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html) - If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems. - The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading. +- When using `use_stream=True`, users can additionally specify `record_stream=True` to get better speedups at the expense of slightly increased memory usage. Refer to the [official PyTorch docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) to know more about this. For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].