Skip to content

Commit 04a4ae2

Browse files
tsunghsienleefacebook-github-bot
authored andcommitted
Open-sourced update on 08/12/2024
Summary: 1. `shampoo_dist_utils.py` factors out utility functions used in distributed setting. 2. Fix the incorrect import line in `matrix_functions_test.py`. 3. Variuous imporvements on `shampoo_quantization.py`. 4. Fix the bug of `merge_small_dims()` when encountering empty tensor. Reviewed By: hjmshi Differential Revision: D61168562 fbshipit-source-id: 75868e44ef53222f34f5585b0d842185360497ef
1 parent 461cebf commit 04a4ae2

11 files changed

+350
-143
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""
2+
Copyright (c) Meta Platforms, Inc. and affiliates.
3+
All rights reserved.
4+
5+
This source code is licensed under the BSD-style license found in the
6+
LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
10+
#!/usr/bin/env python3
11+
12+
13+
from unittest import mock
14+
15+
import torch
16+
17+
from distributed_shampoo.utils import shampoo_dist_utils
18+
from distributed_shampoo.utils.shampoo_dist_utils import get_device_mesh
19+
from torch.distributed.device_mesh import DeviceMesh
20+
from torch.testing._internal.distributed._tensor.common_dtensor import (
21+
DTensorTestBase,
22+
with_comms,
23+
)
24+
25+
26+
class ShampooDistUtilsTest(DTensorTestBase):
27+
@property
28+
def world_size(self) -> int:
29+
return 4
30+
31+
def _verify_deivce_mesh(self, device_mesh: DeviceMesh) -> None:
32+
replicate_mesh = device_mesh["replicate"]
33+
shard_mesh = device_mesh["shard"]
34+
35+
self.assertEqual(device_mesh.get_group(0), device_mesh.get_group("replicate"))
36+
self.assertEqual(device_mesh.get_group(1), device_mesh.get_group("shard"))
37+
38+
self.assertEqual(device_mesh.get_group("shard"), shard_mesh.get_group())
39+
self.assertEqual(device_mesh.get_group("replicate"), replicate_mesh.get_group())
40+
41+
self.assertCountEqual(
42+
device_mesh.get_all_groups(),
43+
(shard_mesh.get_group(), replicate_mesh.get_group()),
44+
)
45+
46+
@with_comms
47+
def test_get_device_mesh(self) -> None:
48+
mesh = torch.tensor(range(self.world_size)).view(-1, self.world_size // 2)
49+
50+
self._verify_deivce_mesh(
51+
device_mesh=get_device_mesh(
52+
device_type=self.device_type,
53+
mesh=mesh,
54+
mesh_dim_names=("replicate", "shard"),
55+
)
56+
)
57+
58+
# Test the caching property of get_device_mesh() by mocking DeviceMesh.__init__().
59+
# DeviceMesh.__init__() should not be called due to caching, and the output of
60+
# get_device_mesh() should be the same as the previous one.
61+
with mock.patch.object(
62+
shampoo_dist_utils.DeviceMesh,
63+
"__init__",
64+
) as mock_device_mesh_init:
65+
device_mesh = get_device_mesh(
66+
device_type=self.device_type,
67+
mesh=mesh,
68+
mesh_dim_names=("replicate", "shard"),
69+
)
70+
71+
mock_device_mesh_init.assert_not_called()
72+
73+
self._verify_deivce_mesh(device_mesh=device_mesh)

distributed_shampoo/utils/shampoo_ddp_distributor.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import heapq
1111
import logging
12-
from functools import cache, partial
12+
from functools import partial
1313
from typing import Any, Dict, Optional, Tuple
1414

1515
import torch
@@ -20,6 +20,7 @@
2020
PARAMS,
2121
)
2222
from distributed_shampoo.utils.shampoo_block_info import DDPBlockInfo
23+
from distributed_shampoo.utils.shampoo_dist_utils import get_device_mesh
2324
from distributed_shampoo.utils.shampoo_distributor import DistributorInterface
2425
from distributed_shampoo.utils.shampoo_utils import (
2526
compress_list,
@@ -478,19 +479,11 @@ def _allocate_zeros_distributed_tensor(
478479
)
479480
)
480481

