Skip to content

Commit 177f888

Browse files
Adding AQUA handler for streaming inference API
1 parent 350a59d commit 177f888

File tree

5 files changed

+66
-38
lines changed

5 files changed

+66
-38
lines changed

ads/aqua/app.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ads.config import (
4141
AQUA_TELEMETRY_BUCKET,
4242
AQUA_TELEMETRY_BUCKET_NS,
43+
OCI_MD_SERVICE_ENDPOINT,
4344
OCI_ODSC_SERVICE_ENDPOINT,
4445
OCI_RESOURCE_PRINCIPAL_VERSION,
4546
)
@@ -63,8 +64,14 @@ def __init__(self) -> None:
6364
if OCI_RESOURCE_PRINCIPAL_VERSION:
6465
set_auth("resource_principal")
6566
self._auth = default_signer({"service_endpoint": OCI_ODSC_SERVICE_ENDPOINT})
67+
self._md_auth = default_signer({"service_endpoint": OCI_MD_SERVICE_ENDPOINT})
6668
self.ds_client = oc.OCIClientFactory(**self._auth).data_science
6769
self.compute_client = oc.OCIClientFactory(**default_signer()).compute
70+
print("self._md_auth: ", self._md_auth)
71+
print("OCI_MD_SERVICE_ENDPOINT: ", OCI_MD_SERVICE_ENDPOINT)
72+
self.model_deployment_client = oc.OCIClientFactory(
73+
**self._md_auth
74+
).model_deployment
6875
self.logging_client = oc.OCIClientFactory(**default_signer()).logging_management
6976
self.identity_client = oc.OCIClientFactory(**default_signer()).identity
7077
self.region = extract_region(self._auth)

ads/aqua/extension/deployment_handler.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List, Union
66
from urllib.parse import urlparse
77

8+
from tornado.iostream import StreamClosedError
89
from tornado.web import HTTPError
910

1011
from ads.aqua.common.decorator import handle_exceptions
@@ -175,21 +176,9 @@ def list_shapes(self):
175176
)
176177

177178

