Skip to content

[AQUA] Adding handler for streaming inference predict endpoint #1190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ads/aqua/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ class Resource(ExtendedEnum):
MODEL_VERSION_SET = "model-version-sets"


class PredictEndpoints(ExtendedEnum):
CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions"
TEXT_COMPLETIONS_ENDPOINT = "/v1/completions"
EMBEDDING_ENDPOINT = "/v1/embedding"


class Tags(ExtendedEnum):
TASK = "task"
LICENSE = "license"
Expand Down
161 changes: 122 additions & 39 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from typing import List, Union
from typing import List, Optional, Union
from urllib.parse import urlparse

from tornado.web import HTTPError

from ads.aqua.app import logger
from ads.aqua.client.client import Client, ExtendedRequestError
from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.enums import PredictEndpoints
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
from ads.aqua.modeldeployment.entities import ModelParams
from ads.aqua.modeldeployment import AquaDeploymentApp
from ads.config import COMPARTMENT_OCID


Expand Down Expand Up @@ -175,23 +177,107 @@ def list_shapes(self):
)


class AquaDeploymentInferenceHandler(AquaAPIhandler):
@staticmethod
def validate_predict_url(endpoint):
try:
url = urlparse(endpoint)
if url.scheme != "https":
return False
if not url.netloc:
return False
return url.path.endswith("/predict")
except Exception:
return False
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know that we use it at all

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was added before we had kernel messaging solution for inference. Since then it has been lying orphaned in code.
Now we can use this handler to expose streaming inference API which our client (AQUA UI) can use for inference instead of keeping the script at UI.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, let's use it but with AQUA Client.
I've added this PR to update the docs. I think In UI playground it would be useful to add a checkbox where users can choose if they want to use streaming or not. By default it can use streaming end-point.

def _get_model_deployment_response(
self,
model_deployment_id: str,
payload: dict,
route_override_header: Optional[str],
):
"""
Returns the model deployment inference response in a streaming fashion.

This method connects to the specified model deployment endpoint and
streams the inference output back to the caller, handling both text
and chat completion endpoints depending on the route override.

Parameters
----------
model_deployment_id : str
The OCID of the model deployment to invoke.
Example: 'ocid1.datasciencemodeldeployment.iad.oc1.xxxyz'

payload : dict
Dictionary containing the model inference parameters.
Same example for text completions:
{
"max_tokens": 1024,
"temperature": 0.5,
"prompt": "what are some good skills deep learning expert. Give us some tips on how to structure interview with some coding example?",
"top_p": 0.4,
"top_k": 100,
"model": "odsc-llm",
"frequency_penalty": 1,
"presence_penalty": 1,
"stream": true
}

route_override_header : Optional[str]
Optional override for the inference route, used for routing between
different endpoint types (e.g., chat vs. text completions).
Example: '/v1/chat/completions'

Returns
-------
Generator[str]
A generator that yields strings of the model's output as they are received.

Raises
------
HTTPError
If the request to the model deployment fails or if streaming cannot be established.
"""

model_deployment = AquaDeploymentApp().get(model_deployment_id)
endpoint = model_deployment.endpoint + "/predictWithResponseStream"
endpoint_type = model_deployment.environment_variables.get(
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT
)
aqua_client = Client(endpoint=endpoint)

if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in (
endpoint_type,
route_override_header,
):
try:
for chunk in aqua_client.chat(
messages=payload.pop("messages"),
payload=payload,
stream=True,
):
try:
yield chunk["choices"][0]["delta"]["content"]
except Exception as e:
logger.debug(
f"Exception occurred while parsing streaming response: {e}"
)
except ExtendedRequestError as ex:
raise HTTPError(400, str(ex))
except Exception as ex:
raise HTTPError(500, str(ex))

elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
try:
for chunk in aqua_client.generate(
prompt=payload.pop("prompt"),
payload=payload,
stream=True,
):
try:
yield chunk["choices"][0]["text"]
except Exception as e:
logger.debug(
f"Exception occurred while parsing streaming response: {e}"
)
except ExtendedRequestError as ex:
raise HTTPError(400, str(ex))
except Exception as ex:
raise HTTPError(500, str(ex))

