Skip to content

Commit 934ddbe

Browse files
Addressing review comments
1 parent 600cb3e commit 934ddbe

File tree

3 files changed

+111
-80
lines changed

3 files changed

+111
-80
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5-
from typing import List, Union
5+
from typing import List, Optional, Union
66
from urllib.parse import urlparse
77

88
from tornado.web import HTTPError
99

10+
from ads.aqua.app import logger
11+
from ads.aqua.client.client import Client, ExtendedRequestError
1012
from ads.aqua.common.decorator import handle_exceptions
13+
from ads.aqua.common.enums import PredictEndpoints
1114
from ads.aqua.extension.base_handler import AquaAPIhandler
1215
from ads.aqua.extension.errors import Errors
1316
from ads.aqua.modeldeployment import AquaDeploymentApp
@@ -175,6 +178,102 @@ def list_shapes(self):
175178

176179

177180
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
181+
def _get_model_deployment_response(
182+
self,
183+
model_deployment_id: str,
184+
payload: dict,
185+
route_override_header: Optional[str],
186+
):
187+
"""
188+
Returns the model deployment inference response in a streaming fashion.
189+
190+
This method connects to the specified model deployment endpoint and
191+
streams the inference output back to the caller, handling both text
192+
and chat completion endpoints depending on the route override.
193+
194+
Parameters
195+
----------
196+
model_deployment_id : str
197+
The OCID of the model deployment to invoke.
198+
Example: 'ocid1.datasciencemodeldeployment.iad.oc1.xxxyz'
199+
200+
payload : dict
201+
Dictionary containing the model inference parameters.
202+
Same example for text completions:
203+
{
204+
"max_tokens": 1024,
205+
"temperature": 0.5,
206+
"prompt": "what are some good skills deep learning expert. Give us some tips on how to structure interview with some coding example?",
207+
"top_p": 0.4,
208+
"top_k": 100,
209+
"model": "odsc-llm",
210+
"frequency_penalty": 1,
211+
"presence_penalty": 1,
212+
"stream": true
213+
}
214+
215+
route_override_header : Optional[str]
216+
Optional override for the inference route, used for routing between
217+
different endpoint types (e.g., chat vs. text completions).
218+
Example: '/v1/chat/completions'
219+
220+
Returns
221+
-------
222+
Generator[str]
223+
A generator that yields strings of the model's output as they are received.
224+
225+
Raises
226+
------
227+
HTTPError
228+
If the request to the model deployment fails or if streaming cannot be established.
229+
"""
230+
231+
model_deployment = AquaDeploymentApp().get(model_deployment_id)
232+
endpoint = model_deployment.endpoint + "/predictWithResponseStream"
233+
endpoint_type = model_deployment.environment_variables.get(
234+
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT
235+
)
236+
aqua_client = Client(endpoint=endpoint)
237+
238+
if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in (
239+
endpoint_type,
240+
route_override_header,
241+
):
242+
try:
243+
for chunk in aqua_client.chat(
244+
messages=payload.pop("messages"),
245+
payload=payload,
246+
stream=True,
247+
):
248+
try:
249+
yield chunk["choices"][0]["delta"]["content"]
250+
except Exception as e:
251+
logger.debug(
252+
f"Exception occurred while parsing streaming response: {e}"
253+
)
254+
except ExtendedRequestError as ex:
255+
raise HTTPError(400, str(ex))
256+
except Exception as ex:
257+
raise HTTPError(500, str(ex))
258+
259+
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
260+
try:
261+
for chunk in aqua_client.generate(
262+
prompt=payload.pop("prompt"),
263+
payload=payload,
264+
stream=True,
265+
):
266+
try:
267+
yield chunk["choices"][0]["text"]
268+
except Exception as e:
269+
logger.debug(
270+
f"Exception occurred while parsing streaming response: {e}"
271+
)
272+
except ExtendedRequestError as ex:
273+
raise HTTPError(400, str(ex))
274+
except Exception as ex:
275+
raise HTTPError(500, str(ex))
276+
178277
@handle_exceptions
179278
def post(self, model_deployment_id):
180279
"""
@@ -203,19 +302,17 @@ def post(self, model_deployment_id):
203302
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
204303
route_override_header = self.request.headers.get("route", None)
205304
self.set_header("Content-Type", "text/event-stream")
206-
self.set_header("Cache-Control", "no-cache")
207-
self.set_header("Transfer-Encoding", "chunked")
208-
self.flush()
305+
response_gen = self._get_model_deployment_response(
306+
model_deployment_id, input_data, route_override_header
307+
)
209308
try:
210-
response_gen = AquaDeploymentApp().get_model_deployment_response(
211-
model_deployment_id, input_data, route_override_header
212-
)
213309
for chunk in response_gen:
214310
self.write(chunk)
215311
self.flush()
312+
self.finish()
216313
except Exception as ex:
217-
raise HTTPError(500, str(ex)) from ex
218-
finally:
314+
self.set_status(ex.status_code)
315+
self.write({"message": "Error occurred", "reason": str(ex)})
219316
self.finish()
220317

221318

ads/aqua/modeldeployment/deployment.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from oci.data_science.models import ModelDeploymentShapeSummary
1212
from pydantic import ValidationError
1313

14-
from ads.aqua import Client
1514
from ads.aqua.app import AquaApp, logger
1615
from ads.aqua.common.entities import (
1716
AquaMultiModelRef,
@@ -21,7 +20,6 @@
2120
from ads.aqua.common.enums import (
2221
InferenceContainerTypeFamily,
2322
ModelFormat,
24-
PredictEndpoints,
2523
Tags,
2624
)
2725
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
@@ -1315,70 +1313,3 @@ def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]:
13151313
)
13161314
for oci_shape in oci_shapes
13171315
]
1318-
1319-
@telemetry(entry_point="plugin=inference&action=get_response", name="aqua")
1320-
def get_model_deployment_response(
1321-
self, model_deployment_id: str, payload: dict, route_override_header: str
1322-
):
1323-
"""
1324-
Returns Model deployment inference response in streaming fashion
1325-
1326-
Parameters
1327-
----------
1328-
model_deployment_id: str
1329-
Model deployment ocid
1330-
payload: dict
1331-
model params.
1332-
{
1333-
"max_tokens": 1024,
1334-
"temperature": 0.5,
1335-
"prompt": "what are some good skills deep learning expert. Give us some tips on how to structure interview with some coding example?",
1336-
"top_p": 0.4,
1337-
"top_k": 100,
1338-
"model": "odsc-llm",
1339-
"frequency_penalty": 1,
1340-
"presence_penalty": 1,
1341-
"stream": true
1342-
}
1343-
1344-
Returns
1345-
-------
1346-
Model deployment inference response in streaming fashion
1347-
1348-
"""
1349-
1350-
model_deployment = self.get(model_deployment_id)
1351-
endpoint = model_deployment.endpoint + "/predictWithResponseStream"
1352-
endpoint_type = model_deployment.environment_variables.get(
1353-
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT
1354-
)
1355-
aqua_client = Client(endpoint=endpoint)
1356-
1357-
if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in (
1358-
endpoint_type,
1359-
route_override_header,
1360-
):
1361-
for chunk in aqua_client.chat(
1362-
messages=payload.pop("messages"),
1363-
payload=payload,
1364-
stream=True,
1365-
):
1366-
try:
1367-
yield chunk["choices"][0]["delta"]["content"]
1368-
except Exception as e:
1369-
logger.debug(
1370-
f"Exception occurred while parsing streaming response: {e}"
1371-
)
1372-
1373-
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
1374-
for chunk in aqua_client.generate(
1375-
prompt=payload.pop("prompt"),
1376-
payload=payload,
1377-
stream=True,
1378-
):
1379-
try:
1380-
yield chunk["choices"][0]["text"]
1381-
except Exception as e:
1382-
logger.debug(
1383-
f"Exception occurred while parsing streaming response: {e}"
1384-
)

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def setUp(self, ipython_init_mock) -> None:
235235
self.handler.flush = MagicMock()
236236
self.handler.finish = MagicMock()
237237

238-
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.get_model_deployment_response")
238+
@patch.object(
239+
AquaDeploymentStreamingInferenceHandler, "_get_model_deployment_response"
240+
)
239241
def test_post(self, mock_get_model_deployment_response):
240242
"""Test post method to return model deployment response."""
241243
mock_response_gen = iter(["chunk1", "chunk2"])
@@ -245,7 +247,8 @@ def test_post(self, mock_get_model_deployment_response):
245247
self.handler.get_json_body = MagicMock(
246248
return_value={"prompt": "Hello", "model": "some-model"}
247249
)
248-
self.handler.request.headers = {"route": "test-route"}
250+
self.handler.request.headers = MagicMock()
251+
self.handler.request.headers.get.return_value = "test-route"
249252

250253
self.handler.post("mock-deployment-id")
251254

0 commit comments

Comments
 (0)