1
1
import copy
2
2
import logging
3
3
import os
4
+ from contextlib import nullcontext
4
5
from functools import reduce
5
6
from pathlib import Path
6
7
from shutil import rmtree
7
8
from typing import Dict , Iterator , Optional , OrderedDict , Tuple
8
- from contextlib import nullcontext
9
9
10
10
import torch
11
11
import torch .distributed as dist
26
26
from colossalai .utils .safetensors import _flatten_optim_state_dict , load_flat
27
27
28
28
from .distributed_checkpoint_utils import (
29
+ MODEL_WEIGHT_PREFIX ,
30
+ RestoreDefaultStateDictBehavior ,
29
31
create_model_metadata ,
32
+ get_dist_files_name ,
33
+ get_dist_meta_file_name ,
30
34
is_pytorch_model_meta_dist_file ,
31
35
load_dist_model ,
32
36
save_metadata ,
33
- get_dist_files_name ,
34
- get_dist_meta_file_name ,
35
- MODEL_WEIGHT_PREFIX ,
36
- RestoreDefaultStateDictBehavior
37
37
)
38
38
from .general_checkpoint_io import GeneralCheckpointIO
39
39
from .index_file import CheckpointIndexFile
@@ -108,7 +108,7 @@ def _model_sharder(
108
108
keep_vars : bool = False ,
109
109
size_per_shard : int = 1024 ,
110
110
pinned_state_dicts : Optional [Dict [str , torch .Tensor ]] = None ,
111
- gather_dtensor : bool = True ,
111
+ gather_dtensor : bool = True ,
112
112
) -> Iterator [Tuple [OrderedDict , int ]]:
113
113
# An internel method that breaks state_dict of model into shards within limited size.
114
114
@@ -118,7 +118,7 @@ def _model_sharder(
118
118
for name , param in model .named_parameters ():
119
119
if param is None :
120
120
continue
121
-
121
+
122
122
# Gather tensor pieces when using tensor parallel.
123
123
param_ = gather_distributed_param (param , keep_vars = False )
124
124
if is_padded_tensor (param_ ):
@@ -245,12 +245,12 @@ def save_sharded_model(
245
245
model ._force_wait_all_gather ()
246
246
if self .dp_rank != 0 and self .sp_rank != 0 :
247
247
return
248
-
248
+
249
249
model_metadata = None
250
250
if not gather_dtensor :
251
251
# Manage filenames of sharded weights and index file for each pipeline stage.
252
252
model_metadata = create_model_metadata (model , tp_size = self .tp_size , tp_rank = self .tp_rank )
253
-
253
+
254
254
model = model .unwrap ()
255
255
256
256
if os .path .isfile (checkpoint ):
@@ -280,7 +280,9 @@ def save_sharded_model(
280
280
if not gather_dtensor :
281
281
dist_id = self .tp_size * self .pp_rank + self .tp_rank
282
282
weights_name = get_dist_files_name (weights_name = weights_name , dist_id = dist_id )
283
- metadata_file = get_dist_meta_file_name (checkpoint = checkpoint , dist_id = dist_id , use_safetensors = use_safetensors )
283
+ metadata_file = get_dist_meta_file_name (
284
+ checkpoint = checkpoint , dist_id = dist_id , use_safetensors = use_safetensors
285
+ )
284
286
285
287
if use_async :
286
288
total_size , writers = async_save_state_dict_shards (
@@ -413,9 +415,7 @@ def load_sharded_model(
413
415
)
414
416
model = model .unwrap ()
415
417
with RestoreDefaultStateDictBehavior (model ):
416
- load_state_dict_into_model (
417
- model , state_dict , missing_keys = [], strict = False , load_sub_module = True
418
- )
418
+ load_state_dict_into_model (model , state_dict , missing_keys = [], strict = False , load_sub_module = True )
419
419
return
420
420
421
421
model_before_wrapping = model # backup for model before wrapping
@@ -897,7 +897,7 @@ def load_unsharded_model(
897
897
load_dtensor = True
898
898
break
899
899
900
- model_metadata = None # used for dist model
900
+ model_metadata = None # used for dist model
901
901
if load_dtensor :
902
902
model_metadata = create_model_metadata (model , tp_size = self .tp_size , tp_rank = self .tp_rank )
903
903
0 commit comments