|
6 | 6 | import sys |
7 | 7 | import types |
8 | 8 | import json |
9 | | -from typing import Optional, Union |
| 9 | +from typing import Callable, Optional, Union |
10 | 10 | import torch |
11 | 11 | from torch.optim import Optimizer |
12 | 12 | from torch.optim.lr_scheduler import _LRScheduler |
|
27 | 27 |
|
28 | 28 | from .accelerator import get_accelerator |
29 | 29 | from .constants import TORCH_DISTRIBUTED_DEFAULT_PORT |
| 30 | +from .runtime.base_optimizer import DeepSpeedOptimizer |
| 31 | +from .runtime.dataloader import DeepSpeedDataLoader, RepeatingLoader |
30 | 32 | from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable |
31 | 33 | from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER |
32 | 34 | from .runtime.hybrid_engine import DeepSpeedHybridEngine |
@@ -65,46 +67,44 @@ def _parse_version(version_str): |
65 | 67 | # Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init |
66 | 68 | dist = None |
67 | 69 |
|
| 70 | +DS_PRIM_TYPES = (DeepSpeedEngine, DeepSpeedHybridEngine, DeepSpeedOptimizer, DeepSpeedDataLoader, RepeatingLoader) |
| 71 | + |
68 | 72 |
|
69 | 73 | def _mark_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): |
70 | 74 | """Mark a trainobj as initialized by setting the ds_is_inited attribute to True.""" |
71 | | - trainobj.ds_is_inited = True |
| 75 | + if not isinstance(trainobj, DS_PRIM_TYPES): # only mark non-DeepSpeed objects |
| 76 | + trainobj.ds_is_inited = True |
72 | 77 |
|
73 | 78 |
|
74 | 79 | def _is_ds_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): |
75 | 80 | """Check if a trainobj has been initialized by checking the ds_is_inited attribute.""" |
76 | | - return getattr(trainobj, 'ds_is_inited', False) |
77 | | - |
78 | | - |
79 | | -def _assert_trainobjs_not_inited(model: torch.nn.Module, optimizer: Optional[Union[Optimizer, |
80 | | - DeepSpeedOptimizerCallable]], |
81 | | - lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]): |
82 | | - """Enforce the model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call.""" |
83 | | - if _is_ds_initialized(model): |
84 | | - raise ValueError( |
85 | | - "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once.") |
86 | | - if optimizer is not None and isinstance(optimizer, Optimizer) and _is_ds_initialized(optimizer): |
87 | | - raise ValueError( |
88 | | - "Optimizer has already been initialized, please make sure to only call deepspeed.initialize on an optimizer once." |
89 | | - ) |
90 | | - if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler) and _is_ds_initialized(lr_scheduler): |
91 | | - raise ValueError( |
92 | | - "LR scheduler has already been initialized, please make sure to only call deepspeed.initialize on an LR scheduler once." |
93 | | - ) |
94 | | - |
95 | | - |
96 | | -def _mark_trainobjs_initialized(model: torch.nn.Module, optimizer: Optional[Union[Optimizer, |
97 | | - DeepSpeedOptimizerCallable]], |
98 | | - lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]]): |
99 | | - """Mark the model, optimizer, and lr_scheduler as initialized. |
100 | | - Note that callables of type DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable are not marked |
101 | | - as they are not stateful and reuse should be permissible. |
102 | | - """ |
103 | | - _mark_ds_initialized(model) |
104 | | - if optimizer is not None and isinstance(optimizer, Optimizer): |
105 | | - _mark_ds_initialized(optimizer) |
106 | | - if lr_scheduler is not None and isinstance(lr_scheduler, _LRScheduler): |
107 | | - _mark_ds_initialized(lr_scheduler) |
| 81 | + if isinstance(trainobj, DS_PRIM_TYPES): |
| 82 | + return True |
| 83 | + else: |
| 84 | + return getattr(trainobj, 'ds_is_inited', False) |
| 85 | + |
| 86 | + |
| 87 | +def _ensure_and_mark_trainobjs_inited( |
| 88 | + model: torch.nn.Module, |
| 89 | + optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]], |
| 90 | + lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]], |
| 91 | + ensures_not_inited: bool = False, |
| 92 | +): |
| 93 | + trainobjs = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler} |
| 94 | + |
| 95 | + for name, trainobj in trainobjs.items(): |
| 96 | + print(f"Checking {name}") |
| 97 | + if trainobj is None: |
| 98 | + continue |
| 99 | + if name in ("optimizer", "lr_scheduler") and not isinstance(trainobj, (Optimizer, _LRScheduler)): |
| 100 | + # skipping DeepSpeedOptimizerCallable and DeepSpeedSchedulerCallable |
| 101 | + continue |
| 102 | + if ensures_not_inited: |
| 103 | + if _is_ds_initialized(trainobj): |
| 104 | + raise ValueError( |
| 105 | + f"{name} has already been initialized, please make sure to only call deepspeed.initialize on a {name} once." |
| 106 | + ) |
| 107 | + _mark_ds_initialized(trainobj) |
108 | 108 |
|
109 | 109 |
|
110 | 110 | def initialize(args=None, |
@@ -179,9 +179,7 @@ def initialize(args=None, |
179 | 179 |
|
180 | 180 | assert model is not None, "deepspeed.initialize requires a model" |
181 | 181 | # enforce that model, optimizer, and lr_scheduler have not been used in a previous deepspeed.initialize call |
182 | | - _assert_trainobjs_not_inited(model, optimizer, lr_scheduler) |
183 | | - # mark model, optimizer, and lr_scheduler as initialized |
184 | | - _mark_trainobjs_initialized(model, optimizer, lr_scheduler) |
| 182 | + _ensure_and_mark_trainobjs_inited(model, optimizer, lr_scheduler, ensures_not_inited=True) |
185 | 183 |
|
186 | 184 | global dist |
187 | 185 | from deepspeed import comm as dist |
@@ -267,7 +265,7 @@ def initialize(args=None, |
267 | 265 | zero.partition_parameters.restore_init_context() |
268 | 266 |
|
269 | 267 | # mark engine, optimizer, and lr_scheduler as initialized |
270 | | - _mark_trainobjs_initialized(engine, engine.optimizer, engine.lr_scheduler) |
| 268 | + _ensure_and_mark_trainobjs_inited(engine, engine.optimizer, engine.lr_scheduler, ensures_not_inited=False) |
271 | 269 |
|
272 | 270 | return_items = [ |
273 | 271 | engine, |
|
0 commit comments