178-
class AquaDeploymentInferenceHandler(AquaAPIhandler):
179-
@staticmethod
180-
def validate_predict_url(endpoint):
181-
try:
182-
url = urlparse(endpoint)
183-
if url.scheme != "https":
184-
return False
185-
if not url.netloc:
186-
return False
187-
return url.path.endswith("/predict")
188-
except Exception:
189-
return False
190-
179+
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
191180
@handle_exceptions
192-
def post(self, *args, **kwargs): # noqa: ARG002
181+
async def post(self, *args, **kwargs): # noqa: ARG002
193182
"""
194183
Handles inference request for the Active Model Deployments
195184
Raises
@@ -205,12 +194,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
205194
if not input_data:
206195
raise HTTPError(400, Errors.NO_INPUT_DATA)
207196

208-
endpoint = input_data.get("endpoint")
209-
if not endpoint:
210-
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("endpoint"))
211-
212-
if not self.validate_predict_url(endpoint):
213-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT.format("endpoint"))
197+
model_deployment_id = input_data.get("id")
214198

215199
prompt = input_data.get("prompt")
216200
if not prompt:
@@ -226,11 +210,24 @@ def post(self, *args, **kwargs): # noqa: ARG002
226210
400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params")
227211
) from ex
228212

229-
return self.finish(
230-
MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response(
231-
endpoint
232-
)
233-
)
213+
self.set_header("Content-Type", "text/event-stream")
214+
self.set_header("Cache-Control", "no-cache")
215+
self.set_header("Transfer-Encoding", "chunked")
216+
await self.flush()
217+
218+
try:
219+
response_gen = MDInferenceResponse(
220+
prompt, model_params_obj
221+
).get_model_deployment_response(model_deployment_id)
222+
for chunk in response_gen:
223+
if not chunk:
224+
continue
225+
self.write(f"data: {chunk}\n\n")
226+
await self.flush()
227+
except StreamClosedError:
228+
self.log.warning("Client disconnected.")
229+
finally:
230+
self.finish()
234231

235232

236233
class AquaDeploymentParamsHandler(AquaAPIhandler):
@@ -294,5 +291,5 @@ def post(self, *args, **kwargs): # noqa: ARG002
294291
("deployments/?([^/]*)", AquaDeploymentHandler),
295292
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
296293
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
297-
("inference", AquaDeploymentInferenceHandler),
294+
("inference", AquaDeploymentStreamingInferenceHandler),
298295
]

ads/aqua/modeldeployment/inference.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,12 @@
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import json
7-
from dataclasses import dataclass, field
8-
9-
import requests
107

118
from ads.aqua.app import AquaApp
129
from ads.aqua.modeldeployment.entities import ModelParams
13-
from ads.common.auth import default_signer
1410
from ads.telemetry import telemetry
1511

1612

17-
@dataclass
1813
class MDInferenceResponse(AquaApp):
1914
"""Contains APIs for Aqua Model deployments Inference.
2015
@@ -30,11 +25,32 @@ class MDInferenceResponse(AquaApp):
3025
Creates an instance of model deployment via Aqua
3126
"""
3227

33-
prompt: str = None
34-
model_params: field(default_factory=ModelParams) = None
28+
def __init__(self, prompt=None, model_params=None):
29+
super().__init__()
30+
self.prompt = prompt
31+
self.model_params = model_params or ModelParams()
32+
33+
@staticmethod
34+
def stream_sanitizer(response):
35+
for chunk in response.data.raw.stream(1024 * 1024, decode_content=True):
36+
if not chunk:
37+
continue
38+
39+
try:
40+
decoded = chunk.decode("utf-8").strip()
41+
if not decoded.startswith("data:"):
42+
continue
43+
44+
data_json = decoded[len("data:") :].strip()
45+
parsed = json.loads(data_json)
46+
text = parsed["choices"][0]["text"]
47+
yield text
48+
49+
except Exception:
50+
continue
3551

3652
@telemetry(entry_point="plugin=inference&action=get_response", name="aqua")
37-
def get_model_deployment_response(self, endpoint):
53+
def get_model_deployment_response(self, model_deployment_id):
3854
"""
3955
Returns MD inference response
4056
@@ -67,8 +83,9 @@ def get_model_deployment_response(self, endpoint):
6783
key: value for key, value in params_dict.items() if value is not None
6884
}
6985
body = {"prompt": self.prompt, **params_dict}
70-
request_kwargs = {"json": body, "headers": {"Content-Type": "application/json"}}
71-
response = requests.post(
72-
endpoint, auth=default_signer()["signer"], **request_kwargs
86+
response = self.model_deployment_client.predict_with_response_stream(
87+
model_deployment_id=model_deployment_id, request_body=body
7388
)
74-
return json.loads(response.content)
89+
90+
for chunk in MDInferenceResponse.stream_sanitizer(response):
91+
yield chunk

ads/common/oci_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from oci.limits import LimitsClient
1919
from oci.logging import LoggingManagementClient
2020
from oci.marketplace import MarketplaceClient
21+
from oci.model_deployment import ModelDeploymentClient
2122
from oci.object_storage import ObjectStorageClient
2223
from oci.resource_search import ResourceSearchClient
2324
from oci.secrets import SecretsClient
@@ -69,6 +70,7 @@ def _client_impl(client):
6970
"vault": VaultsClient,
7071
"identity": IdentityClient,
7172
"compute": ComputeClient,
73+
"model_deployment": ModelDeploymentClient,
7274
"ai_language": AIServiceLanguageClient,
7375
"data_labeling_dp": DataLabelingClient,
7476
"data_labeling_cp": DataLabelingManagementClient,
@@ -114,6 +116,10 @@ def create_client(self, client_name):
114116
def object_storage(self):
115117
return self.create_client("object_storage")
116118

119+
@property
120+
def model_deployment(self):
121+
return self.create_client("model_deployment")
122+
117123
@property
118124
def compute(self):
119125
return self.create_client("compute")

ads/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ads.common.config import DEFAULT_CONFIG_PATH, DEFAULT_CONFIG_PROFILE, Config, Mode
1212

1313
OCI_ODSC_SERVICE_ENDPOINT = os.environ.get("OCI_ODSC_SERVICE_ENDPOINT")
14+
OCI_MD_SERVICE_ENDPOINT = os.environ.get("OCI_MD_SERVICE_ENDPOINT")
1415
OCI_IDENTITY_SERVICE_ENDPOINT = os.environ.get("OCI_IDENTITY_SERVICE_ENDPOINT")
1516
NB_SESSION_COMPARTMENT_OCID = os.environ.get("NB_SESSION_COMPARTMENT_OCID")
1617
PROJECT_OCID = os.environ.get("PROJECT_OCID") or os.environ.get("PIPELINE_PROJECT_OCID")

0 commit comments

Comments
 (0)