Skip to content

Commit 3be6706

Browse files
authored
Fix Group offloading behaviour when using streams (#11097)
* update * update
1 parent cb1b8b2 commit 3be6706

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

src/diffusers/hooks/group_offloading.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@ def __init__(self):
181181
self._layer_execution_tracker_module_names = set()
182182

183183
def initialize_hook(self, module):
184+
def make_execution_order_update_callback(current_name, current_submodule):
185+
def callback():
186+
logger.debug(f"Adding {current_name} to the execution order")
187+
self.execution_order.append((current_name, current_submodule))
188+
189+
return callback
190+
184191
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
185192
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
186193
# layers are executed during the forward pass.
@@ -192,14 +199,8 @@ def initialize_hook(self, module):
192199
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
193200

194201
if group_offloading_hook is not None:
195-
196-
def make_execution_order_update_callback(current_name, current_submodule):
197-
def callback():
198-
logger.debug(f"Adding {current_name} to the execution order")
199-
self.execution_order.append((current_name, current_submodule))
200-
201-
return callback
202-
202+
# For the first forward pass, we have to load in a blocking manner
203+
group_offloading_hook.group.non_blocking = False
203204
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
204205
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
205206
self._layer_execution_tracker_module_names.add(name)
@@ -229,15 +230,21 @@ def post_forward(self, module, output):
229230
# Remove the layer execution tracker hooks from the submodules
230231
base_module_registry = module._diffusers_hook
231232
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
233+
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
232234

233235
for i in range(num_executed):
234236
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
235237

236238
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
237239
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
238240

239-
# Apply lazy prefetching by setting required attributes
240-
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
241+
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
242+
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
243+
# see the benefits of prefetching.
244+
for hook in group_offloading_hooks:
245+
hook.group.non_blocking = True
246+
247+
# Set required attributes for prefetching
241248
if num_executed > 0:
242249
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
243250
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group

0 commit comments

Comments
 (0)