12
12
from lightllm .utils .dist_utils import get_current_rank_in_dp , get_dp_world_size
13
13
from lightllm .models .vit .triton_kernel .gelu_vit import gelu_fwd
14
14
from lightllm .models .vit .triton_kernel .rms_norm_vit import rms_norm
15
+ from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
15
16
16
17
17
18
class ViTTransformerLayerInfer :
@@ -60,7 +61,9 @@ def tp_norm(self, input, weight):
60
61
61
62
def _att_norm (self , input , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
62
63
if layer_weight .norm_type == "rms_norm" :
63
- b = rms_norm (input , weight = layer_weight .att_norm_weight_ .weight , eps = self .eps_ )
64
+ b = rms_norm (
65
+ input , weight = layer_weight .att_norm_weight_ .weight , eps = self .eps_ , use_custom_tensor_mananger = True
66
+ )
64
67
else :
65
68
b = torch .nn .functional .layer_norm (
66
69
input ,
@@ -73,7 +76,9 @@ def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
73
76
74
77
def _ffn_norm (self , input , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
75
78
if layer_weight .norm_type == "rms_norm" :
76
- return rms_norm (input , weight = layer_weight .ffn_norm_weight_ .weight , eps = self .eps_ )
79
+ return rms_norm (
80
+ input , weight = layer_weight .ffn_norm_weight_ .weight , eps = self .eps_ , use_custom_tensor_mananger = True
81
+ )
77
82
else :
78
83
return torch .nn .functional .layer_norm (
79
84
input ,
@@ -84,20 +89,28 @@ def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
84
89
)
85
90
86
91
def _qk_norm (self , q , k , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
87
- q_norm = self .tp_norm (q , layer_weight .q_norm_weight_ .weight )
88
- k_norm = self .tp_norm (k , layer_weight .k_norm_weight_ .weight )
92
+ if self .tp_world_size_ > 1 :
93
+ q_norm = self .tp_norm (q , layer_weight .q_norm_weight_ .weight )
94
+ k_norm = self .tp_norm (k , layer_weight .k_norm_weight_ .weight )
95
+ else :
96
+ q_norm = rms_norm (
97
+ q , weight = layer_weight .q_norm_weight_ .weight , eps = self .eps_ , use_custom_tensor_mananger = True
98
+ )
99
+ k_norm = rms_norm (
100
+ k , weight = layer_weight .k_norm_weight_ .weight , eps = self .eps_ , use_custom_tensor_mananger = True
101
+ )
89
102
return q_norm , k_norm
90
103
91
104
def _get_qkv (self , input , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
92
105
batch_size = input .shape [0 ]
93
106
seq_len = input .shape [1 ]
94
- qkv = layer_weight .qkv_proj .mm (input .view (- 1 , self .embed_dim_ ), use_custom_tensor_mananger = False )
107
+ qkv = layer_weight .qkv_proj .mm (input .view (- 1 , self .embed_dim_ ), use_custom_tensor_mananger = True )
95
108
qkv = qkv .view (batch_size , seq_len , 3 , - 1 , self .head_dim_ )
96
109
q , k , v = qkv .unbind (2 )
97
110
return q , k , v
98
111
99
112
def _context_attention_kernel (self , q , k , v ) -> torch .Tensor :
100
- out = torch . empty_like ( q )
113
+ out = g_cache_manager . alloc_tensor ( q . shape , q . dtype , device = q . device )
101
114
batch_size = q .shape [0 ]
102
115
seq_len = q .shape [1 ]
103
116
flash_attention_fwd (q , k , v , out )
@@ -107,30 +120,33 @@ def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor
107
120
batch_size = input .shape [0 ]
108
121
seq_len = input .shape [1 ]
109
122
o_tensor = layer_weight .o_proj .mm (
110
- input .view (- 1 , self .tp_padding_head_num * self .head_dim_ ), use_custom_tensor_mananger = False
123
+ input .view (- 1 , self .tp_padding_head_num * self .head_dim_ ), use_custom_tensor_mananger = True
111
124
)
112
125
if layer_weight .use_ls :
113
- o_tensor *= layer_weight .ls1
126
+ o_tensor . mul_ ( layer_weight .ls1 )
114
127
return o_tensor .reshape ((batch_size , seq_len , - 1 ))
115
128
116
129
def _ffn (self , input , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
117
- fc1 = layer_weight .ffn_1_proj_ .mm (input .view (- 1 , self .embed_dim_ ), use_custom_tensor_mananger = False )
118
- # ffn1_out = torch.nn.functional.gelu(fc1)
119
- ffn1_out = gelu_fwd (fc1 )
130
+ fc1 = layer_weight .ffn_1_proj_ .mm (input .view (- 1 , self .embed_dim_ ), use_custom_tensor_mananger = True )
120
131
input_shape = input .shape
121
132
input = None
122
- ffn2_out = layer_weight .ffn_2_proj_ .mm (ffn1_out , use_custom_tensor_mananger = False )
123
- if layer_weight .use_ls :
124
- ffn2_out *= layer_weight .ls2
133
+ ffn1_out = gelu_fwd (fc1 , use_custom_tensor_mananger = True )
134
+ ffn2_out = layer_weight .ffn_2_proj_ .mm (ffn1_out , use_custom_tensor_mananger = True )
125
135
ffn1_out = None
136
+ if layer_weight .use_ls :
137
+ ffn2_out .mul_ (layer_weight .ls2 )
126
138
return ffn2_out .reshape (input_shape )
127
139
128
140
def _context_attention (self , input_embding , layer_weight ):
129
141
input1 = self ._att_norm (input_embding , layer_weight )
130
142
q , k , v = self ._get_qkv (input1 , layer_weight )
143
+ input1 = None
131
144
if layer_weight .qk_norm :
132
145
q , k = self ._qk_norm (q , k , layer_weight )
133
146
o = self ._context_attention_kernel (q , k , v )
147
+ q = None
148
+ k = None
149
+ v = None
134
150
o = self ._get_o (o , layer_weight )
135
151
if self .tp_world_size_ > 1 :
136
152
dist .all_reduce (o , op = dist .ReduceOp .SUM , async_op = False )
0 commit comments