481-
@cache
482-
def get_device_mesh(device_mesh_ranks: Tuple[int, ...]) -> dtensor.DeviceMesh:
483-
"""Returns device mesh from provided ranks. This function will cache previous meshes according to the input ranks.
484-
485-
Args:
486-
device_mesh_ranks ([Tuple[int, ...]): Ranks to use in device mesh of desired tensor.
487-
488-
"""
489-
return dtensor.DeviceMesh(device_type=device.type, mesh=device_mesh_ranks)
490-
491482
return dtensor_zeros(
492483
shape,
493484
dtype=dtype,
494-
device_mesh=get_device_mesh(device_mesh_ranks),
485+
device_mesh=get_device_mesh(
486+
device_type=device.type, mesh=device_mesh_ranks
487+
),
495488
placements=[dtensor.Replicate()],
496489
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
Copyright (c) Meta Platforms, Inc. and affiliates.
3+
All rights reserved.
4+
5+
This source code is licensed under the BSD-style license found in the
6+
LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
10+
from functools import cache
11+
from typing import Optional
12+
13+
import torch
14+
15+
from torch.distributed import _tensor as dtensor
16+
from torch.distributed._tensor import DeviceMesh
17+
18+
19+
@cache
20+
def get_device_mesh(
21+
device_type: str,
22+
mesh: torch.Tensor | tuple[int, ...],
23+
mesh_dim_names: Optional[tuple[str, ...]] = None,
24+
) -> dtensor.DeviceMesh:
25+
"""Returns device mesh from provided ranks. This function will cache previous meshes according to the input ranks.
26+
27+
Args:
28+
device_type (str): The device type of the mesh. Currently supports: "cpu", "cuda/cuda-like".
29+
mesh (torch.Tensor | tuple[int, ...]): A multi-dimensional array or an integer tensor describing the layout
30+
of devices, where the IDs are global IDs of the default process group.
31+
mesh_dim_names (Optional[tuple[str, ...]]): Names of mesh dimensions.
32+
33+
Returns:
34+
device_mesh (dtensor.DeviceMesh): Device mesh.
35+
36+
37+
"""
38+
return DeviceMesh(device_type=device_type, mesh=mesh, mesh_dim_names=mesh_dim_names)

distributed_shampoo/utils/shampoo_hsdp_distributor.py

+7-32
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import heapq
1111
import logging
12-
from functools import cache
1312
from math import prod
1413
from typing import Any, Dict, List, Tuple
1514

@@ -23,6 +22,7 @@
2322
USE_MERGE_DIMS,
2423
)
2524
from distributed_shampoo.utils.shampoo_block_info import DDPBlockInfo
25+
from distributed_shampoo.utils.shampoo_dist_utils import get_device_mesh
2626
from distributed_shampoo.utils.shampoo_distributor import DistributorInterface
2727
from distributed_shampoo.utils.shampoo_utils import (
2828
compress_list,
@@ -137,9 +137,10 @@ def __init__(
137137
# Instantiates this by using DeviceMesh.
138138
ranks_in_all_replicated_groups = self._hsdp_device_mesh.mesh.T
139139
for ranks_in_replicated_group in ranks_in_all_replicated_groups:
140-
device_mesh = self._get_device_mesh(
140+
device_mesh = get_device_mesh(
141141
device_type=self._hsdp_device_mesh.device_type,
142-
ranks_in_replicated_group=ranks_in_replicated_group,
142+
mesh=ranks_in_replicated_group.view(-1, self._dist_group_size),
143+
mesh_dim_names=("replicate", "shard"),
143144
)
144145
if dist.get_rank() in ranks_in_replicated_group:
145146
self._dist_group = device_mesh.get_group("shard")
@@ -837,33 +838,6 @@ def block_within_tensor_shard_recovery(
837838
block_end_idx=end_idx,
838839
)
839840

840-
def _get_device_mesh(
841-
self,
842-
device_type: str,
843-
ranks_in_replicated_group: Tensor,
844-
) -> dtensor.DeviceMesh:
845-
"""Returns 2D device mesh from the provided device type and ranks in replicated group.
846-
The 2D device mesh is formed in the way where the shard dimension is the same as self._dist_group_size.
847-
848-
Args:
849-
device_type (str): Device type (specified as a string).
850-
ranks_in_replicated_group (Tensor): Ranks in replicated group.
851-
852-
Returns:
853-
device_mesh (dtensor.DeviceMesh): Device mesh.
854-
855-
"""
856-
857-
@cache
858-
def get_device_mesh(ranks_in_replicated_group: Tensor) -> dtensor.DeviceMesh:
859-
return dtensor.DeviceMesh(
860-
device_type=device_type,
861-
mesh=ranks_in_replicated_group.view(-1, self._dist_group_size),
862-
mesh_dim_names=("replicate", "shard"),
863-
)
864-
865-
return get_device_mesh(ranks_in_replicated_group)
866-
867841
def _allocate_zeros_distributed_tensor(
868842
self,
869843
shape: Tuple[int, ...],
@@ -887,9 +861,10 @@ def _allocate_zeros_distributed_tensor(
887861
ranks_in_replicated_group = torch.tensor(
888862
dist.get_process_group_ranks(self._hsdp_device_mesh.get_group(0))
889863
)
890-
device_mesh_2d = self._get_device_mesh(
864+
device_mesh_2d = get_device_mesh(
891865
device_type=device.type,
892-
ranks_in_replicated_group=ranks_in_replicated_group,
866+
mesh=ranks_in_replicated_group.view(-1, self._dist_group_size),
867+
mesh_dim_names=("replicate", "shard"),
893868
)
894869

895870
return dtensor_zeros(

distributed_shampoo/utils/shampoo_preconditioner_list.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from dataclasses import dataclass, field
1313

1414
from itertools import chain
15-
from typing import Any, DefaultDict, Sequence, Tuple, Union
15+
from types import TracebackType
16+
from typing import Any, DefaultDict, Optional, Sequence, Tuple, Type, Union
1617

1718
import torch
1819
from distributed_shampoo.utils.shampoo_block_info import BlockInfo
@@ -775,3 +776,32 @@ def compute_root_inverse_residuals(
775776
tuple(relative_errors),
776777
tuple(relative_residuals),
777778
)
779+
780+
781+
class DequantizePreconditionersContext:
782+
"""DequantizePreconditionersContext is used for automatically dequantize and then quantize the preconditioners used within this context.
783+
784+
Args:
785+
preconditioner_list (PreconditionerList): Preconditioner list which contains the preconditioners to be dequantized and quantized.
786+
787+
Examples:
788+
>>> with DequantizePreconditionersContext(preconditioner_list):
789+
>>> # Do something with the preconditioners, and preconditioner_list will be dequantized.
790+
>>> # After the context is exited, the preconditioners will be quantized.
791+
792+
"""
793+
794+
def __init__(self, preconditioner_list: PreconditionerList) -> None:
795+
self._preconditioner_list = preconditioner_list
796+
797+
def __enter__(self) -> "DequantizePreconditionersContext":
798+
self._preconditioner_list.dequantize_preconditioners()
799+
return self
800+
801+
def __exit__(
802+
self,
803+
exc_type: Optional[Type[BaseException]],
804+
exc_val: Optional[BaseException],
805+
exc_tb: Optional[TracebackType],
806+
) -> None:
807+
self._preconditioner_list.quantize_preconditioners()

distributed_shampoo/utils/shampoo_quantization.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def dequantize(self, dequantized_dtype: torch.dtype) -> Tensor:
7676
self.quantized_values, dtype=dequantized_dtype
7777
)
7878
QuantizedTensor._convert_float_to_float(
79-
dequantized_values, self.quantized_values
79+
src=self.quantized_values,
80+
dest=dequantized_values,
8081
)
8182
return dequantized_values
8283
else:
@@ -189,7 +190,7 @@ def dequantize(self) -> Tuple[Tensor, ...]:
189190
)
190191

191192
def dequantize_(self) -> None:
192-
if self.dequantized_value_list is not None:
193+
if self.is_dequantized_stored():
193194
logger.warning(
194195
"Dequantized values are already stored; overwriting these values..."
195196
)
@@ -217,7 +218,7 @@ def quantize(self, tensor_list: Tuple[Tensor, ...]) -> None:
217218
)
218219

219220
def quantize_(self) -> None:
220-
if self.dequantized_value_list is None:
221+
if not self.is_dequantized_stored():
221222
logger.warning(
222223
f"No stored dequantized values {self.dequantized_value_list=}. Must first call dequantize_()."
223224
)
@@ -232,7 +233,7 @@ def quantize_(self) -> None:
232233

233234
@property
234235
def dequantized_value(self) -> Tuple[Tensor, ...]:
235-
assert self.dequantized_value_list is not None
236+
assert self.dequantized_value_list is not None # make type checker happy
236237
return self.dequantized_value_list
237238

238239
@property
@@ -243,7 +244,7 @@ def is_dequantized_stored(self) -> bool:
243244
return self.dequantized_value_list is not None
244245

245246
def compress(self, selector: Tuple[bool, ...]) -> "QuantizedTensorList":
246-
assert self.dequantized_value_list is None
247+
assert not self.is_dequantized_stored()
247248
masked_quantized_value_list = compress_list(self.quantized_value_list, selector)
248249
masked_min_values = compress_list(self._min_values, selector)
249250
masked_max_values = compress_list(self._max_values, selector)

distributed_shampoo/utils/shampoo_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def merge_small_dims(tensor_shape: Sequence[int], threshold: int) -> Tuple[int,
2929

3030
# Squeeze tensor shape to remove dimension with 1; if all dimensions are 1,
3131
# then add a 1 to the tensor shape.
32-
squeezed_tensor_shape = list(filter(lambda t: t > 1, tensor_shape)) or [1]
32+
squeezed_tensor_shape = list(filter(lambda t: t != 1, tensor_shape)) or [1]
3333
new_tensor_shape = [squeezed_tensor_shape[0]]
3434
for next_tensor_shape in squeezed_tensor_shape[1:]:
3535
if (new_dimension := new_tensor_shape[-1] * next_tensor_shape) <= threshold:

0 commit comments

Comments
 (0)