Skip to content

Commit d45574e

Browse files
authored
reduce inference memory of vit (#819)
1 parent 281fbc7 commit d45574e

File tree

5 files changed

+55
-23
lines changed

5 files changed

+55
-23
lines changed

lightllm/models/vit/layer_infer/post_layer_infer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ def forward(self, vit_embeds, layer_weight: ViTPreAndPostLayerWeight):
4545
layer_weight.mlp1_1_bias_, vit_embeds_norm.view(-1, vit_embeds_norm.shape[-1]), layer_weight.mlp1_1_weight_
4646
)
4747

48-
# vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1)
49-
vit_embeds_gelu = gelu_fwd(vit_embeds_1)
48+
vit_embeds_gelu = gelu_fwd(vit_embeds_1, use_custom_tensor_mananger=True)
5049

5150
vit_embeds_out = torch.addmm(
5251
layer_weight.mlp1_3_bias_,

lightllm/models/vit/layer_infer/transformer_layer_infer.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
1313
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
1414
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
15+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1516

1617

1718
class ViTTransformerLayerInfer:
@@ -60,7 +61,9 @@ def tp_norm(self, input, weight):
6061

6162
def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
6263
if layer_weight.norm_type == "rms_norm":
63-
b = rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
64+
b = rms_norm(
65+
input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
66+
)
6467
else:
6568
b = torch.nn.functional.layer_norm(
6669
input,
@@ -73,7 +76,9 @@ def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
7376

7477
def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
7578
if layer_weight.norm_type == "rms_norm":
76-
return rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)
79+
return rms_norm(
80+
input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
81+
)
7782
else:
7883
return torch.nn.functional.layer_norm(
7984
input,
@@ -84,20 +89,28 @@ def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
8489
)
8590

8691
def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
87-
q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight)
88-
k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight)
92+
if self.tp_world_size_ > 1:
93+
q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight)
94+
k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight)
95+
else:
96+
q_norm = rms_norm(
97+
q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
98+
)
99+
k_norm = rms_norm(
100+
k, weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
101+
)
89102
return q_norm, k_norm
90103

91104
def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
92105
batch_size = input.shape[0]
93106
seq_len = input.shape[1]
94-
qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False)
107+
qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=True)
95108
qkv = qkv.view(batch_size, seq_len, 3, -1, self.head_dim_)
96109
q, k, v = qkv.unbind(2)
97110
return q, k, v
98111

99112
def _context_attention_kernel(self, q, k, v) -> torch.Tensor:
100-
out = torch.empty_like(q)
113+
out = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
101114
batch_size = q.shape[0]
102115
seq_len = q.shape[1]
103116
flash_attention_fwd(q, k, v, out)
@@ -107,30 +120,33 @@ def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor
107120
batch_size = input.shape[0]
108121
seq_len = input.shape[1]
109122
o_tensor = layer_weight.o_proj.mm(
110-
input.view(-1, self.tp_padding_head_num * self.head_dim_), use_custom_tensor_mananger=False
123+
input.view(-1, self.tp_padding_head_num * self.head_dim_), use_custom_tensor_mananger=True
111124
)
112125
if layer_weight.use_ls:
113-
o_tensor *= layer_weight.ls1
126+
o_tensor.mul_(layer_weight.ls1)
114127
return o_tensor.reshape((batch_size, seq_len, -1))
115128

116129
def _ffn(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
117-
fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False)
118-
# ffn1_out = torch.nn.functional.gelu(fc1)
119-
ffn1_out = gelu_fwd(fc1)
130+
fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=True)
120131
input_shape = input.shape
121132
input = None
122-
ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out, use_custom_tensor_mananger=False)
123-
if layer_weight.use_ls:
124-
ffn2_out *= layer_weight.ls2
133+
ffn1_out = gelu_fwd(fc1, use_custom_tensor_mananger=True)
134+
ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out, use_custom_tensor_mananger=True)
125135
ffn1_out = None
136+
if layer_weight.use_ls:
137+
ffn2_out.mul_(layer_weight.ls2)
126138
return ffn2_out.reshape(input_shape)
127139

