-
Notifications
You must be signed in to change notification settings - Fork 48
[AQUA] Adding handler for streaming inference predict endpoint #1190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
177f888
d1d942d
8b1ee7d
e68c215
2e6195d
c7b7a42
0a867c7
83e34f6
db2cc99
600cb3e
934ddbe
3f0bbfa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,16 +2,18 @@ | |
# Copyright (c) 2024, 2025 Oracle and/or its affiliates. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ | ||
|
||
from typing import List, Union | ||
from typing import List, Optional, Union | ||
from urllib.parse import urlparse | ||
|
||
from tornado.web import HTTPError | ||
|
||
from ads.aqua.app import logger | ||
from ads.aqua.client.client import Client, ExtendedRequestError | ||
from ads.aqua.common.decorator import handle_exceptions | ||
from ads.aqua.common.enums import PredictEndpoints | ||
from ads.aqua.extension.base_handler import AquaAPIhandler | ||
from ads.aqua.extension.errors import Errors | ||
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse | ||
from ads.aqua.modeldeployment.entities import ModelParams | ||
from ads.aqua.modeldeployment import AquaDeploymentApp | ||
from ads.config import COMPARTMENT_OCID | ||
|
||
|
||
|
@@ -175,23 +177,107 @@ def list_shapes(self): | |
) | ||
|
||
|
||
class AquaDeploymentInferenceHandler(AquaAPIhandler): | ||
@staticmethod | ||
def validate_predict_url(endpoint): | ||
try: | ||
url = urlparse(endpoint) | ||
if url.scheme != "https": | ||
return False | ||
if not url.netloc: | ||
return False | ||
return url.path.endswith("/predict") | ||
except Exception: | ||
return False | ||
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler): | ||
def _get_model_deployment_response( | ||
self, | ||
model_deployment_id: str, | ||
payload: dict, | ||
route_override_header: Optional[str], | ||
): | ||
""" | ||
Returns the model deployment inference response in a streaming fashion. | ||
|
||
This method connects to the specified model deployment endpoint and | ||
streams the inference output back to the caller, handling both text | ||
and chat completion endpoints depending on the route override. | ||
|
||
Parameters | ||
---------- | ||
model_deployment_id : str | ||
The OCID of the model deployment to invoke. | ||
Example: 'ocid1.datasciencemodeldeployment.iad.oc1.xxxyz' | ||
|
||
payload : dict | ||
Dictionary containing the model inference parameters. | ||
Same example for text completions: | ||
{ | ||
"max_tokens": 1024, | ||
"temperature": 0.5, | ||
"prompt": "what are some good skills deep learning expert. Give us some tips on how to structure interview with some coding example?", | ||
"top_p": 0.4, | ||
"top_k": 100, | ||
"model": "odsc-llm", | ||
"frequency_penalty": 1, | ||
"presence_penalty": 1, | ||
"stream": true | ||
} | ||
|
||
route_override_header : Optional[str] | ||
Optional override for the inference route, used for routing between | ||
different endpoint types (e.g., chat vs. text completions). | ||
Example: '/v1/chat/completions' | ||
|
||
Returns | ||
------- | ||
Generator[str] | ||
A generator that yields strings of the model's output as they are received. | ||
|
||
Raises | ||
------ | ||
HTTPError | ||
If the request to the model deployment fails or if streaming cannot be established. | ||
""" | ||
|
||
model_deployment = AquaDeploymentApp().get(model_deployment_id) | ||
endpoint = model_deployment.endpoint + "/predictWithResponseStream" | ||
endpoint_type = model_deployment.environment_variables.get( | ||
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT | ||
) | ||
aqua_client = Client(endpoint=endpoint) | ||
|
||
if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( | ||
endpoint_type, | ||
route_override_header, | ||
): | ||
try: | ||
for chunk in aqua_client.chat( | ||
messages=payload.pop("messages"), | ||
payload=payload, | ||
stream=True, | ||
): | ||
try: | ||
yield chunk["choices"][0]["delta"]["content"] | ||
except Exception as e: | ||
logger.debug( | ||
f"Exception occurred while parsing streaming response: {e}" | ||
) | ||
except ExtendedRequestError as ex: | ||
raise HTTPError(400, str(ex)) | ||
except Exception as ex: | ||
raise HTTPError(500, str(ex)) | ||
|
||
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: | ||
try: | ||
for chunk in aqua_client.generate( | ||
prompt=payload.pop("prompt"), | ||
payload=payload, | ||
stream=True, | ||
): | ||
try: | ||
yield chunk["choices"][0]["text"] | ||
except Exception as e: | ||
logger.debug( | ||
f"Exception occurred while parsing streaming response: {e}" | ||
) | ||
except ExtendedRequestError as ex: | ||
raise HTTPError(400, str(ex)) | ||
except Exception as ex: | ||
raise HTTPError(500, str(ex)) | ||
|
||
@handle_exceptions | ||
def post(self, *args, **kwargs): # noqa: ARG002 | ||
def post(self, model_deployment_id): | ||
""" | ||
Handles inference request for the Active Model Deployments | ||
Handles streaming inference request for the Active Model Deployments | ||
Raises | ||
------ | ||
HTTPError | ||
|
@@ -205,32 +291,29 @@ def post(self, *args, **kwargs): # noqa: ARG002 | |
if not input_data: | ||
raise HTTPError(400, Errors.NO_INPUT_DATA) | ||
|
||
endpoint = input_data.get("endpoint") | ||
if not endpoint: | ||
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("endpoint")) | ||
|
||
if not self.validate_predict_url(endpoint): | ||
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT.format("endpoint")) | ||
|
||
prompt = input_data.get("prompt") | ||
if not prompt: | ||
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt")) | ||
messages = input_data.get("messages") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we accept chat_template as well since we want to support it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes it will be accepted implicitly. Chat template will be optional but still still user can provide it and the code will accept and pass it to aqua client. |
||
|
||
model_params = ( | ||
input_data.get("model_params") if input_data.get("model_params") else {} | ||
) | ||
try: | ||
model_params_obj = ModelParams(**model_params) | ||
except Exception as ex: | ||
if not prompt and not messages: | ||
raise HTTPError( | ||
400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params") | ||
) from ex | ||
|
||
return self.finish( | ||
MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response( | ||
endpoint | ||
400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages") | ||
) | ||
if not input_data.get("model"): | ||
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model")) | ||
route_override_header = self.request.headers.get("route", None) | ||
self.set_header("Content-Type", "text/event-stream") | ||
response_gen = self._get_model_deployment_response( | ||
model_deployment_id, input_data, route_override_header | ||
) | ||
try: | ||
for chunk in response_gen: | ||
self.write(chunk) | ||
self.flush() | ||
self.finish() | ||
except Exception as ex: | ||
self.set_status(ex.status_code) | ||
self.write({"message": "Error occurred", "reason": str(ex)}) | ||
self.finish() | ||
|
||
|
||
class AquaDeploymentParamsHandler(AquaAPIhandler): | ||
|
@@ -294,5 +377,5 @@ def post(self, *args, **kwargs): # noqa: ARG002 | |
("deployments/?([^/]*)", AquaDeploymentHandler), | ||
("deployments/?([^/]*)/activate", AquaDeploymentHandler), | ||
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler), | ||
("inference", AquaDeploymentInferenceHandler), | ||
("inference/stream/?([^/]*)", AquaDeploymentStreamingInferenceHandler), | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,6 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Copyright (c) 2024 Oracle and/or its affiliates. | ||
# Copyright (c) 2025 Oracle and/or its affiliates. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ | ||
from ads.aqua.modeldeployment.deployment import AquaDeploymentApp | ||
from ads.aqua.modeldeployment.inference import MDInferenceResponse | ||
|
||
__all__ = ["AquaDeploymentApp", "MDInferenceResponse"] | ||
__all__ = ["AquaDeploymentApp"] |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know that we use it at all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was added before we had kernel messaging solution for inference. Since then it has been lying orphaned in code.
Now we can use this handler to expose streaming inference API which our client (AQUA UI) can use for inference instead of keeping the script at UI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, let's use it but with AQUA Client.
I've added this PR to update the docs. I think In UI playground it would be useful to add a checkbox where users can choose if they want to use streaming or not. By default it can use streaming end-point.