diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index fd72957471c0..fc939477616f 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -178,6 +178,9 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch # We can utilize the enable_group_offload method for Diffusers model implementations pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) +# Uncomment the following to also allow recording the current streams. +# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True) + # For any other model implementations, the apply_group_offloading function can be used apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2) apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level") @@ -205,6 +208,7 @@ Group offloading (for CUDA devices with support for asynchronous data transfer s - The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html) - If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems. - The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading. +- When using `use_stream=True`, users can additionally specify `record_stream=True` to get better speedups at the expense of slightly increased memory usage. Refer to the [official PyTorch docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) to know more about this. For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`]. diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 4c1d354a0f59..ac6cf653641b 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -56,6 +56,7 @@ def __init__( buffers: Optional[List[torch.Tensor]] = None, non_blocking: bool = False, stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False, low_cpu_mem_usage=False, onload_self: bool = True, ) -> None: @@ -68,11 +69,14 @@ def __init__( self.buffers = buffers or [] self.non_blocking = non_blocking or stream is not None self.stream = stream + self.record_stream = record_stream self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage - self.cpu_param_dict = self._init_cpu_param_dict() + if self.stream is None and self.record_stream: + raise ValueError("`record_stream` cannot be True when `stream` is None.") + def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -112,6 +116,8 @@ def _pinned_memory_tensors(self): def onload_(self): r"""Onloads the group of modules to the onload_device.""" context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) + current_stream = torch.cuda.current_stream() if self.record_stream else None + if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -122,14 +128,22 @@ def onload_(self): for group_module in self.modules: for param in group_module.parameters(): param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + param.data.record_stream(current_stream) for buffer in group_module.buffers(): buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) for param in self.parameters: param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + param.data.record_stream(current_stream) for buffer in self.buffers: buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) else: for group_module in self.modules: @@ -143,11 +157,14 @@ def onload_(self): for buffer in self.buffers: buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) def offload_(self): r"""Offloads the group of modules to the offload_device.""" if self.stream is not None: - torch.cuda.current_stream().synchronize() + if not self.record_stream: + torch.cuda.current_stream().synchronize() for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] @@ -331,6 +348,7 @@ def apply_group_offloading( num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, + record_stream: bool = False, low_cpu_mem_usage: bool = False, ) -> None: r""" @@ -378,6 +396,10 @@ def apply_group_offloading( use_stream (`bool`, defaults to `False`): If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for overlapping computation and data transfer. + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. low_cpu_mem_usage (`bool`, defaults to `False`): If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when @@ -417,11 +439,24 @@ def apply_group_offloading( raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") _apply_group_offloading_block_level( - module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage + module=module, + num_blocks_per_group=num_blocks_per_group, + offload_device=offload_device, + onload_device=onload_device, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, ) elif offload_type == "leaf_level": _apply_group_offloading_leaf_level( - module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage + module=module, + offload_device=offload_device, + onload_device=onload_device, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, ) else: raise ValueError(f"Unsupported offload_type: {offload_type}") @@ -434,6 +469,7 @@ def _apply_group_offloading_block_level( onload_device: torch.device, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, ) -> None: r""" @@ -453,6 +489,14 @@ def _apply_group_offloading_block_level( stream (`torch.cuda.Stream`, *optional*): If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful for overlapping computation and data transfer. + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. + low_cpu_mem_usage (`bool`, defaults to `False`): + If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This + option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when + the CPU memory is a bottleneck but may counteract the benefits of using streams. """ # Create module groups for ModuleList and Sequential blocks @@ -475,6 +519,7 @@ def _apply_group_offloading_block_level( onload_leader=current_modules[0], non_blocking=non_blocking, stream=stream, + record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=stream is None, ) @@ -512,6 +557,7 @@ def _apply_group_offloading_block_level( buffers=buffers, non_blocking=False, stream=None, + record_stream=False, onload_self=True, ) next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None @@ -524,6 +570,7 @@ def _apply_group_offloading_leaf_level( onload_device: torch.device, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, ) -> None: r""" @@ -545,6 +592,14 @@ def _apply_group_offloading_leaf_level( stream (`torch.cuda.Stream`, *optional*): If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful for overlapping computation and data transfer. + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. + low_cpu_mem_usage (`bool`, defaults to `False`): + If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This + option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when + the CPU memory is a bottleneck but may counteract the benefits of using streams. """ # Create module groups for leaf modules and apply group offloading hooks @@ -560,6 +615,7 @@ def _apply_group_offloading_leaf_level( onload_leader=submodule, non_blocking=non_blocking, stream=stream, + record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, ) @@ -605,6 +661,7 @@ def _apply_group_offloading_leaf_level( buffers=buffers, non_blocking=non_blocking, stream=stream, + record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, ) @@ -624,6 +681,7 @@ def _apply_group_offloading_leaf_level( buffers=None, non_blocking=False, stream=None, + record_stream=False, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, ) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 19ac868cdae0..2a22bc09ad7a 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -546,6 +546,7 @@ def enable_group_offload( num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, + record_stream: bool = False, low_cpu_mem_usage=False, ) -> None: r""" @@ -594,6 +595,7 @@ def enable_group_offload( num_blocks_per_group, non_blocking, use_stream, + record_stream, low_cpu_mem_usage=low_cpu_mem_usage, ) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index d55ff6e62872..847677884a35 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1525,8 +1525,9 @@ def get_memory_usage(storage_dtype, compute_dtype): or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE ) + @parameterized.expand([False, True]) @require_torch_gpu - def test_group_offloading(self): + def test_group_offloading(self, record_stream): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() torch.manual_seed(0) @@ -1566,7 +1567,9 @@ def run_forward(model): torch.manual_seed(0) model = self.model_class(**init_dict) - model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True) + model.enable_group_offload( + torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream + ) output_with_group_offloading4 = run_forward(model) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))