@handle_exceptions
def post(self, *args, **kwargs): # noqa: ARG002
def post(self, model_deployment_id):
"""
Handles inference request for the Active Model Deployments
Handles streaming inference request for the Active Model Deployments
Raises
------
HTTPError
Expand All @@ -205,32 +291,29 @@ def post(self, *args, **kwargs): # noqa: ARG002
if not input_data:
raise HTTPError(400, Errors.NO_INPUT_DATA)

endpoint = input_data.get("endpoint")
if not endpoint:
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("endpoint"))

if not self.validate_predict_url(endpoint):
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT.format("endpoint"))

prompt = input_data.get("prompt")
if not prompt:
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt"))
messages = input_data.get("messages")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we accept chat_template as well since we want to support it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it will be accepted implicitly. Chat template will be optional but still still user can provide it and the code will accept and pass it to aqua client.


model_params = (
input_data.get("model_params") if input_data.get("model_params") else {}
)
try:
model_params_obj = ModelParams(**model_params)
except Exception as ex:
if not prompt and not messages:
raise HTTPError(
400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
) from ex

return self.finish(
MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
endpoint
400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages")
)
if not input_data.get("model"):
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
route_override_header = self.request.headers.get("route", None)
self.set_header("Content-Type", "text/event-stream")
response_gen = self._get_model_deployment_response(
model_deployment_id, input_data, route_override_header
)
try:
for chunk in response_gen:
self.write(chunk)
self.flush()
self.finish()
except Exception as ex:
self.set_status(ex.status_code)
self.write({"message": "Error occurred", "reason": str(ex)})
self.finish()


class AquaDeploymentParamsHandler(AquaAPIhandler):
Expand Down Expand Up @@ -294,5 +377,5 @@ def post(self, *args, **kwargs): # noqa: ARG002
("deployments/?([^/]*)", AquaDeploymentHandler),
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
("inference", AquaDeploymentInferenceHandler),
("inference/stream/?([^/]*)", AquaDeploymentStreamingInferenceHandler),
]
6 changes: 2 additions & 4 deletions ads/aqua/modeldeployment/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
from ads.aqua.modeldeployment.deployment import AquaDeploymentApp
from ads.aqua.modeldeployment.inference import MDInferenceResponse

__all__ = ["AquaDeploymentApp", "MDInferenceResponse"]
__all__ = ["AquaDeploymentApp"]
17 changes: 10 additions & 7 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
ComputeShapeSummary,
ContainerPath,
)
from ads.aqua.common.enums import InferenceContainerTypeFamily, ModelFormat, Tags
from ads.aqua.common.enums import (
InferenceContainerTypeFamily,
ModelFormat,
Tags,
)
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.utils import (
DEFINED_METADATA_TO_FILE_MAP,
Expand Down Expand Up @@ -628,7 +632,9 @@ def _create_multi(
config_data["model_task"] = model.model_task

if model.fine_tune_weights_location:
config_data["fine_tune_weights_location"] = model.fine_tune_weights_location
config_data["fine_tune_weights_location"] = (
model.fine_tune_weights_location
)

model_config.append(config_data)
model_name_list.append(model.model_name)
Expand Down Expand Up @@ -789,7 +795,7 @@ def _create_deployment(
telemetry_kwargs = {"ocid": get_ocid_substring(deployment_id, key_len=8)}

if Tags.BASE_MODEL_CUSTOM in tags:
telemetry_kwargs[ "custom_base_model"] = True
telemetry_kwargs["custom_base_model"] = True

# tracks unique deployments that were created in the user compartment
self.telemetry.record_event_async(
Expand Down Expand Up @@ -934,7 +940,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
model_deployment = self.ds_client.get_model_deployment(
model_deployment_id=model_deployment_id, **kwargs
).data

oci_aqua = (
(
Tags.AQUA_TAG in model_deployment.freeform_tags
Expand Down Expand Up @@ -979,7 +984,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
aqua_deployment = AquaDeployment.from_oci_model_deployment(
model_deployment, self.region
)

if Tags.MULTIMODEL_TYPE_TAG in model_deployment.freeform_tags:
aqua_model_id = model_deployment.freeform_tags.get(
Tags.AQUA_MODEL_ID_TAG, UNKNOWN
Expand Down Expand Up @@ -1010,7 +1014,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
aqua_deployment.models = [
AquaMultiModelRef(**metadata) for metadata in multi_model_metadata
]

return AquaDeploymentDetail(
**vars(aqua_deployment),
log_group=AquaResourceIdentifier(
Expand Down Expand Up @@ -1309,4 +1312,4 @@ def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
or gpu_specs.shapes.get(oci_shape.name.upper()),
)
for oci_shape in oci_shapes
]
]
74 changes: 0 additions & 74 deletions ads/aqua/modeldeployment/inference.py

This file was deleted.

2 changes: 1 addition & 1 deletion ads/common/oci_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python

# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import logging
Expand Down
Loading
Loading