Skip to content

Commit bc41862

Browse files
committed
Merge branch 'main' of github.com:oracle/accelerated-data-science into ODSC-70841_update_md_tracking
2 parents 075d714 + 33c9966 commit bc41862

File tree

8 files changed

+179
-182
lines changed

8 files changed

+179
-182
lines changed

ads/aqua/common/enums.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ class Resource(ExtendedEnum):
2020
MODEL_VERSION_SET = "model-version-sets"
2121

2222

23+
class PredictEndpoints(ExtendedEnum):
24+
CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions"
25+
TEXT_COMPLETIONS_ENDPOINT = "/v1/completions"
26+
EMBEDDING_ENDPOINT = "/v1/embedding"
27+
28+
2329
class Tags(ExtendedEnum):
2430
TASK = "task"
2531
LICENSE = "license"

ads/aqua/extension/deployment_handler.py

Lines changed: 122 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5-
from typing import List, Union
5+
from typing import List, Optional, Union
66
from urllib.parse import urlparse
77

88
from tornado.web import HTTPError
99

10+
from ads.aqua.app import logger
11+
from ads.aqua.client.client import Client, ExtendedRequestError
1012
from ads.aqua.common.decorator import handle_exceptions
13+
from ads.aqua.common.enums import PredictEndpoints
1114
from ads.aqua.extension.base_handler import AquaAPIhandler
1215
from ads.aqua.extension.errors import Errors
13-
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
14-
from ads.aqua.modeldeployment.entities import ModelParams
16+
from ads.aqua.modeldeployment import AquaDeploymentApp
1517
from ads.config import COMPARTMENT_OCID
1618

1719

@@ -175,23 +177,107 @@ def list_shapes(self):
175177
)
176178

177179

178-
class AquaDeploymentInferenceHandler(AquaAPIhandler):
179-
@staticmethod
180-
def validate_predict_url(endpoint):
181-
try:
182-
url = urlparse(endpoint)
183-
if url.scheme != "https":
184-
return False
185-
if not url.netloc:
186-
return False
187-
return url.path.endswith("/predict")
188-
except Exception:
189-
return False
180+
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
181+
def _get_model_deployment_response(
182+
self,
183+
model_deployment_id: str,
184+
payload: dict,
185+
route_override_header: Optional[str],
186+
):
187+
"""
188+
Returns the model deployment inference response in a streaming fashion.
189+
190+
This method connects to the specified model deployment endpoint and
191+
streams the inference output back to the caller, handling both text
192+
and chat completion endpoints depending on the route override.
193+
194+
Parameters
195+
----------
196+
model_deployment_id : str
197+
The OCID of the model deployment to invoke.
198+
Example: 'ocid1.datasciencemodeldeployment.iad.oc1.xxxyz'
199+
200+
payload : dict
201+
Dictionary containing the model inference parameters.
202+
Same example for text completions:
203+
{
204+
"max_tokens": 1024,
205+
"temperature": 0.5,
206+
"prompt": "what are some good skills deep learning expert. Give us some tips on how to structure interview with some coding example?",
207+
"top_p": 0.4,
208+
"top_k": 100,
209+
"model": "odsc-llm",
210+
"frequency_penalty": 1,
211+
"presence_penalty": 1,
212+
"stream": true
213+
}
214+
215+
route_override_header : Optional[str]
216+
Optional override for the inference route, used for routing between
217+
different endpoint types (e.g., chat vs. text completions).
218+
Example: '/v1/chat/completions'
219+
220+
Returns
221+
-------
222+
Generator[str]
223+
A generator that yields strings of the model's output as they are received.
224+
225+
Raises
226+
------
227+
HTTPError
228+
If the request to the model deployment fails or if streaming cannot be established.
229+
"""
230+
231+
model_deployment = AquaDeploymentApp().get(model_deployment_id)
232+
endpoint = model_deployment.endpoint + "/predictWithResponseStream"
233+
endpoint_type = model_deployment.environment_variables.get(
234+
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT
235+
)
236+
aqua_client = Client(endpoint=endpoint)
237+
238+
if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in (
239+
endpoint_type,
240+
route_override_header,
241+
):
242+
try:
243+
for chunk in aqua_client.chat(
244+
messages=payload.pop("messages"),
245+
payload=payload,
246+
stream=True,
247+
):
248+
try:
249+
yield chunk["choices"][0]["delta"]["content"]
250+
except Exception as e:
251+
logger.debug(
252+
f"Exception occurred while parsing streaming response: {e}"
253+
)
254+
except ExtendedRequestError as ex:
255+
raise HTTPError(400, str(ex))
256+
except Exception as ex:
257+
raise HTTPError(500, str(ex))
258+
259+
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
260+
try:
261+
for chunk in aqua_client.generate(
262+
prompt=payload.pop("prompt"),
263+
payload=payload,
264+
stream=True,
265+
):
266+
try:
267+
yield chunk["choices"][0]["text"]
268+
except Exception as e:
269+
logger.debug(
270+
f"Exception occurred while parsing streaming response: {e}"
271+
)
272+
except ExtendedRequestError as ex:
273+
raise HTTPError(400, str(ex))
274+
except Exception as ex:
275+
raise HTTPError(500, str(ex))
190276

