Skip to content

Commit 9984a64

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 794c6b1 commit 9984a64

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

colossalai/checkpoint_io/distributed_checkpoint_utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
from contextlib import contextmanager
34
from typing import Dict
45

56
import torch
@@ -9,12 +10,8 @@
910

1011
from colossalai.interface import ModelWrapper
1112
from colossalai.shardformer.layer.parallel_module import ParallelModule
12-
from contextlib import contextmanager
1313

14-
from .utils import (
15-
load_state_dict,
16-
search_tp_partition_dim,
17-
)
14+
from .utils import load_state_dict, search_tp_partition_dim
1815

1916
MODEL_META_PREFIX = "pytorch_model-meta-dist-"
2017
MODEL_WEIGHT_PREFIX = "pytorch_model-dist-"
@@ -34,8 +31,7 @@ def RestoreDefaultStateDictBehavior(model):
3431
yield model
3532
finally:
3633
for module, original_method in original_methods.items():
37-
module._save_to_state_dict, module._load_from_state_dict = original_method
38-
34+
module._save_to_state_dict, module._load_from_state_dict = original_method
3935

4036

4137
def create_model_metadata(
@@ -260,12 +256,14 @@ def load_dist_model(
260256

261257
return state_dict
262258

259+
263260
def get_dist_files_name(weights_name, dist_id):
264261
weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin")
265262
weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors")
266263
return weights_name
267264

265+
268266
def get_dist_meta_file_name(checkpoint, dist_id, use_safetensors):
269267
if use_safetensors:
270268
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}")
271-
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}")
269+
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}")

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import copy
22
import logging
33
import os
4+
from contextlib import nullcontext
45
from functools import reduce
56
from pathlib import Path
67
from shutil import rmtree
78
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
8-
from contextlib import nullcontext
99

1010
import torch
1111
import torch.distributed as dist
@@ -26,14 +26,14 @@
2626
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
2727

2828
from .distributed_checkpoint_utils import (
29+
MODEL_WEIGHT_PREFIX,
30+
RestoreDefaultStateDictBehavior,
2931
create_model_metadata,
32+
get_dist_files_name,
33+
get_dist_meta_file_name,
3034
is_pytorch_model_meta_dist_file,
3135
load_dist_model,
3236
save_metadata,
33-
get_dist_files_name,
34-
get_dist_meta_file_name,
35-
MODEL_WEIGHT_PREFIX,
36-
RestoreDefaultStateDictBehavior
3737
)
3838
from .general_checkpoint_io import GeneralCheckpointIO
3939
from .index_file import CheckpointIndexFile
@@ -108,7 +108,7 @@ def _model_sharder(
108108
keep_vars: bool = False,
109109
size_per_shard: int = 1024,
110110
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
111-
gather_dtensor: bool = True,
111+
gather_dtensor: bool = True,
112112
) -> Iterator[Tuple[OrderedDict, int]]:
113113
# An internel method that breaks state_dict of model into shards within limited size.
114114

@@ -118,7 +118,7 @@ def _model_sharder(
118118
for name, param in model.named_parameters():
119119
if param is None:
120120
continue
121-
121+
122122
# Gather tensor pieces when using tensor parallel.
123123
param_ = gather_distributed_param(param, keep_vars=False)
124124
if is_padded_tensor(param_):
@@ -245,12 +245,12 @@ def save_sharded_model(
245245
model._force_wait_all_gather()
246246
if self.dp_rank != 0 and self.sp_rank != 0:
247247
return
248-
248+
249249
model_metadata = None
250250
if not gather_dtensor:
251251
# Manage filenames of sharded weights and index file for each pipeline stage.
252252
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
253-
253+
254254
model = model.unwrap()
255255

256256
if os.path.isfile(checkpoint):
@@ -280,7 +280,9 @@ def save_sharded_model(
280280
if not gather_dtensor:
281281
dist_id = self.tp_size * self.pp_rank + self.tp_rank
282282
weights_name = get_dist_files_name(weights_name=weights_name, dist_id=dist_id)
283-
metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_safetensors)
283+
metadata_file = get_dist_meta_file_name(
284+
checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_safetensors
285+
)
284286

285287
if use_async:
286288
total_size, writers = async_save_state_dict_shards(
@@ -413,9 +415,7 @@ def load_sharded_model(
413415
)
414416
model = model.unwrap()
415417
with RestoreDefaultStateDictBehavior(model):
416-
load_state_dict_into_model(
417-
model, state_dict, missing_keys=[], strict=False, load_sub_module=True
418-
)
418+
load_state_dict_into_model(model, state_dict, missing_keys=[], strict=False, load_sub_module=True)
419419
return
420420

421421
model_before_wrapping = model # backup for model before wrapping
@@ -897,7 +897,7 @@ def load_unsharded_model(
897897
load_dtensor = True
898898
break
899899

900-
model_metadata = None # used for dist model
900+
model_metadata = None # used for dist model
901901
if load_dtensor:
902902
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
903903

0 commit comments

Comments
 (0)