9
9
10
10
import heapq
11
11
import logging
12
- from functools import cache
13
12
from math import prod
14
13
from typing import Any , Dict , List , Tuple
15
14
23
22
USE_MERGE_DIMS ,
24
23
)
25
24
from distributed_shampoo .utils .shampoo_block_info import DDPBlockInfo
25
+ from distributed_shampoo .utils .shampoo_dist_utils import get_device_mesh
26
26
from distributed_shampoo .utils .shampoo_distributor import DistributorInterface
27
27
from distributed_shampoo .utils .shampoo_utils import (
28
28
compress_list ,
@@ -137,9 +137,10 @@ def __init__(
137
137
# Instantiates this by using DeviceMesh.
138
138
ranks_in_all_replicated_groups = self ._hsdp_device_mesh .mesh .T
139
139
for ranks_in_replicated_group in ranks_in_all_replicated_groups :
140
- device_mesh = self . _get_device_mesh (
140
+ device_mesh = get_device_mesh (
141
141
device_type = self ._hsdp_device_mesh .device_type ,
142
- ranks_in_replicated_group = ranks_in_replicated_group ,
142
+ mesh = ranks_in_replicated_group .view (- 1 , self ._dist_group_size ),
143
+ mesh_dim_names = ("replicate" , "shard" ),
143
144
)
144
145
if dist .get_rank () in ranks_in_replicated_group :
145
146
self ._dist_group = device_mesh .get_group ("shard" )
@@ -837,33 +838,6 @@ def block_within_tensor_shard_recovery(
837
838
block_end_idx = end_idx ,
838
839
)
839
840
840
- def _get_device_mesh (
841
- self ,
842
- device_type : str ,
843
- ranks_in_replicated_group : Tensor ,
844
- ) -> dtensor .DeviceMesh :
845
- """Returns 2D device mesh from the provided device type and ranks in replicated group.
846
- The 2D device mesh is formed in the way where the shard dimension is the same as self._dist_group_size.
847
-
848
- Args:
849
- device_type (str): Device type (specified as a string).
850
- ranks_in_replicated_group (Tensor): Ranks in replicated group.
851
-
852
- Returns:
853
- device_mesh (dtensor.DeviceMesh): Device mesh.
854
-
855
- """
856
-
857
- @cache
858
- def get_device_mesh (ranks_in_replicated_group : Tensor ) -> dtensor .DeviceMesh :
859
- return dtensor .DeviceMesh (
860
- device_type = device_type ,
861
- mesh = ranks_in_replicated_group .view (- 1 , self ._dist_group_size ),
862
- mesh_dim_names = ("replicate" , "shard" ),
863
- )
864
-
865
- return get_device_mesh (ranks_in_replicated_group )
866
-
867
841
def _allocate_zeros_distributed_tensor (
868
842
self ,
869
843
shape : Tuple [int , ...],
@@ -887,9 +861,10 @@ def _allocate_zeros_distributed_tensor(
887
861
ranks_in_replicated_group = torch .tensor (
888
862
dist .get_process_group_ranks (self ._hsdp_device_mesh .get_group (0 ))
889
863
)
890
- device_mesh_2d = self . _get_device_mesh (
864
+ device_mesh_2d = get_device_mesh (
891
865
device_type = device .type ,
892
- ranks_in_replicated_group = ranks_in_replicated_group ,
866
+ mesh = ranks_in_replicated_group .view (- 1 , self ._dist_group_size ),
867
+ mesh_dim_names = ("replicate" , "shard" ),
893
868
)
894
869
895
870
return dtensor_zeros (
0 commit comments