Skip to content

Commit 370e545

Browse files
committed
support distribute checkpoint io
1 parent 5b094a8 commit 370e545

File tree

7 files changed

+784
-14
lines changed

7 files changed

+784
-14
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

+3
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def __init__(
7878
self.require_grad_sync = True
7979
self.overlap_allgather = overlap_allgather
8080
self.use_fp8 = use_fp8
81+
self.param_origin_shape = {}
82+
for name, param in module.named_parameters():
83+
self.param_origin_shape[name] = param.shape
8184

8285
shardformer = ShardFormer(shard_config)
8386
if custom_policy is not None:

colossalai/checkpoint_io/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .checkpoint_io_base import CheckpointIO
22
from .general_checkpoint_io import GeneralCheckpointIO
33
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
4+
5+
from .distributed_checkpoint_io import DistributedCheckpointIO
46
from .index_file import CheckpointIndexFile
57
from .moe_checkpoint import MoECheckpointIO
68

@@ -10,4 +12,5 @@
1012
"GeneralCheckpointIO",
1113
"HybridParallelCheckpointIO",
1214
"MoECheckpointIO",
15+
"DistributedCheckpointIO",
1316
]

0 commit comments

Comments
 (0)