Skip to content

Commit b1d319f

Browse files
authored
Merge pull request #312 from santlchogva/dev
Support MiniMonkey model
2 parents 7960be4 + f2b7e0e commit b1d319f

File tree

3 files changed

+337
-0
lines changed

3 files changed

+337
-0
lines changed

api/adapter/patcher.py

+2
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def patch_config(
164164

165165

166166
def patch_model(model: "PreTrainedModel") -> None:
167+
if model.config.model_type == "internvl_chat":
168+
return
167169
if model.config.model_type == "minicpmv":
168170
return
169171
if "GenerationMixin" not in str(model.generate.__func__):

api/engine/hf.py

+5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from api.templates import get_template
3737
from api.templates.glm import generate_stream_chatglm, generate_stream_chatglm_v3
3838
from api.templates.minicpm import generate_stream_minicpm_v
39+
from api.templates.minimonkey import generate_stream_minimonkey
3940
from api.templates.stream import generate_stream
4041
from api.templates.utils import get_context_length
4142
from api.utils import create_error_response
@@ -78,6 +79,8 @@ def __init__(
7879
self.generate_stream_func = generate_stream_chatglm
7980
elif self.model.config.model_type == "minicpmv":
8081
self.generate_stream_func = generate_stream_minicpm_v
82+
elif self.model.config.model_type == "internvl_chat":
83+
self.generate_stream_func = generate_stream_minimonkey
8184

8285
logger.info(f"Using {self.model_name} Model for Chat!")
8386
logger.info(f"Using {self.template} for Chat!")
@@ -98,6 +101,8 @@ def _generate(self, params: Dict[str, Any]) -> Iterator[dict]:
98101
else:
99102
if self.model.config.model_type == "minicpmv":
100103
inputs = prompt_or_messages
104+
elif self.model.config.model_type == "internvl_chat":
105+
inputs = prompt_or_messages
101106
else:
102107
inputs = self.template.convert_messages_to_ids(
103108
prompt_or_messages,

api/templates/minimonkey.py

+330
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
from __future__ import annotations
2+
3+
import gc
4+
import time
5+
import uuid
6+
from typing import (
7+
Any,
8+
Dict,
9+
List,
10+
Iterator,
11+
TYPE_CHECKING,
12+
)
13+
14+
import torch
15+
16+
from api.protocol import ChatCompletionMessageParam
17+
18+
if TYPE_CHECKING:
19+
from transformers import PreTrainedTokenizer, PreTrainedModel
20+
21+
22+
import queue
23+
from threading import Thread
24+
import torchvision.transforms as T
25+
import transformers
26+
from torchvision.transforms.functional import InterpolationMode
27+
from transformers import BitsAndBytesConfig, TextIteratorStreamer
28+
29+
transformers.logging.set_verbosity_error()
30+
31+
# mx262/MiniMonkey
32+
33+
IMG_START_TOKEN='<img>'
34+
IMG_END_TOKEN='</img>'
35+
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'
36+
37+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
38+
IMAGENET_STD = (0.229, 0.224, 0.225)
39+
40+
def build_transform(input_size):
41+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
42+
transform = T.Compose([
43+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
44+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
45+
T.ToTensor(),
46+
T.Normalize(mean=MEAN, std=STD)
47+
])
48+
return transform
49+
50+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
51+
best_ratio_diff = float('inf')
52+
best_ratio = (1, 1)
53+
area = width * height
54+
for ratio in target_ratios:
55+
target_aspect_ratio = ratio[0] / ratio[1]
56+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
57+
if ratio_diff < best_ratio_diff:
58+
best_ratio_diff = ratio_diff
59+
best_ratio = ratio
60+
elif ratio_diff == best_ratio_diff:
61+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
62+
best_ratio = ratio
63+
return best_ratio
64+
65+
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
66+
orig_width, orig_height = image.size
67+
aspect_ratio = orig_width / orig_height
68+
69+
# calculate the existing image aspect ratio
70+
target_ratios = set(
71+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
72+
i * j <= max_num and i * j >= min_num)
73+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
74+
75+
# find the closest aspect ratio to the target
76+
target_aspect_ratio = find_closest_aspect_ratio(
77+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
78+
79+
# calculate the target width and height
80+
target_width = image_size * target_aspect_ratio[0]
81+
target_height = image_size * target_aspect_ratio[1]
82+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
83+
84+
# resize the image
85+
resized_img = image.resize((target_width, target_height))
86+
processed_images = []
87+
for i in range(blocks):
88+
box = (
89+
(i % (target_width // image_size)) * image_size,
90+
(i // (target_width // image_size)) * image_size,
91+
((i % (target_width // image_size)) + 1) * image_size,
92+
((i // (target_width // image_size)) + 1) * image_size
93+
)
94+
# split the image
95+
split_img = resized_img.crop(box)
96+
processed_images.append(split_img)
97+
assert len(processed_images) == blocks
98+
if use_thumbnail and len(processed_images) != 1:
99+
thumbnail_img = image.resize((image_size, image_size))
100+
processed_images.append(thumbnail_img)
101+
return processed_images, target_aspect_ratio
102+
103+
104+
def dynamic_preprocess2(image, min_num=1, max_num=12, prior_aspect_ratio=None, image_size=448, use_thumbnail=False):
105+
orig_width, orig_height = image.size
106+
aspect_ratio = orig_width / orig_height
107+
108+
# calculate the existing image aspect ratio
109+
target_ratios = set(
110+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
111+
i * j <= max_num and i * j >= min_num)
112+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
113+
new_target_ratios = []
114+
for i in target_ratios:
115+
if prior_aspect_ratio[0]%i[0] or prior_aspect_ratio[1]%i[1]:
116+
new_target_ratios.append(i)
117+
else:
118+
continue
119+
# find the closest aspect ratio to the target
120+
target_aspect_ratio = find_closest_aspect_ratio(
121+
aspect_ratio, new_target_ratios, orig_width, orig_height, image_size)
122+
# calculate the target width and height
123+
target_width = image_size * target_aspect_ratio[0]
124+
target_height = image_size * target_aspect_ratio[1]
125+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
126+
127+
# resize the image
128+
resized_img = image.resize((target_width, target_height))
129+
processed_images = []
130+
for i in range(blocks):
131+
box = (
132+
(i % (target_width // image_size)) * image_size,
133+
(i // (target_width // image_size)) * image_size,
134+
((i % (target_width // image_size)) + 1) * image_size,
135+
((i // (target_width // image_size)) + 1) * image_size
136+
)
137+
# split the image
138+
split_img = resized_img.crop(box)
139+
processed_images.append(split_img)
140+
assert len(processed_images) == blocks
141+
if use_thumbnail and len(processed_images) != 1:
142+
thumbnail_img = image.resize((image_size, image_size))
143+
processed_images.append(thumbnail_img)
144+
return processed_images
145+
146+
def load_image(image, input_size=448, min_num=1, max_num=12):
147+
image = image.convert('RGB')
148+
transform = build_transform(input_size=input_size)
149+
images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
150+
pixel_values = [transform(image) for image in images]
151+
pixel_values = torch.stack(pixel_values)
152+
return pixel_values, target_aspect_ratio
153+
154+
def load_image2(image, input_size=448, min_num=1, max_num=12, target_aspect_ratio=None):
155+
image = image.convert('RGB')
156+
transform = build_transform(input_size=input_size)
157+
images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio)
158+
pixel_values = [transform(image) for image in images]
159+
pixel_values = torch.stack(pixel_values)
160+
return pixel_values
161+
162+
163+
@torch.inference_mode()
164+
def generate_stream_minimonkey(
165+
model: "PreTrainedModel",
166+
tokenizer: "PreTrainedTokenizer",
167+
params: Dict[str, Any],
168+
) -> Iterator:
169+
"""
170+
Generates text in a streaming manner using the ChatGLM model.
171+
172+
Args:
173+
model: The pre-trained model.
174+
tokenizer: The tokenizer used for tokenizing the input.
175+
params: A dictionary containing the input parameters.
176+
177+
Yields:
178+
A dictionary representing each generated text completion.
179+
180+
"""
181+
inputs = params["inputs"]
182+
model_name = params.get("model", "llm")
183+
184+
model.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
185+
186+
images, prompt = chatml_prompt_from_messages(inputs)
187+
188+
# set the max number of tiles in `max_num`, XXX make an option
189+
pixel_values, target_aspect_ratio = load_image(images[-1], min_num=4, max_num=12)
190+
pixel_values2 = load_image2(images[-1], min_num=3, max_num=7, target_aspect_ratio=target_aspect_ratio)
191+
pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0).to(device=model.device, dtype=model.dtype)
192+
193+
for num_patches in [pixel_values.shape[0]]:
194+
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * model.num_image_token * num_patches + IMG_END_TOKEN
195+
prompt = prompt.replace('<image>', image_tokens, 1)
196+
197+
model_inputs = tokenizer(prompt, return_tensors='pt')
198+
input_ids = model_inputs['input_ids'].to(model.device)
199+
attention_mask = model_inputs['attention_mask'].to(model.device)
200+
201+
inputs = dict(
202+
input_ids=input_ids,
203+
pixel_values=pixel_values,
204+
attention_mask=attention_mask,
205+
target_aspect_ratio=target_aspect_ratio,
206+
)
207+
208+
eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
209+
new_params = dict(eos_token_id=[eos_token_id, tokenizer.eos_token_id],
210+
temperature = float(params.get("temperature", 1.0)),
211+
max_new_tokens = int(params.get("max_tokens", 256)),
212+
repetition_penalty = float(params.get("repetition_penalty", 1.0)),
213+
top_p = float(params.get("top_p", 1.0)),
214+
top_k = int(params.get("top_k", 50)))
215+
216+
generation_kwargs = dict(
217+
**inputs,
218+
**new_params,
219+
)
220+
221+
# Todo: fix length for prompt
222+
input_echo_len = 0
223+
224+
generated_text, previous_text = "", ""
225+
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
226+
created: int = int(time.time())
227+
for i, new_text in enumerate(threaded_streaming_generator(generate=model.generate, tokenizer=tokenizer, generation_kwargs=generation_kwargs)):
228+
generated_text += new_text
229+
delta_text = generated_text[len(previous_text):]
230+
previous_text = generated_text
231+
yield {
232+
"id": completion_id,
233+
"object": "text_completion",
234+
"created": created,
235+
"model": model_name,
236+
"delta": delta_text,
237+
"text": generated_text,
238+
"logprobs": None,
239+
"finish_reason": None,
240+
"usage": {
241+
"prompt_tokens": input_echo_len,
242+
"completion_tokens": i,
243+
"total_tokens": input_echo_len + i,
244+
},
245+
}
246+
247+
gc.collect()
248+
torch.cuda.empty_cache()
249+
250+
251+
def chatml_prompt_from_messages(messages: list[ChatCompletionMessageParam], img_tok = "<image>\n"):
252+
prompt = ''
253+
images = []
254+
generation_msg = "<|im_start|>assistant\n"
255+
256+
if messages and messages[-1]['role'] == 'assistant':
257+
generation_msg += messages[-1]['content'][0].text
258+
messages.pop(-1)
259+
260+
for m in messages:
261+
if m['role'] == 'user':
262+
text = ''
263+
has_image = False
264+
265+
for c in m['content']:
266+
if c['type'] == 'image_url':
267+
images.extend([ url_to_image(c['image_url']['url']) ])
268+
has_image = True
269+
if c['type'] == 'text':
270+
text = c['text']
271+
272+
img_tag = img_tok if has_image else ''
273+
prompt += f"<|im_start|>user\n{img_tag}{text}<|im_end|>"
274+
elif m['role'] == 'assistant':
275+
for c in m['content']:
276+
if c['type'] == 'text':
277+
prompt += f"<|im_start|>assistant\n{c['text']}<|im_end|>"
278+
elif m['role'] == 'system':
279+
for c in m['content']:
280+
if c['type'] == 'text':
281+
prompt += f"<|im_start|>system\n{c['text']}<|im_end|>"
282+
283+
prompt += generation_msg
284+
285+
return images, prompt
286+
287+
288+
def url_to_image(image_url: str):
289+
from PIL import Image
290+
from io import BytesIO
291+
292+
if image_url.startswith("data:"):
293+
import base64
294+
295+
image_bytes = base64.b64decode(image_url.split(",")[1])
296+
else:
297+
import urllib.request
298+
299+
with urllib.request.urlopen(image_url) as f:
300+
image_bytes = f.read()
301+
302+
return Image.open(BytesIO(image_bytes)).convert("RGB")
303+
304+
305+
def threaded_streaming_generator(generate, tokenizer, generation_kwargs):
306+
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True, timeout=60)
307+
308+
generation_kwargs['streamer'] = streamer
309+
310+
exq = queue.Queue()
311+
312+
def wrapper():
313+
try:
314+
with torch.no_grad():
315+
generate(**generation_kwargs)
316+
317+
except Exception as e:
318+
#logger.exception(e)
319+
exq.put(e)
320+
streamer.end()
321+
322+
t = Thread(target=wrapper, daemon=True)
323+
t.start()
324+
325+
for text in streamer:
326+
if text:
327+
yield text
328+
329+
if not exq.empty():
330+
raise exq.get_nowait()

0 commit comments

Comments
 (0)