191277
@handle_exceptions
192-
def post(self, *args, **kwargs): # noqa: ARG002
278+
def post(self, model_deployment_id):
193279
"""
194-
Handles inference request for the Active Model Deployments
280+
Handles streaming inference request for the Active Model Deployments
195281
Raises
196282
------
197283
HTTPError
@@ -205,32 +291,29 @@ def post(self, *args, **kwargs): # noqa: ARG002
205291
if not input_data:
206292
raise HTTPError(400, Errors.NO_INPUT_DATA)
207293

208-
endpoint = input_data.get("endpoint")
209-
if not endpoint:
210-
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("endpoint"))
211-
212-
if not self.validate_predict_url(endpoint):
213-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT.format("endpoint"))
214-
215294
prompt = input_data.get("prompt")
216-
if not prompt:
217-
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt"))
295+
messages = input_data.get("messages")
218296

219-
model_params = (
220-
input_data.get("model_params") if input_data.get("model_params") else {}
221-
)
222-
try:
223-
model_params_obj = ModelParams(**model_params)
224-
except Exception as ex:
297+
if not prompt and not messages:
225298
raise HTTPError(
226-
400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
227-
) from ex
228-
229-
return self.finish(
230-
MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
231-
endpoint
299+
400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages")
232300
)
301+
if not input_data.get("model"):
302+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
303+
route_override_header = self.request.headers.get("route", None)
304+
self.set_header("Content-Type", "text/event-stream")
305+
response_gen = self._get_model_deployment_response(
306+
model_deployment_id, input_data, route_override_header
233307
)
308+
try:
309+
for chunk in response_gen:
310+
self.write(chunk)
311+
self.flush()
312+
self.finish()
313+
except Exception as ex:
314+
self.set_status(ex.status_code)
315+
self.write({"message": "Error occurred", "reason": str(ex)})
316+
self.finish()
234317

235318

236319
class AquaDeploymentParamsHandler(AquaAPIhandler):
@@ -294,5 +377,5 @@ def post(self, *args, **kwargs): # noqa: ARG002
294377
("deployments/?([^/]*)", AquaDeploymentHandler),
295378
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
296379
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
297-
("inference", AquaDeploymentInferenceHandler),
380+
("inference/stream/?([^/]*)", AquaDeploymentStreamingInferenceHandler),
298381
]

ads/aqua/modeldeployment/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
from ads.aqua.modeldeployment.deployment import AquaDeploymentApp
6-
from ads.aqua.modeldeployment.inference import MDInferenceResponse
75

8-
__all__ = ["AquaDeploymentApp", "MDInferenceResponse"]
6+
__all__ = ["AquaDeploymentApp"]

ads/aqua/modeldeployment/deployment.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
ComputeShapeSummary,
2121
ContainerPath,
2222
)
23-
from ads.aqua.common.enums import InferenceContainerTypeFamily, ModelFormat, Tags
23+
from ads.aqua.common.enums import (
24+
InferenceContainerTypeFamily,
25+
ModelFormat,
26+
Tags,
27+
)
2428
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
2529
from ads.aqua.common.utils import (
2630
DEFINED_METADATA_TO_FILE_MAP,
@@ -632,7 +636,9 @@ def _create_multi(
632636
config_data["model_task"] = model.model_task
633637

634638
if model.fine_tune_weights_location:
635-
config_data["fine_tune_weights_location"] = model.fine_tune_weights_location
639+
config_data["fine_tune_weights_location"] = (
640+
model.fine_tune_weights_location
641+
)
636642

637643
model_config.append(config_data)
638644
model_name_list.append(model.model_name)
@@ -800,7 +806,7 @@ def _create_deployment(
800806
telemetry_kwargs = {"ocid": get_ocid_substring(deployment_id, key_len=8)}
801807

802808
if Tags.BASE_MODEL_CUSTOM in tags:
803-
telemetry_kwargs[ "custom_base_model"] = True
809+
telemetry_kwargs["custom_base_model"] = True
804810

805811
# tracks unique deployments that were created in the user compartment
806812
self.telemetry.record_event_async(
@@ -945,7 +951,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
945951
model_deployment = self.ds_client.get_model_deployment(
946952
model_deployment_id=model_deployment_id, **kwargs
947953
).data
948-
949954
oci_aqua = (
950955
(
951956
Tags.AQUA_TAG in model_deployment.freeform_tags
@@ -990,7 +995,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
990995
aqua_deployment = AquaDeployment.from_oci_model_deployment(
991996
model_deployment, self.region
992997
)
993-
994998
if Tags.MULTIMODEL_TYPE_TAG in model_deployment.freeform_tags:
995999
aqua_model_id = model_deployment.freeform_tags.get(
9961000
Tags.AQUA_MODEL_ID_TAG, UNKNOWN
@@ -1021,7 +1025,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
10211025
aqua_deployment.models = [
10221026
AquaMultiModelRef(**metadata) for metadata in multi_model_metadata
10231027
]
1024-
10251028
return AquaDeploymentDetail(
10261029
**vars(aqua_deployment),
10271030
log_group=AquaResourceIdentifier(
@@ -1321,7 +1324,7 @@ def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
13211324
)
13221325
for oci_shape in oci_shapes
13231326
]
1324-
1327+
13251328
@threaded()
13261329
def get_deployment_status(self,model_deployment_id: str, work_request_id : str, model_type : str) -> None:
13271330
"""Waits for the data science model deployment to be completed and log its status in telemetry.
@@ -1366,5 +1369,4 @@ def get_deployment_status(self,model_deployment_id: str, work_request_id : str,
13661369
category=f"aqua/{model_type}/deployment/status",
13671370
action="SUCCEEDED",
13681371
**telemetry_kwargs
1369-
)
1370-
1372+
)

ads/aqua/modeldeployment/inference.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

ads/common/oci_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import logging

0 commit comments

Comments
 (0)