|
| 1 | +import base64 |
| 2 | +import io |
1 | 3 | from typing import AsyncGenerator, Dict, List, Optional
|
2 | 4 | from loguru import logger
|
3 | 5 |
|
@@ -92,11 +94,11 @@ async def make_request(
|
92 | 94 |
|
93 | 95 | request_args.update(self._request_args)
|
94 | 96 |
|
| 97 | + messages = self._build_messages(request) |
| 98 | + |
95 | 99 | payload = {
|
96 | 100 | "model": self._model,
|
97 |
| - "messages": [ |
98 |
| - {"role": "user", "content": request.prompt}, |
99 |
| - ], |
| 101 | + "messages": messages, |
100 | 102 | "stream": True,
|
101 | 103 | **request_args,
|
102 | 104 | }
|
@@ -158,3 +160,21 @@ def validate_connection(self):
|
158 | 160 | Validate the connection to the backend server.
|
159 | 161 | """
|
160 | 162 | logger.info("Connection validation is not explicitly implemented for aiohttp backend.")
|
| 163 | + |
| 164 | + def _build_messages(self, request: TextGenerationRequest) -> Dict: |
| 165 | + if request.number_images == 0: |
| 166 | + messages = [{"role": "user", "content": request.prompt}] |
| 167 | + else: |
| 168 | + content = [] |
| 169 | + for image in request.images: |
| 170 | + stream = io.BytesIO() |
| 171 | + im_format = image.image.format or "PNG" |
| 172 | + image.image.save(stream, format=im_format) |
| 173 | + im_b64 = base64.b64encode(stream.getvalue()).decode("utf-8") |
| 174 | + image_url = {"url": f"data:image/{im_format.lower()};base64,{im_b64}"} |
| 175 | + content.append({"type": "image_url", "image_url": image_url}) |
| 176 | + |
| 177 | + content.append({"type": "text", "text": request.prompt}) |
| 178 | + messages = [{"role": "user", "content": content}] |
| 179 | + |
| 180 | + return messages |
0 commit comments