Skip to content

Commit 84d35d3

Browse files
black-elevenyihuiwen
and
yihuiwen
authoredApr 12, 2025
support tarsier2 (#821)
Co-authored-by: yihuiwen <yihuiwen@sensetime.com>
1 parent 16eb6bf commit 84d35d3

File tree

9 files changed

+491
-2
lines changed

9 files changed

+491
-2
lines changed
 

‎lightllm/models/tarsier2/__init__.py

Whitespace-only changes.

‎lightllm/models/tarsier2/layer_weights/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
4+
5+
from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
6+
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
7+
8+
9+
# add key: language_model.xxx -> xxx
10+
# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
11+
def rename_weight_keys(weights):
12+
prefix = "language_model."
13+
keys = list(weights.keys())
14+
for k in keys:
15+
if prefix in k:
16+
weights[k[len(prefix) :]] = weights[k]
17+
18+
19+
class Tarsier2Qwen2PreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
20+
def __init__(self, data_type, network_config, mode):
21+
super().__init__(data_type, network_config, mode)
22+
return
23+
24+
def load_hf_weights(self, weights):
25+
rename_weight_keys(weights)
26+
super().load_hf_weights(weights)
27+
return
28+
29+
30+
class Tarsier2LlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
31+
def __init__(self, data_type, network_config, mode):
32+
super().__init__(data_type, network_config, mode)
33+
return
34+
35+
def load_hf_weights(self, weights):
36+
rename_weight_keys(weights)
37+
super().load_hf_weights(weights)
38+
return

‎lightllm/models/tarsier2/model.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import json
2+
import os
3+
4+
from lightllm.common.build_utils import repair_config
5+
from lightllm.models.llama.model import LlamaTpPartModel
6+
from lightllm.models.qwen2.model import Qwen2TpPartModel
7+
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
8+
from lightllm.models.qwen2_vl.vision_process import smart_resize
9+
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
10+
from lightllm.models.tarsier2.layer_weights.pre_and_post_layer_weight import (
11+
Tarsier2Qwen2PreAndPostLayerWeight,
12+
Tarsier2LlamaPreAndPostLayerWeight,
13+
)
14+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
15+
from lightllm.server.core.objs import SamplingParams
16+
17+
18+
class Tarsier2Tokenizer:
19+
def __init__(self, tokenizer=None, image_processor=None, **kwargs):
20+
self.tokenizer = tokenizer
21+
self.image_processor = image_processor
22+
self.image_start_id = kwargs["model_cfg"]["text_config"]["vision_start_token_id"]
23+
self.image_end_id = kwargs["model_cfg"]["text_config"]["vision_end_token_id"]
24+
self.image_token_id = kwargs["model_cfg"]["text_config"]["image_token_id"]
25+
26+
def init_imageItem_extral_params(
27+
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
28+
):
29+
return
30+
31+
def get_image_token_length(self, img: ImageItem):
32+
width = img.image_w
33+
height = img.image_h
34+
resized_height, resized_width = smart_resize(height=height, width=width)
35+
self.patch_size = self.image_processor.patch_size
36+
self.merge_size = self.image_processor.merge_size
37+
grid_t = 1
38+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
39+
merge_length = self.merge_size ** 2
40+
self.token_num = (grid_t * grid_h * grid_w) // merge_length
41+
self.image_length = self.token_num
42+
return self.image_length
43+
44+
def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
45+
46+
origin_ids = self.tokenizer.encode(prompt)
47+
48+
# <img><image_pad></img> -> <img></img>
49+
origin_ids = [token for token in origin_ids if token != self.image_token_id]
50+
# <img></img> --> <img>id,id+1...id+num</img>
51+
input_ids = []
52+
image_id = 0
53+
start_idx = 0
54+
while True:
55+
try:
56+
start_idx = origin_ids.index(self.image_start_id, start_idx)
57+
if start_idx + 1 >= len(origin_ids):
58+
break
59+
if origin_ids[start_idx + 1] == self.image_end_id:
60+
input_ids.extend(origin_ids[: start_idx + 1])
61+
token_id = multimodal_params.images[image_id].token_id
62+
token_num = multimodal_params.images[image_id].token_num
63+
input_ids.extend(range(token_id, token_id + token_num))
64+
input_ids.append(self.image_end_id)
65+
origin_ids = origin_ids[start_idx + 2 :]
66+
start_idx = 0
67+
image_id += 1
68+
else:
69+
raise ValueError("image token error")
70+
except ValueError:
71+
break
72+
input_ids.extend(origin_ids[start_idx:])
73+
return input_ids
74+
75+
def __getattr__(self, name):
76+
if name != "encode":
77+
return getattr(self.tokenizer, name)
78+
return self.encode
79+
80+
pass
81+
82+
83+
class Tarsier2Qwen2TpPartModel(Qwen2TpPartModel):
84+
# weight class
85+
pre_and_post_weight_class = Tarsier2Qwen2PreAndPostLayerWeight
86+
87+
# infer class
88+
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
89+
90+
def __init__(self, kvargs):
91+
super().__init__(kvargs)
92+
return
93+
94+
def _init_config(self):
95+
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
96+
self.config = json.load(json_file)["text_config"]
97+
# rename keys
98+
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
99+
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
100+
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
101+
return
102+
103+
104+
class Tarsier2Qwen2VLTpPartModel(Qwen2VLTpPartModel):
105+
# weight class
106+
pre_and_post_weight_class = Tarsier2Qwen2PreAndPostLayerWeight
107+
108+
# infer class
109+
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
110+
111+
def __init__(self, kvargs):
112+
super().__init__(kvargs)
113+
return
114+
115+
def _init_config(self):
116+
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
117+
self.config = json.load(json_file)["text_config"]
118+
# rename keys
119+
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
120+
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
121+
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
122+
return
123+
124+
125+
class Tarsier2LlamaTpPartModel(LlamaTpPartModel):
126+
127+
pre_and_post_weight_class = Tarsier2LlamaPreAndPostLayerWeight
128+
129+
# infer class
130+
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
131+
132+
def __init__(self, kvargs):
133+
super().__init__(kvargs)
134+
return
135+
136+
def _init_config(self):
137+
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
138+
self.config = json.load(json_file)["text_config"]
139+
# rename keys
140+
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
141+
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
142+
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
143+
return

0 commit comments

Comments
 (0)