1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import time
5
+ import uuid
6
+ from typing import (
7
+ Any ,
8
+ Dict ,
9
+ List ,
10
+ Iterator ,
11
+ TYPE_CHECKING ,
12
+ )
13
+
14
+ import torch
15
+
16
+ from api .protocol import ChatCompletionMessageParam
17
+
18
+ if TYPE_CHECKING :
19
+ from transformers import PreTrainedTokenizer , PreTrainedModel
20
+
21
+
22
+ import queue
23
+ from threading import Thread
24
+ import torchvision .transforms as T
25
+ import transformers
26
+ from torchvision .transforms .functional import InterpolationMode
27
+ from transformers import BitsAndBytesConfig , TextIteratorStreamer
28
+
29
+ transformers .logging .set_verbosity_error ()
30
+
31
+ # mx262/MiniMonkey
32
+
33
+ IMG_START_TOKEN = '<img>'
34
+ IMG_END_TOKEN = '</img>'
35
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
36
+
37
+ IMAGENET_MEAN = (0.485 , 0.456 , 0.406 )
38
+ IMAGENET_STD = (0.229 , 0.224 , 0.225 )
39
+
40
+ def build_transform (input_size ):
41
+ MEAN , STD = IMAGENET_MEAN , IMAGENET_STD
42
+ transform = T .Compose ([
43
+ T .Lambda (lambda img : img .convert ('RGB' ) if img .mode != 'RGB' else img ),
44
+ T .Resize ((input_size , input_size ), interpolation = InterpolationMode .BICUBIC ),
45
+ T .ToTensor (),
46
+ T .Normalize (mean = MEAN , std = STD )
47
+ ])
48
+ return transform
49
+
50
+ def find_closest_aspect_ratio (aspect_ratio , target_ratios , width , height , image_size ):
51
+ best_ratio_diff = float ('inf' )
52
+ best_ratio = (1 , 1 )
53
+ area = width * height
54
+ for ratio in target_ratios :
55
+ target_aspect_ratio = ratio [0 ] / ratio [1 ]
56
+ ratio_diff = abs (aspect_ratio - target_aspect_ratio )
57
+ if ratio_diff < best_ratio_diff :
58
+ best_ratio_diff = ratio_diff
59
+ best_ratio = ratio
60
+ elif ratio_diff == best_ratio_diff :
61
+ if area > 0.5 * image_size * image_size * ratio [0 ] * ratio [1 ]:
62
+ best_ratio = ratio
63
+ return best_ratio
64
+
65
+ def dynamic_preprocess (image , min_num = 1 , max_num = 12 , image_size = 448 , use_thumbnail = False ):
66
+ orig_width , orig_height = image .size
67
+ aspect_ratio = orig_width / orig_height
68
+
69
+ # calculate the existing image aspect ratio
70
+ target_ratios = set (
71
+ (i , j ) for n in range (min_num , max_num + 1 ) for i in range (1 , n + 1 ) for j in range (1 , n + 1 ) if
72
+ i * j <= max_num and i * j >= min_num )
73
+ target_ratios = sorted (target_ratios , key = lambda x : x [0 ] * x [1 ])
74
+
75
+ # find the closest aspect ratio to the target
76
+ target_aspect_ratio = find_closest_aspect_ratio (
77
+ aspect_ratio , target_ratios , orig_width , orig_height , image_size )
78
+
79
+ # calculate the target width and height
80
+ target_width = image_size * target_aspect_ratio [0 ]
81
+ target_height = image_size * target_aspect_ratio [1 ]
82
+ blocks = target_aspect_ratio [0 ] * target_aspect_ratio [1 ]
83
+
84
+ # resize the image
85
+ resized_img = image .resize ((target_width , target_height ))
86
+ processed_images = []
87
+ for i in range (blocks ):
88
+ box = (
89
+ (i % (target_width // image_size )) * image_size ,
90
+ (i // (target_width // image_size )) * image_size ,
91
+ ((i % (target_width // image_size )) + 1 ) * image_size ,
92
+ ((i // (target_width // image_size )) + 1 ) * image_size
93
+ )
94
+ # split the image
95
+ split_img = resized_img .crop (box )
96
+ processed_images .append (split_img )
97
+ assert len (processed_images ) == blocks
98
+ if use_thumbnail and len (processed_images ) != 1 :
99
+ thumbnail_img = image .resize ((image_size , image_size ))
100
+ processed_images .append (thumbnail_img )
101
+ return processed_images , target_aspect_ratio
102
+
103
+
104
+ def dynamic_preprocess2 (image , min_num = 1 , max_num = 12 , prior_aspect_ratio = None , image_size = 448 , use_thumbnail = False ):
105
+ orig_width , orig_height = image .size
106
+ aspect_ratio = orig_width / orig_height
107
+
108
+ # calculate the existing image aspect ratio
109
+ target_ratios = set (
110
+ (i , j ) for n in range (min_num , max_num + 1 ) for i in range (1 , n + 1 ) for j in range (1 , n + 1 ) if
111
+ i * j <= max_num and i * j >= min_num )
112
+ target_ratios = sorted (target_ratios , key = lambda x : x [0 ] * x [1 ])
113
+ new_target_ratios = []
114
+ for i in target_ratios :
115
+ if prior_aspect_ratio [0 ]% i [0 ] or prior_aspect_ratio [1 ]% i [1 ]:
116
+ new_target_ratios .append (i )
117
+ else :
118
+ continue
119
+ # find the closest aspect ratio to the target
120
+ target_aspect_ratio = find_closest_aspect_ratio (
121
+ aspect_ratio , new_target_ratios , orig_width , orig_height , image_size )
122
+ # calculate the target width and height
123
+ target_width = image_size * target_aspect_ratio [0 ]
124
+ target_height = image_size * target_aspect_ratio [1 ]
125
+ blocks = target_aspect_ratio [0 ] * target_aspect_ratio [1 ]
126
+
127
+ # resize the image
128
+ resized_img = image .resize ((target_width , target_height ))
129
+ processed_images = []
130
+ for i in range (blocks ):
131
+ box = (
132
+ (i % (target_width // image_size )) * image_size ,
133
+ (i // (target_width // image_size )) * image_size ,
134
+ ((i % (target_width // image_size )) + 1 ) * image_size ,
135
+ ((i // (target_width // image_size )) + 1 ) * image_size
136
+ )
137
+ # split the image
138
+ split_img = resized_img .crop (box )
139
+ processed_images .append (split_img )
140
+ assert len (processed_images ) == blocks
141
+ if use_thumbnail and len (processed_images ) != 1 :
142
+ thumbnail_img = image .resize ((image_size , image_size ))
143
+ processed_images .append (thumbnail_img )
144
+ return processed_images
145
+
146
+ def load_image (image , input_size = 448 , min_num = 1 , max_num = 12 ):
147
+ image = image .convert ('RGB' )
148
+ transform = build_transform (input_size = input_size )
149
+ images , target_aspect_ratio = dynamic_preprocess (image , image_size = input_size , use_thumbnail = True , min_num = min_num , max_num = max_num )
150
+ pixel_values = [transform (image ) for image in images ]
151
+ pixel_values = torch .stack (pixel_values )
152
+ return pixel_values , target_aspect_ratio
153
+
154
+ def load_image2 (image , input_size = 448 , min_num = 1 , max_num = 12 , target_aspect_ratio = None ):
155
+ image = image .convert ('RGB' )
156
+ transform = build_transform (input_size = input_size )
157
+ images = dynamic_preprocess2 (image , image_size = input_size , use_thumbnail = True , min_num = min_num , max_num = max_num , prior_aspect_ratio = target_aspect_ratio )
158
+ pixel_values = [transform (image ) for image in images ]
159
+ pixel_values = torch .stack (pixel_values )
160
+ return pixel_values
161
+
162
+
163
+ @torch .inference_mode ()
164
+ def generate_stream_minimonkey (
165
+ model : "PreTrainedModel" ,
166
+ tokenizer : "PreTrainedTokenizer" ,
167
+ params : Dict [str , Any ],
168
+ ) -> Iterator :
169
+ """
170
+ Generates text in a streaming manner using the ChatGLM model.
171
+
172
+ Args:
173
+ model: The pre-trained model.
174
+ tokenizer: The tokenizer used for tokenizing the input.
175
+ params: A dictionary containing the input parameters.
176
+
177
+ Yields:
178
+ A dictionary representing each generated text completion.
179
+
180
+ """
181
+ inputs = params ["inputs" ]
182
+ model_name = params .get ("model" , "llm" )
183
+
184
+ model .img_context_token_id = tokenizer .convert_tokens_to_ids (IMG_CONTEXT_TOKEN )
185
+
186
+ images , prompt = chatml_prompt_from_messages (inputs )
187
+
188
+ # set the max number of tiles in `max_num`, XXX make an option
189
+ pixel_values , target_aspect_ratio = load_image (images [- 1 ], min_num = 4 , max_num = 12 )
190
+ pixel_values2 = load_image2 (images [- 1 ], min_num = 3 , max_num = 7 , target_aspect_ratio = target_aspect_ratio )
191
+ pixel_values = torch .cat ([pixel_values2 [:- 1 ], pixel_values [:- 1 ], pixel_values2 [- 1 :]], 0 ).to (device = model .device , dtype = model .dtype )
192
+
193
+ for num_patches in [pixel_values .shape [0 ]]:
194
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * model .num_image_token * num_patches + IMG_END_TOKEN
195
+ prompt = prompt .replace ('<image>' , image_tokens , 1 )
196
+
197
+ model_inputs = tokenizer (prompt , return_tensors = 'pt' )
198
+ input_ids = model_inputs ['input_ids' ].to (model .device )
199
+ attention_mask = model_inputs ['attention_mask' ].to (model .device )
200
+
201
+ inputs = dict (
202
+ input_ids = input_ids ,
203
+ pixel_values = pixel_values ,
204
+ attention_mask = attention_mask ,
205
+ target_aspect_ratio = target_aspect_ratio ,
206
+ )
207
+
208
+ eos_token_id = tokenizer .convert_tokens_to_ids ('<|im_end|>' )
209
+ new_params = dict (eos_token_id = [eos_token_id , tokenizer .eos_token_id ],
210
+ temperature = float (params .get ("temperature" , 1.0 )),
211
+ max_new_tokens = int (params .get ("max_tokens" , 256 )),
212
+ repetition_penalty = float (params .get ("repetition_penalty" , 1.0 )),
213
+ top_p = float (params .get ("top_p" , 1.0 )),
214
+ top_k = int (params .get ("top_k" , 50 )))
215
+
216
+ generation_kwargs = dict (
217
+ ** inputs ,
218
+ ** new_params ,
219
+ )
220
+
221
+ # Todo: fix length for prompt
222
+ input_echo_len = 0
223
+
224
+ generated_text , previous_text = "" , ""
225
+ completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
226
+ created : int = int (time .time ())
227
+ for i , new_text in enumerate (threaded_streaming_generator (generate = model .generate , tokenizer = tokenizer , generation_kwargs = generation_kwargs )):
228
+ generated_text += new_text
229
+ delta_text = generated_text [len (previous_text ):]
230
+ previous_text = generated_text
231
+ yield {
232
+ "id" : completion_id ,
233
+ "object" : "text_completion" ,
234
+ "created" : created ,
235
+ "model" : model_name ,
236
+ "delta" : delta_text ,
237
+ "text" : generated_text ,
238
+ "logprobs" : None ,
239
+ "finish_reason" : None ,
240
+ "usage" : {
241
+ "prompt_tokens" : input_echo_len ,
242
+ "completion_tokens" : i ,
243
+ "total_tokens" : input_echo_len + i ,
244
+ },
245
+ }
246
+
247
+ gc .collect ()
248
+ torch .cuda .empty_cache ()
249
+
250
+
251
+ def chatml_prompt_from_messages (messages : list [ChatCompletionMessageParam ], img_tok = "<image>\n " ):
252
+ prompt = ''
253
+ images = []
254
+ generation_msg = "<|im_start|>assistant\n "
255
+
256
+ if messages and messages [- 1 ]['role' ] == 'assistant' :
257
+ generation_msg += messages [- 1 ]['content' ][0 ].text
258
+ messages .pop (- 1 )
259
+
260
+ for m in messages :
261
+ if m ['role' ] == 'user' :
262
+ text = ''
263
+ has_image = False
264
+
265
+ for c in m ['content' ]:
266
+ if c ['type' ] == 'image_url' :
267
+ images .extend ([ url_to_image (c ['image_url' ]['url' ]) ])
268
+ has_image = True
269
+ if c ['type' ] == 'text' :
270
+ text = c ['text' ]
271
+
272
+ img_tag = img_tok if has_image else ''
273
+ prompt += f"<|im_start|>user\n { img_tag } { text } <|im_end|>"
274
+ elif m ['role' ] == 'assistant' :
275
+ for c in m ['content' ]:
276
+ if c ['type' ] == 'text' :
277
+ prompt += f"<|im_start|>assistant\n { c ['text' ]} <|im_end|>"
278
+ elif m ['role' ] == 'system' :
279
+ for c in m ['content' ]:
280
+ if c ['type' ] == 'text' :
281
+ prompt += f"<|im_start|>system\n { c ['text' ]} <|im_end|>"
282
+
283
+ prompt += generation_msg
284
+
285
+ return images , prompt
286
+
287
+
288
+ def url_to_image (image_url : str ):
289
+ from PIL import Image
290
+ from io import BytesIO
291
+
292
+ if image_url .startswith ("data:" ):
293
+ import base64
294
+
295
+ image_bytes = base64 .b64decode (image_url .split ("," )[1 ])
296
+ else :
297
+ import urllib .request
298
+
299
+ with urllib .request .urlopen (image_url ) as f :
300
+ image_bytes = f .read ()
301
+
302
+ return Image .open (BytesIO (image_bytes )).convert ("RGB" )
303
+
304
+
305
+ def threaded_streaming_generator (generate , tokenizer , generation_kwargs ):
306
+ streamer = TextIteratorStreamer (tokenizer , skip_special_tokens = True , skip_prompt = True , timeout = 60 )
307
+
308
+ generation_kwargs ['streamer' ] = streamer
309
+
310
+ exq = queue .Queue ()
311
+
312
+ def wrapper ():
313
+ try :
314
+ with torch .no_grad ():
315
+ generate (** generation_kwargs )
316
+
317
+ except Exception as e :
318
+ #logger.exception(e)
319
+ exq .put (e )
320
+ streamer .end ()
321
+
322
+ t = Thread (target = wrapper , daemon = True )
323
+ t .start ()
324
+
325
+ for text in streamer :
326
+ if text :
327
+ yield text
328
+
329
+ if not exq .empty ():
330
+ raise exq .get_nowait ()
0 commit comments