@@ -56,6 +56,7 @@ def __init__(
56
56
buffers : Optional [List [torch .Tensor ]] = None ,
57
57
non_blocking : bool = False ,
58
58
stream : Optional [torch .cuda .Stream ] = None ,
59
+ record_stream : Optional [bool ] = False ,
59
60
low_cpu_mem_usage = False ,
60
61
onload_self : bool = True ,
61
62
) -> None :
@@ -68,11 +69,14 @@ def __init__(
68
69
self .buffers = buffers or []
69
70
self .non_blocking = non_blocking or stream is not None
70
71
self .stream = stream
72
+ self .record_stream = record_stream
71
73
self .onload_self = onload_self
72
74
self .low_cpu_mem_usage = low_cpu_mem_usage
73
-
74
75
self .cpu_param_dict = self ._init_cpu_param_dict ()
75
76
77
+ if self .stream is None and self .record_stream :
78
+ raise ValueError ("`record_stream` cannot be True when `stream` is None." )
79
+
76
80
def _init_cpu_param_dict (self ):
77
81
cpu_param_dict = {}
78
82
if self .stream is None :
@@ -112,6 +116,8 @@ def _pinned_memory_tensors(self):
112
116
def onload_ (self ):
113
117
r"""Onloads the group of modules to the onload_device."""
114
118
context = nullcontext () if self .stream is None else torch .cuda .stream (self .stream )
119
+ current_stream = torch .cuda .current_stream () if self .record_stream else None
120
+
115
121
if self .stream is not None :
116
122
# Wait for previous Host->Device transfer to complete
117
123
self .stream .synchronize ()
@@ -122,14 +128,22 @@ def onload_(self):
122
128
for group_module in self .modules :
123
129
for param in group_module .parameters ():
124
130
param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
131
+ if self .record_stream :
132
+ param .data .record_stream (current_stream )
125
133
for buffer in group_module .buffers ():
126
134
buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
135
+ if self .record_stream :
136
+ buffer .data .record_stream (current_stream )
127
137
128
138
for param in self .parameters :
129
139
param .data = pinned_memory [param ].to (self .onload_device , non_blocking = self .non_blocking )
140
+ if self .record_stream :
141
+ param .data .record_stream (current_stream )
130
142
131
143
for buffer in self .buffers :
132
144
buffer .data = pinned_memory [buffer ].to (self .onload_device , non_blocking = self .non_blocking )
145
+ if self .record_stream :
146
+ buffer .data .record_stream (current_stream )
133
147
134
148
else :
135
149
for group_module in self .modules :
@@ -143,11 +157,14 @@ def onload_(self):
143
157
144
158
for buffer in self .buffers :
145
159
buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
160
+ if self .record_stream :
161
+ buffer .data .record_stream (current_stream )
146
162
147
163
def offload_ (self ):
148
164
r"""Offloads the group of modules to the offload_device."""
149
165
if self .stream is not None :
150
- torch .cuda .current_stream ().synchronize ()
166
+ if not self .record_stream :
167
+ torch .cuda .current_stream ().synchronize ()
151
168
for group_module in self .modules :
152
169
for param in group_module .parameters ():
153
170
param .data = self .cpu_param_dict [param ]
@@ -331,6 +348,7 @@ def apply_group_offloading(
331
348
num_blocks_per_group : Optional [int ] = None ,
332
349
non_blocking : bool = False ,
333
350
use_stream : bool = False ,
351
+ record_stream : bool = False ,
334
352
low_cpu_mem_usage : bool = False ,
335
353
) -> None :
336
354
r"""
@@ -378,6 +396,10 @@ def apply_group_offloading(
378
396
use_stream (`bool`, defaults to `False`):
379
397
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
380
398
overlapping computation and data transfer.
399
+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
400
+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
401
+ [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
402
+ details.
381
403
low_cpu_mem_usage (`bool`, defaults to `False`):
382
404
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
383
405
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
@@ -417,11 +439,24 @@ def apply_group_offloading(
417
439
raise ValueError ("num_blocks_per_group must be provided when using offload_type='block_level'." )
418
440
419
441
_apply_group_offloading_block_level (
420
- module , num_blocks_per_group , offload_device , onload_device , non_blocking , stream , low_cpu_mem_usage
442
+ module = module ,
443
+ num_blocks_per_group = num_blocks_per_group ,
444
+ offload_device = offload_device ,
445
+ onload_device = onload_device ,
446
+ non_blocking = non_blocking ,
447
+ stream = stream ,
448
+ record_stream = record_stream ,
449
+ low_cpu_mem_usage = low_cpu_mem_usage ,
421
450
)
422
451
elif offload_type == "leaf_level" :
423
452
_apply_group_offloading_leaf_level (
424
- module , offload_device , onload_device , non_blocking , stream , low_cpu_mem_usage
453
+ module = module ,
454
+ offload_device = offload_device ,
455
+ onload_device = onload_device ,
456
+ non_blocking = non_blocking ,
457
+ stream = stream ,
458
+ record_stream = record_stream ,
459
+ low_cpu_mem_usage = low_cpu_mem_usage ,
425
460
)
426
461
else :
427
462
raise ValueError (f"Unsupported offload_type: { offload_type } " )
@@ -434,6 +469,7 @@ def _apply_group_offloading_block_level(
434
469
onload_device : torch .device ,
435
470
non_blocking : bool ,
436
471
stream : Optional [torch .cuda .Stream ] = None ,
472
+ record_stream : Optional [bool ] = False ,
437
473
low_cpu_mem_usage : bool = False ,
438
474
) -> None :
439
475
r"""
@@ -453,6 +489,14 @@ def _apply_group_offloading_block_level(
453
489
stream (`torch.cuda.Stream`, *optional*):
454
490
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
455
491
for overlapping computation and data transfer.
492
+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
493
+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
494
+ [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
495
+ details.
496
+ low_cpu_mem_usage (`bool`, defaults to `False`):
497
+ If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
498
+ option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
499
+ the CPU memory is a bottleneck but may counteract the benefits of using streams.
456
500
"""
457
501
458
502
# Create module groups for ModuleList and Sequential blocks
@@ -475,6 +519,7 @@ def _apply_group_offloading_block_level(
475
519
onload_leader = current_modules [0 ],
476
520
non_blocking = non_blocking ,
477
521
stream = stream ,
522
+ record_stream = record_stream ,
478
523
low_cpu_mem_usage = low_cpu_mem_usage ,
479
524
onload_self = stream is None ,
480
525
)
@@ -512,6 +557,7 @@ def _apply_group_offloading_block_level(
512
557
buffers = buffers ,
513
558
non_blocking = False ,
514
559
stream = None ,
560
+ record_stream = False ,
515
561
onload_self = True ,
516
562
)
517
563
next_group = matched_module_groups [0 ] if len (matched_module_groups ) > 0 else None
@@ -524,6 +570,7 @@ def _apply_group_offloading_leaf_level(
524
570
onload_device : torch .device ,
525
571
non_blocking : bool ,
526
572
stream : Optional [torch .cuda .Stream ] = None ,
573
+ record_stream : Optional [bool ] = False ,
527
574
low_cpu_mem_usage : bool = False ,
528
575
) -> None :
529
576
r"""
@@ -545,6 +592,14 @@ def _apply_group_offloading_leaf_level(
545
592
stream (`torch.cuda.Stream`, *optional*):
546
593
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
547
594
for overlapping computation and data transfer.
595
+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
596
+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
597
+ [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
598
+ details.
599
+ low_cpu_mem_usage (`bool`, defaults to `False`):
600
+ If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
601
+ option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
602
+ the CPU memory is a bottleneck but may counteract the benefits of using streams.
548
603
"""
549
604
550
605
# Create module groups for leaf modules and apply group offloading hooks
@@ -560,6 +615,7 @@ def _apply_group_offloading_leaf_level(
560
615
onload_leader = submodule ,
561
616
non_blocking = non_blocking ,
562
617
stream = stream ,
618
+ record_stream = record_stream ,
563
619
low_cpu_mem_usage = low_cpu_mem_usage ,
564
620
onload_self = True ,
565
621
)
@@ -605,6 +661,7 @@ def _apply_group_offloading_leaf_level(
605
661
buffers = buffers ,
606
662
non_blocking = non_blocking ,
607
663
stream = stream ,
664
+ record_stream = record_stream ,
608
665
low_cpu_mem_usage = low_cpu_mem_usage ,
609
666
onload_self = True ,
610
667
)
@@ -624,6 +681,7 @@ def _apply_group_offloading_leaf_level(
624
681
buffers = None ,
625
682
non_blocking = False ,
626
683
stream = None ,
684
+ record_stream = False ,
627
685
low_cpu_mem_usage = low_cpu_mem_usage ,
628
686
onload_self = True ,
629
687
)
0 commit comments