|
2 | 2 | # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
|
3 | 3 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4 | 4 |
|
5 |
| -from typing import List, Union |
| 5 | +from typing import List, Optional, Union |
6 | 6 | from urllib.parse import urlparse
|
7 | 7 |
|
8 | 8 | from tornado.web import HTTPError
|
9 | 9 |
|
| 10 | +from ads.aqua.app import logger |
| 11 | +from ads.aqua.client.client import Client, ExtendedRequestError |
10 | 12 | from ads.aqua.common.decorator import handle_exceptions
|
| 13 | +from ads.aqua.common.enums import PredictEndpoints |
11 | 14 | from ads.aqua.extension.base_handler import AquaAPIhandler
|
12 | 15 | from ads.aqua.extension.errors import Errors
|
13 | 16 | from ads.aqua.modeldeployment import AquaDeploymentApp
|
@@ -175,6 +178,102 @@ def list_shapes(self):
|
175 | 178 |
|
176 | 179 |
|
177 | 180 | class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
|
| 181 | + def _get_model_deployment_response( |
| 182 | + self, |
| 183 | + model_deployment_id: str, |
| 184 | + payload: dict, |
| 185 | + route_override_header: Optional[str], |
| 186 | + ): |
| 187 | + """ |
| 188 | + Returns the model deployment inference response in a streaming fashion. |
| 189 | +
|
| 190 | + This method connects to the specified model deployment endpoint and |
| 191 | + streams the inference output back to the caller, handling both text |
| 192 | + and chat completion endpoints depending on the route override. |
| 193 | +
|
| 194 | + Parameters |
| 195 | + ---------- |
| 196 | + model_deployment_id : str |
| 197 | + The OCID of the model deployment to invoke. |
| 198 | + Example: 'ocid1.datasciencemodeldeployment.iad.oc1.xxxyz' |
| 199 | +
|
| 200 | + payload : dict |
| 201 | + Dictionary containing the model inference parameters. |
| 202 | + Same example for text completions: |
| 203 | + { |
| 204 | + "max_tokens": 1024, |
| 205 | + "temperature": 0.5, |
| 206 | + "prompt": "what are some good skills deep learning expert. Give us some tips on how to structure interview with some coding example?", |
| 207 | + "top_p": 0.4, |
| 208 | + "top_k": 100, |
| 209 | + "model": "odsc-llm", |
| 210 | + "frequency_penalty": 1, |
| 211 | + "presence_penalty": 1, |
| 212 | + "stream": true |
| 213 | + } |
| 214 | +
|
| 215 | + route_override_header : Optional[str] |
| 216 | + Optional override for the inference route, used for routing between |
| 217 | + different endpoint types (e.g., chat vs. text completions). |
| 218 | + Example: '/v1/chat/completions' |
| 219 | +
|
| 220 | + Returns |
| 221 | + ------- |
| 222 | + Generator[str] |
| 223 | + A generator that yields strings of the model's output as they are received. |
| 224 | +
|
| 225 | + Raises |
| 226 | + ------ |
| 227 | + HTTPError |
| 228 | + If the request to the model deployment fails or if streaming cannot be established. |
| 229 | + """ |
| 230 | + |
| 231 | + model_deployment = AquaDeploymentApp().get(model_deployment_id) |
| 232 | + endpoint = model_deployment.endpoint + "/predictWithResponseStream" |
| 233 | + endpoint_type = model_deployment.environment_variables.get( |
| 234 | + "MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT |
| 235 | + ) |
| 236 | + aqua_client = Client(endpoint=endpoint) |
| 237 | + |
| 238 | + if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( |
| 239 | + endpoint_type, |
| 240 | + route_override_header, |
| 241 | + ): |
| 242 | + try: |
| 243 | + for chunk in aqua_client.chat( |
| 244 | + messages=payload.pop("messages"), |
| 245 | + payload=payload, |
| 246 | + stream=True, |
| 247 | + ): |
| 248 | + try: |
| 249 | + yield chunk["choices"][0]["delta"]["content"] |
| 250 | + except Exception as e: |
| 251 | + logger.debug( |
| 252 | + f"Exception occurred while parsing streaming response: {e}" |
| 253 | + ) |
| 254 | + except ExtendedRequestError as ex: |
| 255 | + raise HTTPError(400, str(ex)) |
| 256 | + except Exception as ex: |
| 257 | + raise HTTPError(500, str(ex)) |
| 258 | + |
| 259 | + elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: |
| 260 | + try: |
| 261 | + for chunk in aqua_client.generate( |
| 262 | + prompt=payload.pop("prompt"), |
| 263 | + payload=payload, |
| 264 | + stream=True, |
| 265 | + ): |
| 266 | + try: |
| 267 | + yield chunk["choices"][0]["text"] |
| 268 | + except Exception as e: |
| 269 | + logger.debug( |
| 270 | + f"Exception occurred while parsing streaming response: {e}" |
| 271 | + ) |
| 272 | + except ExtendedRequestError as ex: |
| 273 | + raise HTTPError(400, str(ex)) |
| 274 | + except Exception as ex: |
| 275 | + raise HTTPError(500, str(ex)) |
| 276 | + |
178 | 277 | @handle_exceptions
|
179 | 278 | def post(self, model_deployment_id):
|
180 | 279 | """
|
@@ -203,19 +302,17 @@ def post(self, model_deployment_id):
|
203 | 302 | raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
|
204 | 303 | route_override_header = self.request.headers.get("route", None)
|
205 | 304 | self.set_header("Content-Type", "text/event-stream")
|
206 |
| - self.set_header("Cache-Control", "no-cache") |
207 |
| - self.set_header("Transfer-Encoding", "chunked") |
208 |
| - self.flush() |
| 305 | + response_gen = self._get_model_deployment_response( |
| 306 | + model_deployment_id, input_data, route_override_header |
| 307 | + ) |
209 | 308 | try:
|
210 |
| - response_gen = AquaDeploymentApp().get_model_deployment_response( |
211 |
| - model_deployment_id, input_data, route_override_header |
212 |
| - ) |
213 | 309 | for chunk in response_gen:
|
214 | 310 | self.write(chunk)
|
215 | 311 | self.flush()
|
| 312 | + self.finish() |
216 | 313 | except Exception as ex:
|
217 |
| - raise HTTPError(500, str(ex)) from ex |
218 |
| - finally: |
| 314 | + self.set_status(ex.status_code) |
| 315 | + self.write({"message": "Error occurred", "reason": str(ex)}) |
219 | 316 | self.finish()
|
220 | 317 |
|
221 | 318 |
|
|
0 commit comments