Skip to content

Commit eaae2f5

Browse files
committed
Use DTensor-based tensor parallel
ghstack-source-id: b55b264d20bd2c0054f7248435fd605a452e876b Pull Request resolved: #180
1 parent 6253c6b commit eaae2f5

File tree

3 files changed

+70
-61
lines changed

3 files changed

+70
-61
lines changed

model.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -156,30 +156,22 @@ def __init__(self, config: ModelArgs):
156156
super().__init__()
157157
assert config.dim % config.n_head == 0
158158

159-
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
160-
# key, query, value projections for all heads, but in a batch
161-
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
159+
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
160+
self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
161+
self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
162162
self.wo = nn.Linear(config.dim, config.dim, bias=False)
163163
self.kv_cache = None
164164

165165
self.n_head = config.n_head
166166
self.head_dim = config.head_dim
167167
self.n_local_heads = config.n_local_heads
168168
self.dim = config.dim
169-
self._register_load_state_dict_pre_hook(self.load_hook)
170-
171-
def load_hook(self, state_dict, prefix, *args):
172-
if prefix + "wq.weight" in state_dict:
173-
wq = state_dict.pop(prefix + "wq.weight")
174-
wk = state_dict.pop(prefix + "wk.weight")
175-
wv = state_dict.pop(prefix + "wv.weight")
176-
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
177169

178170
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
179171
bsz, seqlen, _ = x.shape
180172

181173
kv_size = self.n_local_heads * self.head_dim
182-
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
174+
q, k, v = self.wq(x), self.wk(x), self.wv(x)
183175

184176
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
185177
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)

scripts/convert_hf_checkpoint.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def convert_hf_checkpoint(
8383
original_dir = checkpoint_dir / "original"
8484
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
8585
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}
86-
86+
8787

8888
def permute(w, n_head):
8989
dim = config.dim
@@ -116,13 +116,8 @@ def permute(w, n_head):
116116
if "wq" in key:
117117
q = final_result[key]
118118
k = final_result[key.replace("wq", "wk")]
119-
v = final_result[key.replace("wq", "wv")]
120-
q = permute(q, config.n_head)
121-
k = permute(k, config.n_local_heads)
122-
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
123-
del final_result[key]
124-
del final_result[key.replace("wq", "wk")]
125-
del final_result[key.replace("wq", "wv")]
119+
final_result[key] = permute(q, config.n_head)
120+
final_result[key.replace("wq", "wk")] = permute(k, config.n_local_heads)
126121
else:
127122
final_result = merged_result
128123
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")

tp.py

+63-41
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import os
7+
from enum import Enum
78
from typing import List, Optional
89

910
import torch
1011
import torch.distributed as dist
12+
from torch.distributed import DeviceMesh
13+
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
1114
from torch import nn
1215
if os.uname().sysname != "Darwin":
1316
from torch.distributed import _functional_collectives as funcol
@@ -16,7 +19,7 @@
1619
funcol = None
1720

1821
from model import Attention, FeedForward, Transformer
19-
from quantize import WeightOnlyInt4Linear
22+
from quantize import WeightOnlyInt4Linear, WeightOnlyInt8Linear
2023

2124

2225
def _get_rank() -> int:
@@ -33,6 +36,12 @@ def local_break():
3336
def _get_world_size() -> int:
3437
return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
3538

39+
global device_mesh
40+
41+
def _get_tp_mesh():
42+
# device_mesh has only TP dimension for now
43+
return device_mesh
44+
3645
def maybe_init_dist() -> Optional[int]:
3746
try:
3847
# provided by torchrun
@@ -48,86 +57,97 @@ def maybe_init_dist() -> Optional[int]:
4857

4958
torch.cuda.set_device(rank)
5059
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
60+
61+
global device_mesh
62+
device_mesh = dist.init_device_mesh(
63+
"cuda",
64+
(world_size,), # Only TP dimension for now
65+
)
5166
return rank
5267

68+
class TPMode(Enum):
69+
MANUAL = 0
70+
DTENSOR = 1
5371

54-
def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None:
72+
def _apply_tp_linear(linear: nn.Linear, style: str) -> None:
5573
rank = _get_rank()
5674
world_size = _get_world_size()
75+
tp_mesh = _get_tp_mesh()
5776

5877
# Linear's weight matrix is transposed, and is of shape
5978
# (linear.out_features, linear.in_features)
6079
dim_lookup = {
61-
"colwise": (0, "out_features"),
62-
"rowwise": (1, "in_features")
80+
"colwise": (0, "out_features", ColwiseParallel()),
81+
"rowwise": (1, "in_features", RowwiseParallel()),
6382
}
6483
assert style in dim_lookup
65-
shard_dim, size_attr = dim_lookup[style]
84+
shard_dim, size_attr, tp_plan = dim_lookup[style]
6685

6786
# ensure we can shard evenly
6887
assert getattr(linear, size_attr) % world_size == 0
6988
def shard(x, dim):
7089
assert x.size(dim=dim) % world_size == 0
7190
return torch.tensor_split(x, world_size, dim=dim)[rank]
7291

73-
def shard_qkv(qkv, dim, weight_splits):
74-
q, k, v = qkv.split(weight_splits, dim=dim)
75-
q = shard(q, dim)
76-
k = shard(k, dim)
77-
v = shard(v, dim)
78-
return torch.cat((q,k,v), dim=dim)
79-
80-
# shard
81-
if weight_splits:
82-
# attention
83-
assert len(weight_splits) == 3
84-
85-
if isinstance(linear, WeightOnlyInt4Linear):
86-
sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits])
87-
linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits)
88-
else:
89-
sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
90-
if hasattr(linear, "scales") and style == "colwise":
91-
linear.scales = shard_qkv(linear.scales, 0, weight_splits)
92-
else:
93-
sharded_weight = shard(linear.weight, shard_dim)
94-
if isinstance(linear, WeightOnlyInt4Linear):
92+
def shard_scale(linear, shard_dim):
93+
if hasattr(linear, "scales_and_zeros"):
9594
linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim)
9695
if style == "rowwise":
9796
assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3]
9897
assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8
99-
if hasattr(linear, "scales") and style == "colwise":
100-
linear.scales = shard(linear.scales, 0)
98+
elif hasattr(linear, "scale"):
99+
if style == "colwise":
100+
linear.scales = shard(linear.scales, 0)
101+
102+
# shard
103+
tp_mode: TPMode
104+
if isinstance(linear, (WeightOnlyInt4Linear, WeightOnlyInt8Linear)):
105+
# TODO: DTensor doesn't have a way to distribute quantized tensor yet.
106+
# Should revisit when that capability is added.
107+
sharded_weight = shard(linear.weight, shard_dim)
108+
linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
109+
shard_scale(linear, shard_dim)
110+
tp_mode = TPMode.MANUAL
111+
else:
112+
# Use DTensor based TP
113+
parallelize_module(linear, tp_mesh, tp_plan)
114+
tp_mode = TPMode.DTENSOR
101115

102116
# local_break()
103-
linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
104117
setattr(linear, size_attr, getattr(linear, size_attr) // world_size)
105118

106119
# shape info should still be synced
107120
# assert linear.weight.shape == (linear.out_features, linear.in_features)
121+
return tp_mode
108122

109123

110124
def _apply_tp_ffn(mlp: FeedForward) -> None:
111125
assert hasattr(mlp, "w1")
112126
assert hasattr(mlp, "w3")
113127
assert hasattr(mlp, "w2")
114128

115-
_apply_tp_linear(mlp.w1, "colwise")
116-
_apply_tp_linear(mlp.w3, "colwise")
117-
_apply_tp_linear(mlp.w2, "rowwise")
129+
tp_mode = _apply_tp_linear(mlp.w1, "colwise")
130+
tp_mode = _apply_tp_linear(mlp.w3, "colwise")
131+
tp_mode = _apply_tp_linear(mlp.w2, "rowwise")
118132

119-
world_size = _get_world_size()
120-
mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
121-
output, "sum", list(range(world_size))))
133+
if tp_mode == TPMode.MANUAL:
134+
# In manual mode, we need to manually add an all-reduce at the end
135+
world_size = _get_world_size()
136+
mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
137+
output, "sum", list(range(world_size))))
122138

123139

124140
def _apply_tp_attn(attn: Attention) -> None:
125-
assert hasattr(attn, "wqkv")
141+
assert hasattr(attn, "wq")
142+
assert hasattr(attn, "wk")
143+
assert hasattr(attn, "wv")
126144
assert hasattr(attn, "wo")
127145

128146
kv_size = attn.n_local_heads * attn.head_dim
129-
_apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size])
130-
_apply_tp_linear(attn.wo, "rowwise")
147+
tp_mode = _apply_tp_linear(attn.wq, "colwise")
148+
tp_mode = _apply_tp_linear(attn.wk, "colwise")
149+
tp_mode = _apply_tp_linear(attn.wv, "colwise")
150+
tp_mode = _apply_tp_linear(attn.wo, "rowwise")
131151

132152
# overwrite
133153
world_size = _get_world_size()
@@ -136,8 +156,10 @@ def _apply_tp_attn(attn: Attention) -> None:
136156
attn.head_dim = attn.dim // attn.n_head
137157
attn.n_local_heads = attn.n_local_heads // world_size
138158

139-
attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
140-
output[0], "sum", list(range(world_size))))
159+
if tp_mode == TPMode.MANUAL:
160+
# In manual mode, we need to manually add an all-reduce at the end
161+
attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce(
162+
output[0], "sum", list(range(world_size))))
141163

142164

143165
def _apply_tp_Transformer(Transformer: Transformer) -> None:

0 commit comments

Comments
 (0)