From 547765c000c4bf1af43ab4855e15cf968efd1a45 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 12 Jun 2024 13:42:52 -0700 Subject: [PATCH] Use DTensor-based tensor parallel [ghstack-poisoned] --- model.py | 16 ++--- scripts/convert_hf_checkpoint.py | 11 +--- tp.py | 104 +++++++++++++++++++------------ 3 files changed, 70 insertions(+), 61 deletions(-) diff --git a/model.py b/model.py index 0660bc2b..13732c95 100644 --- a/model.py +++ b/model.py @@ -156,9 +156,9 @@ def __init__(self, config: ModelArgs): super().__init__() assert config.dim % config.n_head == 0 - total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim - # key, query, value projections for all heads, but in a batch - self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False) + self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False) + self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False) self.wo = nn.Linear(config.dim, config.dim, bias=False) self.kv_cache = None @@ -166,20 +166,12 @@ def __init__(self, config: ModelArgs): self.head_dim = config.head_dim self.n_local_heads = config.n_local_heads self.dim = config.dim - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook(self, state_dict, prefix, *args): - if prefix + "wq.weight" in state_dict: - wq = state_dict.pop(prefix + "wq.weight") - wk = state_dict.pop(prefix + "wk.weight") - wv = state_dict.pop(prefix + "wv.weight") - state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: bsz, seqlen, _ = x.shape kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.view(bsz, seqlen, self.n_head, self.head_dim) k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 8a221067..f47e3518 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -83,7 +83,7 @@ def convert_hf_checkpoint( original_dir = checkpoint_dir / "original" pattern = re.compile(r"^consolidated\.\d{2}\.pth$") bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)} - + def permute(w, n_head): dim = config.dim @@ -116,13 +116,8 @@ def permute(w, n_head): if "wq" in key: q = final_result[key] k = final_result[key.replace("wq", "wk")] - v = final_result[key.replace("wq", "wv")] - q = permute(q, config.n_head) - k = permute(k, config.n_local_heads) - final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) - del final_result[key] - del final_result[key.replace("wq", "wk")] - del final_result[key.replace("wq", "wv")] + final_result[key] = permute(q, config.n_head) + final_result[key.replace("wq", "wk")] = permute(k, config.n_local_heads) else: final_result = merged_result print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") diff --git a/tp.py b/tp.py index a151a875..a8ef25f3 100644 --- a/tp.py +++ b/tp.py @@ -4,10 +4,13 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os +from enum import Enum from typing import List, Optional import torch import torch.distributed as dist +from torch.distributed import DeviceMesh +from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module from torch import nn if os.uname().sysname != "Darwin": from torch.distributed import _functional_collectives as funcol @@ -16,7 +19,7 @@ funcol = None from model import Attention, FeedForward, Transformer -from quantize import WeightOnlyInt4Linear +from quantize import WeightOnlyInt4Linear, WeightOnlyInt8Linear def _get_rank() -> int: @@ -33,6 +36,12 @@ def local_break(): def _get_world_size() -> int: return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) +global device_mesh + +def _get_tp_mesh(): + # device_mesh has only TP dimension for now + return device_mesh + def maybe_init_dist() -> Optional[int]: try: # provided by torchrun @@ -48,21 +57,31 @@ def maybe_init_dist() -> Optional[int]: torch.cuda.set_device(rank) dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + + global device_mesh + device_mesh = dist.init_device_mesh( + "cuda", + (world_size,), # Only TP dimension for now + ) return rank +class TPMode(Enum): + MANUAL = 0 + DTENSOR = 1 -def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = []) -> None: +def _apply_tp_linear(linear: nn.Linear, style: str) -> None: rank = _get_rank() world_size = _get_world_size() + tp_mesh = _get_tp_mesh() # Linear's weight matrix is transposed, and is of shape # (linear.out_features, linear.in_features) dim_lookup = { - "colwise": (0, "out_features"), - "rowwise": (1, "in_features") + "colwise": (0, "out_features", ColwiseParallel()), + "rowwise": (1, "in_features", RowwiseParallel()), } assert style in dim_lookup - shard_dim, size_attr = dim_lookup[style] + shard_dim, size_attr, tp_plan = dim_lookup[style] # ensure we can shard evenly assert getattr(linear, size_attr) % world_size == 0 @@ -70,41 +89,36 @@ def shard(x, dim): assert x.size(dim=dim) % world_size == 0 return torch.tensor_split(x, world_size, dim=dim)[rank] - def shard_qkv(qkv, dim, weight_splits): - q, k, v = qkv.split(weight_splits, dim=dim) - q = shard(q, dim) - k = shard(k, dim) - v = shard(v, dim) - return torch.cat((q,k,v), dim=dim) - - # shard - if weight_splits: - # attention - assert len(weight_splits) == 3 - - if isinstance(linear, WeightOnlyInt4Linear): - sharded_weight = shard_qkv(linear.weight, shard_dim, [i//8 for i in weight_splits]) - linear.scales_and_zeros = shard_qkv(linear.scales_and_zeros, 1 - shard_dim, weight_splits) - else: - sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) - if hasattr(linear, "scales") and style == "colwise": - linear.scales = shard_qkv(linear.scales, 0, weight_splits) - else: - sharded_weight = shard(linear.weight, shard_dim) - if isinstance(linear, WeightOnlyInt4Linear): + def shard_scale(linear, shard_dim): + if hasattr(linear, "scales_and_zeros"): linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim) if style == "rowwise": assert linear.scales_and_zeros.shape[0] * 32 == sharded_weight.shape[1] * sharded_weight.shape[2] * sharded_weight.shape[3] assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8 - if hasattr(linear, "scales") and style == "colwise": - linear.scales = shard(linear.scales, 0) + elif hasattr(linear, "scale"): + if style == "colwise": + linear.scales = shard(linear.scales, 0) + + # shard + tp_mode: TPMode + if isinstance(linear, (WeightOnlyInt4Linear, WeightOnlyInt8Linear)): + # TODO: DTensor doesn't have a way to distribute quantized tensor yet. + # Should revisit when that capability is added. + sharded_weight = shard(linear.weight, shard_dim) + linear.weight = nn.Parameter(sharded_weight, requires_grad=False) + shard_scale(linear, shard_dim) + tp_mode = TPMode.MANUAL + else: + # Use DTensor based TP + parallelize_module(linear, tp_mesh, tp_plan) + tp_mode = TPMode.DTENSOR # local_break() - linear.weight = nn.Parameter(sharded_weight, requires_grad=False) setattr(linear, size_attr, getattr(linear, size_attr) // world_size) # shape info should still be synced # assert linear.weight.shape == (linear.out_features, linear.in_features) + return tp_mode def _apply_tp_ffn(mlp: FeedForward) -> None: @@ -112,22 +126,28 @@ def _apply_tp_ffn(mlp: FeedForward) -> None: assert hasattr(mlp, "w3") assert hasattr(mlp, "w2") - _apply_tp_linear(mlp.w1, "colwise") - _apply_tp_linear(mlp.w3, "colwise") - _apply_tp_linear(mlp.w2, "rowwise") + tp_mode = _apply_tp_linear(mlp.w1, "colwise") + tp_mode = _apply_tp_linear(mlp.w3, "colwise") + tp_mode = _apply_tp_linear(mlp.w2, "rowwise") - world_size = _get_world_size() - mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( - output, "sum", list(range(world_size)))) + if tp_mode == TPMode.MANUAL: + # In manual mode, we need to manually add an all-reduce at the end + world_size = _get_world_size() + mlp.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( + output, "sum", list(range(world_size)))) def _apply_tp_attn(attn: Attention) -> None: - assert hasattr(attn, "wqkv") + assert hasattr(attn, "wq") + assert hasattr(attn, "wk") + assert hasattr(attn, "wv") assert hasattr(attn, "wo") kv_size = attn.n_local_heads * attn.head_dim - _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size]) - _apply_tp_linear(attn.wo, "rowwise") + tp_mode = _apply_tp_linear(attn.wq, "colwise") + tp_mode = _apply_tp_linear(attn.wk, "colwise") + tp_mode = _apply_tp_linear(attn.wv, "colwise") + tp_mode = _apply_tp_linear(attn.wo, "rowwise") # overwrite world_size = _get_world_size() @@ -136,8 +156,10 @@ def _apply_tp_attn(attn: Attention) -> None: attn.head_dim = attn.dim // attn.n_head attn.n_local_heads = attn.n_local_heads // world_size - attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( - output[0], "sum", list(range(world_size)))) + if tp_mode == TPMode.MANUAL: + # In manual mode, we need to manually add an all-reduce at the end + attn.register_forward_hook(lambda _module, _input, output: funcol.all_reduce( + output[0], "sum", list(range(world_size)))) def _apply_tp_Transformer(Transformer: Transformer) -> None: