@@ -181,6 +181,13 @@ def __init__(self):
181
181
self ._layer_execution_tracker_module_names = set ()
182
182
183
183
def initialize_hook (self , module ):
184
+ def make_execution_order_update_callback (current_name , current_submodule ):
185
+ def callback ():
186
+ logger .debug (f"Adding { current_name } to the execution order" )
187
+ self .execution_order .append ((current_name , current_submodule ))
188
+
189
+ return callback
190
+
184
191
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
185
192
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
186
193
# layers are executed during the forward pass.
@@ -192,14 +199,8 @@ def initialize_hook(self, module):
192
199
group_offloading_hook = registry .get_hook (_GROUP_OFFLOADING )
193
200
194
201
if group_offloading_hook is not None :
195
-
196
- def make_execution_order_update_callback (current_name , current_submodule ):
197
- def callback ():
198
- logger .debug (f"Adding { current_name } to the execution order" )
199
- self .execution_order .append ((current_name , current_submodule ))
200
-
201
- return callback
202
-
202
+ # For the first forward pass, we have to load in a blocking manner
203
+ group_offloading_hook .group .non_blocking = False
203
204
layer_tracker_hook = LayerExecutionTrackerHook (make_execution_order_update_callback (name , submodule ))
204
205
registry .register_hook (layer_tracker_hook , _LAYER_EXECUTION_TRACKER )
205
206
self ._layer_execution_tracker_module_names .add (name )
@@ -229,15 +230,21 @@ def post_forward(self, module, output):
229
230
# Remove the layer execution tracker hooks from the submodules
230
231
base_module_registry = module ._diffusers_hook
231
232
registries = [submodule ._diffusers_hook for _ , submodule in self .execution_order ]
233
+ group_offloading_hooks = [registry .get_hook (_GROUP_OFFLOADING ) for registry in registries ]
232
234
233
235
for i in range (num_executed ):
234
236
registries [i ].remove_hook (_LAYER_EXECUTION_TRACKER , recurse = False )
235
237
236
238
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
237
239
base_module_registry .remove_hook (_LAZY_PREFETCH_GROUP_OFFLOADING , recurse = False )
238
240
239
- # Apply lazy prefetching by setting required attributes
240
- group_offloading_hooks = [registry .get_hook (_GROUP_OFFLOADING ) for registry in registries ]
241
+ # LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
242
+ # We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
243
+ # see the benefits of prefetching.
244
+ for hook in group_offloading_hooks :
245
+ hook .group .non_blocking = True
246
+
247
+ # Set required attributes for prefetching
241
248
if num_executed > 0 :
242
249
base_module_group_offloading_hook = base_module_registry .get_hook (_GROUP_OFFLOADING )
243
250
base_module_group_offloading_hook .next_group = group_offloading_hooks [0 ].group
0 commit comments