2
2
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
5
- from typing import List , Union
5
+ from typing import List , Optional , Union
6
6
from urllib .parse import urlparse
7
7
8
8
from tornado .web import HTTPError
9
9
10
+ from ads .aqua .app import logger
11
+ from ads .aqua .client .client import Client , ExtendedRequestError
10
12
from ads .aqua .common .decorator import handle_exceptions
13
+ from ads .aqua .common .enums import PredictEndpoints
11
14
from ads .aqua .extension .base_handler import AquaAPIhandler
12
15
from ads .aqua .extension .errors import Errors
13
- from ads .aqua .modeldeployment import AquaDeploymentApp , MDInferenceResponse
14
- from ads .aqua .modeldeployment .entities import ModelParams
16
+ from ads .aqua .modeldeployment import AquaDeploymentApp
15
17
from ads .config import COMPARTMENT_OCID
16
18
17
19
@@ -175,23 +177,107 @@ def list_shapes(self):
175
177
)
176
178
177
179
178
- class AquaDeploymentInferenceHandler (AquaAPIhandler ):
179
- @staticmethod
180
- def validate_predict_url (endpoint ):
181
- try :
182
- url = urlparse (endpoint )
183
- if url .scheme != "https" :
184
- return False
185
- if not url .netloc :
186
- return False
187
- return url .path .endswith ("/predict" )
188
- except Exception :
189
- return False
180
+ class AquaDeploymentStreamingInferenceHandler (AquaAPIhandler ):
181
+ def _get_model_deployment_response (
182
+ self ,
183
+ model_deployment_id : str ,
184
+ payload : dict ,
185
+ route_override_header : Optional [str ],
186
+ ):
187
+ """
188
+ Returns the model deployment inference response in a streaming fashion.
189
+
190
+ This method connects to the specified model deployment endpoint and
191
+ streams the inference output back to the caller, handling both text
192
+ and chat completion endpoints depending on the route override.
193
+
194
+ Parameters
195
+ ----------
196
+ model_deployment_id : str
197
+ The OCID of the model deployment to invoke.
198
+ Example: 'ocid1.datasciencemodeldeployment.iad.oc1.xxxyz'
199
+
200
+ payload : dict
201
+ Dictionary containing the model inference parameters.
202
+ Same example for text completions:
203
+ {
204
+ "max_tokens": 1024,
205
+ "temperature": 0.5,
206
+ "prompt": "what are some good skills deep learning expert. Give us some tips on how to structure interview with some coding example?",
207
+ "top_p": 0.4,
208
+ "top_k": 100,
209
+ "model": "odsc-llm",
210
+ "frequency_penalty": 1,
211
+ "presence_penalty": 1,
212
+ "stream": true
213
+ }
214
+
215
+ route_override_header : Optional[str]
216
+ Optional override for the inference route, used for routing between
217
+ different endpoint types (e.g., chat vs. text completions).
218
+ Example: '/v1/chat/completions'
219
+
220
+ Returns
221
+ -------
222
+ Generator[str]
223
+ A generator that yields strings of the model's output as they are received.
224
+
225
+ Raises
226
+ ------
227
+ HTTPError
228
+ If the request to the model deployment fails or if streaming cannot be established.
229
+ """
230
+
231
+ model_deployment = AquaDeploymentApp ().get (model_deployment_id )
232
+ endpoint = model_deployment .endpoint + "/predictWithResponseStream"
233
+ endpoint_type = model_deployment .environment_variables .get (
234
+ "MODEL_DEPLOY_PREDICT_ENDPOINT" , PredictEndpoints .TEXT_COMPLETIONS_ENDPOINT
235
+ )
236
+ aqua_client = Client (endpoint = endpoint )
237
+
238
+ if PredictEndpoints .CHAT_COMPLETIONS_ENDPOINT in (
239
+ endpoint_type ,
240
+ route_override_header ,
241
+ ):
242
+ try :
243
+ for chunk in aqua_client .chat (
244
+ messages = payload .pop ("messages" ),
245
+ payload = payload ,
246
+ stream = True ,
247
+ ):
248
+ try :
249
+ yield chunk ["choices" ][0 ]["delta" ]["content" ]
250
+ except Exception as e :
251
+ logger .debug (
252
+ f"Exception occurred while parsing streaming response: { e } "
253
+ )
254
+ except ExtendedRequestError as ex :
255
+ raise HTTPError (400 , str (ex ))
256
+ except Exception as ex :
257
+ raise HTTPError (500 , str (ex ))
258
+
259
+ elif endpoint_type == PredictEndpoints .TEXT_COMPLETIONS_ENDPOINT :
260
+ try :
261
+ for chunk in aqua_client .generate (
262
+ prompt = payload .pop ("prompt" ),
263
+ payload = payload ,
264
+ stream = True ,
265
+ ):
266
+ try :
267
+ yield chunk ["choices" ][0 ]["text" ]
268
+ except Exception as e :
269
+ logger .debug (
270
+ f"Exception occurred while parsing streaming response: { e } "
271
+ )
272
+ except ExtendedRequestError as ex :
273
+ raise HTTPError (400 , str (ex ))
274
+ except Exception as ex :
275
+ raise HTTPError (500 , str (ex ))
190
276
191
277
@handle_exceptions
192
- def post (self , * args , ** kwargs ): # noqa: ARG002
278
+ def post (self , model_deployment_id ):
193
279
"""
194
- Handles inference request for the Active Model Deployments
280
+ Handles streaming inference request for the Active Model Deployments
195
281
Raises
196
282
------
197
283
HTTPError
@@ -205,32 +291,29 @@ def post(self, *args, **kwargs): # noqa: ARG002
205
291
if not input_data :
206
292
raise HTTPError (400 , Errors .NO_INPUT_DATA )
207
293
208
- endpoint = input_data .get ("endpoint" )
209
- if not endpoint :
210
- raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("endpoint" ))
211
-
212
- if not self .validate_predict_url (endpoint ):
213
- raise HTTPError (400 , Errors .INVALID_INPUT_DATA_FORMAT .format ("endpoint" ))
214
-
215
294
prompt = input_data .get ("prompt" )
216
- if not prompt :
217
- raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("prompt" ))
295
+ messages = input_data .get ("messages" )
218
296
219
- model_params = (
220
- input_data .get ("model_params" ) if input_data .get ("model_params" ) else {}
221
- )
222
- try :
223
- model_params_obj = ModelParams (** model_params )
224
- except Exception as ex :
297
+ if not prompt and not messages :
225
298
raise HTTPError (
226
- 400 , Errors .INVALID_INPUT_DATA_FORMAT .format ("model_params" )
227
- ) from ex
228
-
229
- return self .finish (
230
- MDInferenceResponse (prompt , model_params_obj ).get_model_deployment_response (
231
- endpoint
299
+ 400 , Errors .MISSING_REQUIRED_PARAMETER .format ("prompt/messages" )
232
300
)
301
+ if not input_data .get ("model" ):
302
+ raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("model" ))
303
+ route_override_header = self .request .headers .get ("route" , None )
304
+ self .set_header ("Content-Type" , "text/event-stream" )
305
+ response_gen = self ._get_model_deployment_response (
306
+ model_deployment_id , input_data , route_override_header
233
307
)
308
+ try :
309
+ for chunk in response_gen :
310
+ self .write (chunk )
311
+ self .flush ()
312
+ self .finish ()
313
+ except Exception as ex :
314
+ self .set_status (ex .status_code )
315
+ self .write ({"message" : "Error occurred" , "reason" : str (ex )})
316
+ self .finish ()
234
317
235
318
236
319
class AquaDeploymentParamsHandler (AquaAPIhandler ):
@@ -294,5 +377,5 @@ def post(self, *args, **kwargs): # noqa: ARG002
294
377
("deployments/?([^/]*)" , AquaDeploymentHandler ),
295
378
("deployments/?([^/]*)/activate" , AquaDeploymentHandler ),
296
379
("deployments/?([^/]*)/deactivate" , AquaDeploymentHandler ),
297
- ("inference" , AquaDeploymentInferenceHandler ),
380
+ ("inference/stream/?([^/]*) " , AquaDeploymentStreamingInferenceHandler ),
298
381
]
0 commit comments