128140
def _context_attention(self, input_embding, layer_weight):
129141
input1 = self._att_norm(input_embding, layer_weight)
130142
q, k, v = self._get_qkv(input1, layer_weight)
143+
input1 = None
131144
if layer_weight.qk_norm:
132145
q, k = self._qk_norm(q, k, layer_weight)
133146
o = self._context_attention_kernel(q, k, v)
147+
q = None
148+
k = None
149+
v = None
134150
o = self._get_o(o, layer_weight)
135151
if self.tp_world_size_ > 1:
136152
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)

lightllm/models/vit/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from rpyc.utils.classic import obtain
2020
from lightllm.common.quantization import Quantcfg
2121
from lightllm.utils.dist_utils import get_dp_world_size
22+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
2223

2324

2425
logger = init_logger(__name__)
@@ -128,11 +129,14 @@ def _init_datatype(self):
128129
else:
129130
raise ValueError(f"Unsupport datatype {self.data_type}!")
130131

132+
@torch.no_grad()
131133
def forward(self, pixel_values):
134+
g_cache_manager.cache_env_in()
132135
input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight)
133136
for i in range(self.layers_num + self.select_layer + 1):
134137
input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i])
135138
input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight)
139+
g_cache_manager.cache_env_out()
136140
return input_embs
137141

138142
@torch.no_grad()

lightllm/models/vit/triton_kernel/gelu_vit.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import triton
33
import triton.language as tl
4+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
45

56

67
@triton.jit
@@ -21,8 +22,14 @@ def gelu_kernel(output_ptr, input_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
2122
tl.store(output_ptr + offsets, output, mask=mask)
2223

2324

24-
def gelu_fwd(input):
25-
output = torch.empty_like(input)
25+
def gelu_fwd(input, use_custom_tensor_mananger=False):
26+
if use_custom_tensor_mananger:
27+
shape = input.shape
28+
dtype = input.dtype
29+
device = input.device
30+
output = g_cache_manager.alloc_tensor(shape, dtype, device=device)
31+
else:
32+
output = torch.empty_like(input)
2633
n_elements = input.numel()
2734
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
2835
gelu_kernel[grid](output, input, n_elements, BLOCK_SIZE=1024)

lightllm/models/vit/triton_kernel/rms_norm_vit.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import triton
33
import triton.language as tl
44
from torch import Tensor
5+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
56

67

78
@triton.jit
@@ -32,26 +33,31 @@ def rms_norm_kernel(
3233
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
3334

3435

35-
def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-5):
36+
def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-5, use_custom_tensor_mananger: bool = False):
3637
"""Rms norm."""
3738
feat_size = weight.shape[0]
3839
seq_len = hidden_states.numel() // hidden_states.size(-1)
3940
input_stride = hidden_states.stride(-2)
4041

4142
BLOCK_N = triton.next_power_of_2(feat_size)
42-
out = torch.empty_like(hidden_states)
43+
if use_custom_tensor_mananger:
44+
shape = hidden_states.shape
45+
dtype = hidden_states.dtype
46+
device = hidden_states.device
47+
output = g_cache_manager.alloc_tensor(shape, dtype, device=device)
48+
else:
49+
output = torch.empty_like(hidden_states)
4350

4451
grid = (seq_len,)
4552
rms_norm_kernel[grid](
4653
hidden_states,
4754
weight,
48-
out,
55+
output,
4956
input_row_stride=input_stride,
5057
eps=eps,
5158
N_COLS=feat_size,
5259
BLOCK_N=BLOCK_N,
5360
num_warps=4,
5461
num_stages=3,
5562
)
56-
57-
return out
63+
return output

0 commit comments

Comments
 (0)