4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import os
7
+ from enum import Enum
7
8
from typing import List , Optional
8
9
9
10
import torch
10
11
import torch .distributed as dist
12
+ from torch .distributed import DeviceMesh
13
+ from torch .distributed .tensor .parallel import ColwiseParallel , RowwiseParallel , parallelize_module
11
14
from torch import nn
12
15
if os .uname ().sysname != "Darwin" :
13
16
from torch .distributed import _functional_collectives as funcol
16
19
funcol = None
17
20
18
21
from model import Attention , FeedForward , Transformer
19
- from quantize import WeightOnlyInt4Linear
22
+ from quantize import WeightOnlyInt4Linear , WeightOnlyInt8Linear
20
23
21
24
22
25
def _get_rank () -> int :
@@ -33,6 +36,12 @@ def local_break():
33
36
def _get_world_size () -> int :
34
37
return int (os .environ .get ("LOCAL_WORLD_SIZE" , "1" ))
35
38
39
+ global device_mesh
40
+
41
+ def _get_tp_mesh ():
42
+ # device_mesh has only TP dimension for now
43
+ return device_mesh
44
+
36
45
def maybe_init_dist () -> Optional [int ]:
37
46
try :
38
47
# provided by torchrun
@@ -48,86 +57,97 @@ def maybe_init_dist() -> Optional[int]:
48
57
49
58
torch .cuda .set_device (rank )
50
59
dist .init_process_group (backend = "nccl" , rank = rank , world_size = world_size )
60
+
61
+ global device_mesh
62
+ device_mesh = dist .init_device_mesh (
63
+ "cuda" ,
64
+ (world_size ,), # Only TP dimension for now
65
+ )
51
66
return rank
52
67
68
+ class TPMode (Enum ):
69
+ MANUAL = 0
70
+ DTENSOR = 1
53
71
54
- def _apply_tp_linear (linear : nn .Linear , style : str , weight_splits : List [ int ] = [] ) -> None :
72
+ def _apply_tp_linear (linear : nn .Linear , style : str ) -> None :
55
73
rank = _get_rank ()
56
74
world_size = _get_world_size ()
75
+ tp_mesh = _get_tp_mesh ()
57
76
58
77
# Linear's weight matrix is transposed, and is of shape
59
78
# (linear.out_features, linear.in_features)
60
79
dim_lookup = {
61
- "colwise" : (0 , "out_features" ),
62
- "rowwise" : (1 , "in_features" )
80
+ "colwise" : (0 , "out_features" , ColwiseParallel () ),
81
+ "rowwise" : (1 , "in_features" , RowwiseParallel ()),
63
82
}
64
83
assert style in dim_lookup
65
- shard_dim , size_attr = dim_lookup [style ]
84
+ shard_dim , size_attr , tp_plan = dim_lookup [style ]
66
85
67
86
# ensure we can shard evenly
68
87
assert getattr (linear , size_attr ) % world_size == 0
69
88
def shard (x , dim ):
70
89
assert x .size (dim = dim ) % world_size == 0
71
90
return torch .tensor_split (x , world_size , dim = dim )[rank ]
72
91
73
- def shard_qkv (qkv , dim , weight_splits ):
74
- q , k , v = qkv .split (weight_splits , dim = dim )
75
- q = shard (q , dim )
76
- k = shard (k , dim )
77
- v = shard (v , dim )
78
- return torch .cat ((q ,k ,v ), dim = dim )
79
-
80
- # shard
81
- if weight_splits :
82
- # attention
83
- assert len (weight_splits ) == 3
84
-
85
- if isinstance (linear , WeightOnlyInt4Linear ):
86
- sharded_weight = shard_qkv (linear .weight , shard_dim , [i // 8 for i in weight_splits ])
87
- linear .scales_and_zeros = shard_qkv (linear .scales_and_zeros , 1 - shard_dim , weight_splits )
88
- else :
89
- sharded_weight = shard_qkv (linear .weight , shard_dim , weight_splits )
90
- if hasattr (linear , "scales" ) and style == "colwise" :
91
- linear .scales = shard_qkv (linear .scales , 0 , weight_splits )
92
- else :
93
- sharded_weight = shard (linear .weight , shard_dim )
94
- if isinstance (linear , WeightOnlyInt4Linear ):
92
+ def shard_scale (linear , shard_dim ):
93
+ if hasattr (linear , "scales_and_zeros" ):
95
94
linear .scales_and_zeros = shard (linear .scales_and_zeros , 1 - shard_dim )
96
95
if style == "rowwise" :
97
96
assert linear .scales_and_zeros .shape [0 ] * 32 == sharded_weight .shape [1 ] * sharded_weight .shape [2 ] * sharded_weight .shape [3 ]
98
97
assert linear .scales_and_zeros .shape [1 ] == sharded_weight .shape [0 ] * 8
99
- if hasattr (linear , "scales" ) and style == "colwise" :
100
- linear .scales = shard (linear .scales , 0 )
98
+ elif hasattr (linear , "scale" ):
99
+ if style == "colwise" :
100
+ linear .scales = shard (linear .scales , 0 )
101
+
102
+ # shard
103
+ tp_mode : TPMode
104
+ if isinstance (linear , (WeightOnlyInt4Linear , WeightOnlyInt8Linear )):
105
+ # TODO: DTensor doesn't have a way to distribute quantized tensor yet.
106
+ # Should revisit when that capability is added.
107
+ sharded_weight = shard (linear .weight , shard_dim )
108
+ linear .weight = nn .Parameter (sharded_weight , requires_grad = False )
109
+ shard_scale (linear , shard_dim )
110
+ tp_mode = TPMode .MANUAL
111
+ else :
112
+ # Use DTensor based TP
113
+ parallelize_module (linear , tp_mesh , tp_plan )
114
+ tp_mode = TPMode .DTENSOR
101
115
102
116
# local_break()
103
- linear .weight = nn .Parameter (sharded_weight , requires_grad = False )
104
117
setattr (linear , size_attr , getattr (linear , size_attr ) // world_size )
105
118
106
119
# shape info should still be synced
107
120
# assert linear.weight.shape == (linear.out_features, linear.in_features)
121
+ return tp_mode
108
122
109
123
110
124
def _apply_tp_ffn (mlp : FeedForward ) -> None :
111
125
assert hasattr (mlp , "w1" )
112
126
assert hasattr (mlp , "w3" )
113
127
assert hasattr (mlp , "w2" )
114
128
115
- _apply_tp_linear (mlp .w1 , "colwise" )
116
- _apply_tp_linear (mlp .w3 , "colwise" )
117
- _apply_tp_linear (mlp .w2 , "rowwise" )
129
+ tp_mode = _apply_tp_linear (mlp .w1 , "colwise" )
130
+ tp_mode = _apply_tp_linear (mlp .w3 , "colwise" )
131
+ tp_mode = _apply_tp_linear (mlp .w2 , "rowwise" )
118
132
119
- world_size = _get_world_size ()
120
- mlp .register_forward_hook (lambda _module , _input , output : funcol .all_reduce (
121
- output , "sum" , list (range (world_size ))))
133
+ if tp_mode == TPMode .MANUAL :
134
+ # In manual mode, we need to manually add an all-reduce at the end
135
+ world_size = _get_world_size ()
136
+ mlp .register_forward_hook (lambda _module , _input , output : funcol .all_reduce (
137
+ output , "sum" , list (range (world_size ))))
122
138
123
139
124
140
def _apply_tp_attn (attn : Attention ) -> None :
125
- assert hasattr (attn , "wqkv" )
141
+ assert hasattr (attn , "wq" )
142
+ assert hasattr (attn , "wk" )
143
+ assert hasattr (attn , "wv" )
126
144
assert hasattr (attn , "wo" )
127
145
128
146
kv_size = attn .n_local_heads * attn .head_dim
129
- _apply_tp_linear (attn .wqkv , "colwise" , [attn .dim , kv_size , kv_size ])
130
- _apply_tp_linear (attn .wo , "rowwise" )
147
+ tp_mode = _apply_tp_linear (attn .wq , "colwise" )
148
+ tp_mode = _apply_tp_linear (attn .wk , "colwise" )
149
+ tp_mode = _apply_tp_linear (attn .wv , "colwise" )
150
+ tp_mode = _apply_tp_linear (attn .wo , "rowwise" )
131
151
132
152
# overwrite
133
153
world_size = _get_world_size ()
@@ -136,8 +156,10 @@ def _apply_tp_attn(attn: Attention) -> None:
136
156
attn .head_dim = attn .dim // attn .n_head
137
157
attn .n_local_heads = attn .n_local_heads // world_size
138
158
139
- attn .register_forward_hook (lambda _module , _input , output : funcol .all_reduce (
140
- output [0 ], "sum" , list (range (world_size ))))
159
+ if tp_mode == TPMode .MANUAL :
160
+ # In manual mode, we need to manually add an all-reduce at the end
161
+ attn .register_forward_hook (lambda _module , _input , output : funcol .all_reduce (
162
+ output [0 ], "sum" , list (range (world_size ))))
141
163
142
164
143
165
def _apply_tp_Transformer (Transformer : Transformer ) -> None :
0 commit comments