diff --git a/ads/aqua/common/enums.py b/ads/aqua/common/enums.py index 3c7f62c8f..37b32f337 100644 --- a/ads/aqua/common/enums.py +++ b/ads/aqua/common/enums.py @@ -20,6 +20,12 @@ class Resource(ExtendedEnum): MODEL_VERSION_SET = "model-version-sets" +class PredictEndpoints(ExtendedEnum): + CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions" + TEXT_COMPLETIONS_ENDPOINT = "/v1/completions" + EMBEDDING_ENDPOINT = "/v1/embedding" + + class Tags(ExtendedEnum): TASK = "task" LICENSE = "license" diff --git a/ads/aqua/extension/deployment_handler.py b/ads/aqua/extension/deployment_handler.py index 4c4fc2ac5..c0d2640ee 100644 --- a/ads/aqua/extension/deployment_handler.py +++ b/ads/aqua/extension/deployment_handler.py @@ -2,16 +2,18 @@ # Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -from typing import List, Union +from typing import List, Optional, Union from urllib.parse import urlparse from tornado.web import HTTPError +from ads.aqua.app import logger +from ads.aqua.client.client import Client, ExtendedRequestError from ads.aqua.common.decorator import handle_exceptions +from ads.aqua.common.enums import PredictEndpoints from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.errors import Errors -from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse -from ads.aqua.modeldeployment.entities import ModelParams +from ads.aqua.modeldeployment import AquaDeploymentApp from ads.config import COMPARTMENT_OCID @@ -175,23 +177,107 @@ def list_shapes(self): ) -class AquaDeploymentInferenceHandler(AquaAPIhandler): - @staticmethod - def validate_predict_url(endpoint): - try: - url = urlparse(endpoint) - if url.scheme != "https": - return False - if not url.netloc: - return False - return url.path.endswith("/predict") - except Exception: - return False +class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler): + def _get_model_deployment_response( + self, + model_deployment_id: str, + payload: dict, + route_override_header: Optional[str], + ): + """ + Returns the model deployment inference response in a streaming fashion. + + This method connects to the specified model deployment endpoint and + streams the inference output back to the caller, handling both text + and chat completion endpoints depending on the route override. + + Parameters + ---------- + model_deployment_id : str + The OCID of the model deployment to invoke. + Example: 'ocid1.datasciencemodeldeployment.iad.oc1.xxxyz' + + payload : dict + Dictionary containing the model inference parameters. + Same example for text completions: + { + "max_tokens": 1024, + "temperature": 0.5, + "prompt": "what are some good skills deep learning expert. Give us some tips on how to structure interview with some coding example?", + "top_p": 0.4, + "top_k": 100, + "model": "odsc-llm", + "frequency_penalty": 1, + "presence_penalty": 1, + "stream": true + } + + route_override_header : Optional[str] + Optional override for the inference route, used for routing between + different endpoint types (e.g., chat vs. text completions). + Example: '/v1/chat/completions' + + Returns + ------- + Generator[str] + A generator that yields strings of the model's output as they are received. + + Raises + ------ + HTTPError + If the request to the model deployment fails or if streaming cannot be established. + """ + + model_deployment = AquaDeploymentApp().get(model_deployment_id) + endpoint = model_deployment.endpoint + "/predictWithResponseStream" + endpoint_type = model_deployment.environment_variables.get( + "MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT + ) + aqua_client = Client(endpoint=endpoint) + + if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( + endpoint_type, + route_override_header, + ): + try: + for chunk in aqua_client.chat( + messages=payload.pop("messages"), + payload=payload, + stream=True, + ): + try: + yield chunk["choices"][0]["delta"]["content"] + except Exception as e: + logger.debug( + f"Exception occurred while parsing streaming response: {e}" + ) + except ExtendedRequestError as ex: + raise HTTPError(400, str(ex)) + except Exception as ex: + raise HTTPError(500, str(ex)) + + elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: + try: + for chunk in aqua_client.generate( + prompt=payload.pop("prompt"), + payload=payload, + stream=True, + ): + try: + yield chunk["choices"][0]["text"] + except Exception as e: + logger.debug( + f"Exception occurred while parsing streaming response: {e}" + ) + except ExtendedRequestError as ex: + raise HTTPError(400, str(ex)) + except Exception as ex: + raise HTTPError(500, str(ex)) @handle_exceptions - def post(self, *args, **kwargs): # noqa: ARG002 + def post(self, model_deployment_id): """ - Handles inference request for the Active Model Deployments + Handles streaming inference request for the Active Model Deployments Raises ------ HTTPError @@ -205,32 +291,29 @@ def post(self, *args, **kwargs): # noqa: ARG002 if not input_data: raise HTTPError(400, Errors.NO_INPUT_DATA) - endpoint = input_data.get("endpoint") - if not endpoint: - raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("endpoint")) - - if not self.validate_predict_url(endpoint): - raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT.format("endpoint")) - prompt = input_data.get("prompt") - if not prompt: - raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt")) + messages = input_data.get("messages") - model_params = ( - input_data.get("model_params") if input_data.get("model_params") else {} - ) - try: - model_params_obj = ModelParams(**model_params) - except Exception as ex: + if not prompt and not messages: raise HTTPError( - 400, Errors.INVALID_INPUT_DATA_FORMAT.format("model_params") - ) from ex - - return self.finish( - MDInferenceResponse(prompt, model_params_obj).get_model_deployment_response( - endpoint + 400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages") ) + if not input_data.get("model"): + raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model")) + route_override_header = self.request.headers.get("route", None) + self.set_header("Content-Type", "text/event-stream") + response_gen = self._get_model_deployment_response( + model_deployment_id, input_data, route_override_header ) + try: + for chunk in response_gen: + self.write(chunk) + self.flush() + self.finish() + except Exception as ex: + self.set_status(ex.status_code) + self.write({"message": "Error occurred", "reason": str(ex)}) + self.finish() class AquaDeploymentParamsHandler(AquaAPIhandler): @@ -294,5 +377,5 @@ def post(self, *args, **kwargs): # noqa: ARG002 ("deployments/?([^/]*)", AquaDeploymentHandler), ("deployments/?([^/]*)/activate", AquaDeploymentHandler), ("deployments/?([^/]*)/deactivate", AquaDeploymentHandler), - ("inference", AquaDeploymentInferenceHandler), + ("inference/stream/?([^/]*)", AquaDeploymentStreamingInferenceHandler), ] diff --git a/ads/aqua/modeldeployment/__init__.py b/ads/aqua/modeldeployment/__init__.py index baf5c5b53..4b44fbd3a 100644 --- a/ads/aqua/modeldeployment/__init__.py +++ b/ads/aqua/modeldeployment/__init__.py @@ -1,8 +1,6 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ from ads.aqua.modeldeployment.deployment import AquaDeploymentApp -from ads.aqua.modeldeployment.inference import MDInferenceResponse -__all__ = ["AquaDeploymentApp", "MDInferenceResponse"] +__all__ = ["AquaDeploymentApp"] diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 9394dfe3d..8866eb667 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -17,7 +17,11 @@ ComputeShapeSummary, ContainerPath, ) -from ads.aqua.common.enums import InferenceContainerTypeFamily, ModelFormat, Tags +from ads.aqua.common.enums import ( + InferenceContainerTypeFamily, + ModelFormat, + Tags, +) from ads.aqua.common.errors import AquaRuntimeError, AquaValueError from ads.aqua.common.utils import ( DEFINED_METADATA_TO_FILE_MAP, @@ -628,7 +632,9 @@ def _create_multi( config_data["model_task"] = model.model_task if model.fine_tune_weights_location: - config_data["fine_tune_weights_location"] = model.fine_tune_weights_location + config_data["fine_tune_weights_location"] = ( + model.fine_tune_weights_location + ) model_config.append(config_data) model_name_list.append(model.model_name) @@ -789,7 +795,7 @@ def _create_deployment( telemetry_kwargs = {"ocid": get_ocid_substring(deployment_id, key_len=8)} if Tags.BASE_MODEL_CUSTOM in tags: - telemetry_kwargs[ "custom_base_model"] = True + telemetry_kwargs["custom_base_model"] = True # tracks unique deployments that were created in the user compartment self.telemetry.record_event_async( @@ -934,7 +940,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": model_deployment = self.ds_client.get_model_deployment( model_deployment_id=model_deployment_id, **kwargs ).data - oci_aqua = ( ( Tags.AQUA_TAG in model_deployment.freeform_tags @@ -979,7 +984,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": aqua_deployment = AquaDeployment.from_oci_model_deployment( model_deployment, self.region ) - if Tags.MULTIMODEL_TYPE_TAG in model_deployment.freeform_tags: aqua_model_id = model_deployment.freeform_tags.get( Tags.AQUA_MODEL_ID_TAG, UNKNOWN @@ -1010,7 +1014,6 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": aqua_deployment.models = [ AquaMultiModelRef(**metadata) for metadata in multi_model_metadata ] - return AquaDeploymentDetail( **vars(aqua_deployment), log_group=AquaResourceIdentifier( @@ -1309,4 +1312,4 @@ def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]: or gpu_specs.shapes.get(oci_shape.name.upper()), ) for oci_shape in oci_shapes - ] \ No newline at end of file + ] diff --git a/ads/aqua/modeldeployment/inference.py b/ads/aqua/modeldeployment/inference.py deleted file mode 100644 index e5812ad25..000000000 --- a/ads/aqua/modeldeployment/inference.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python - -# Copyright (c) 2024, 2025 Oracle and/or its affiliates. -# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ - -import json -from dataclasses import dataclass, field - -import requests - -from ads.aqua.app import AquaApp -from ads.aqua.modeldeployment.entities import ModelParams -from ads.common.auth import default_signer -from ads.telemetry import telemetry - - -@dataclass -class MDInferenceResponse(AquaApp): - """Contains APIs for Aqua Model deployments Inference. - - Attributes - ---------- - - model_params: Dict - prompt: string - - Methods - ------- - get_model_deployment_response(self, **kwargs) -> "String" - Creates an instance of model deployment via Aqua - """ - - prompt: str = None - model_params: field(default_factory=ModelParams) = None - - @telemetry(entry_point="plugin=inference&action=get_response", name="aqua") - def get_model_deployment_response(self, endpoint): - """ - Returns MD inference response - - Parameters - ---------- - endpoint: str - MD predict url - prompt: str - User prompt. - - model_params: (Dict, optional) - Model parameters to be associated with the message. - Currently supported VLLM+OpenAI parameters. - - --model-params '{ - "max_tokens":500, - "temperature": 0.5, - "top_k": 10, - "top_p": 0.5, - "model": "/opt/ds/model/deployed_model", - ...}' - - Returns - ------- - model_response_content - """ - - params_dict = self.model_params.to_dict() - params_dict = { - key: value for key, value in params_dict.items() if value is not None - } - body = {"prompt": self.prompt, **params_dict} - request_kwargs = {"json": body, "headers": {"Content-Type": "application/json"}} - response = requests.post( - endpoint, auth=default_signer()["signer"], **request_kwargs - ) - return json.loads(response.content) diff --git a/ads/common/oci_client.py b/ads/common/oci_client.py index 6496f148d..b1c36262c 100644 --- a/ads/common/oci_client.py +++ b/ads/common/oci_client.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright (c) 2021, 2024 Oracle and/or its affiliates. +# Copyright (c) 2021, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import logging diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index d5d5c4565..df60a9f46 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -33,8 +33,7 @@ AquaContainerConfig, AquaContainerConfigItem, ) -from ads.aqua.model.enums import MultiModelSupportedTaskType -from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse +from ads.aqua.modeldeployment import AquaDeploymentApp from ads.aqua.modeldeployment.entities import ( AquaDeployment, AquaDeploymentConfig, @@ -487,7 +486,7 @@ class TestDataset: "model_name": "test_model_1", "model_task": "text_embedding", "artifact_location": "test_location_1", - "fine_tune_weights_location" : None + "fine_tune_weights_location": None, }, { "env_var": {}, @@ -496,7 +495,7 @@ class TestDataset: "model_name": "test_model_2", "model_task": "image_text_to_text", "artifact_location": "test_location_2", - "fine_tune_weights_location" : None + "fine_tune_weights_location": None, }, { "env_var": {}, @@ -505,7 +504,7 @@ class TestDataset: "model_name": "test_model_3", "model_task": "code_synthesis", "artifact_location": "test_location_3", - "fine_tune_weights_location" : "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad." + "fine_tune_weights_location": "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.", }, ], "model_id": "ocid1.datasciencemodel.oc1..", @@ -972,7 +971,7 @@ class TestDataset: "model_name": "model_one", "model_task": "text_embedding", "artifact_location": "artifact_location_one", - "fine_tune_weights_location": None + "fine_tune_weights_location": None, }, { "env_var": {"--test_key_two": "test_value_two"}, @@ -981,7 +980,7 @@ class TestDataset: "model_name": "model_two", "model_task": "image_text_to_text", "artifact_location": "artifact_location_two", - "fine_tune_weights_location": None + "fine_tune_weights_location": None, }, { "env_var": {"--test_key_three": "test_value_three"}, @@ -990,7 +989,7 @@ class TestDataset: "model_name": "model_three", "model_task": "code_synthesis", "artifact_location": "artifact_location_three", - "fine_tune_weights_location" : "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad." + "fine_tune_weights_location": "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.", }, ] @@ -1817,7 +1816,7 @@ def test_create_deployment_for_multi_model( model_task="code_synthesis", gpu_count=2, artifact_location="test_location_3", - fine_tune_weights_location= "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad." + fine_tune_weights_location="oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.", ) result = self.app.create( @@ -2283,36 +2282,3 @@ def test_validate_multimodel_deployment_feasibility_positive_single( total_gpus, "test_data/deployment/aqua_summary_multi_model_single.json", ) - - -class TestMDInferenceResponse(unittest.TestCase): - def setUp(self): - self.app = MDInferenceResponse() - - @classmethod - def setUpClass(cls): - cls.curr_dir = os.path.dirname(os.path.abspath(__file__)) - - @classmethod - def tearDownClass(cls): - cls.curr_dir = None - - @patch("requests.post") - def test_get_model_deployment_response(self, mock_post): - """Test to check if model deployment response is returned correctly.""" - - endpoint = TestDataset.MODEL_DEPLOYMENT_URL + "/predict" - self.app.prompt = "What is 1+1?" - self.app.model_params = ModelParams(**TestDataset.model_params) - - mock_response = MagicMock() - response_json = os.path.join( - self.curr_dir, "test_data/deployment/aqua_deployment_response.json" - ) - with open(response_json, "r") as _file: - mock_response.content = _file.read() - mock_response.status_code = 200 - mock_post.return_value = mock_response - - result = self.app.get_model_deployment_response(endpoint) - assert result["choices"][0]["text"] == " The answer is 2" diff --git a/tests/unitary/with_extras/aqua/test_deployment_handler.py b/tests/unitary/with_extras/aqua/test_deployment_handler.py index 9e9be2b34..78a814684 100644 --- a/tests/unitary/with_extras/aqua/test_deployment_handler.py +++ b/tests/unitary/with_extras/aqua/test_deployment_handler.py @@ -16,8 +16,8 @@ import ads.config from ads.aqua.extension.deployment_handler import ( AquaDeploymentHandler, - AquaDeploymentInferenceHandler, AquaDeploymentParamsHandler, + AquaDeploymentStreamingInferenceHandler, ) @@ -224,23 +224,39 @@ def test_validate_deployment_params( ) -class TestAquaDeploymentInferenceHandler(unittest.TestCase): +class TestAquaDeploymentStreamingInferenceHandler(unittest.TestCase): @patch.object(IPythonHandler, "__init__") def setUp(self, ipython_init_mock) -> None: ipython_init_mock.return_value = None - self.inference_handler = AquaDeploymentInferenceHandler( - MagicMock(), MagicMock() - ) - self.inference_handler.request = MagicMock() - self.inference_handler.finish = MagicMock() - - @patch("ads.aqua.modeldeployment.MDInferenceResponse.get_model_deployment_response") + self.handler = AquaDeploymentStreamingInferenceHandler(MagicMock(), MagicMock()) + self.handler.request = MagicMock() + self.handler.set_header = MagicMock() + self.handler.write = MagicMock() + self.handler.flush = MagicMock() + self.handler.finish = MagicMock() + + @patch.object( + AquaDeploymentStreamingInferenceHandler, "_get_model_deployment_response" + ) def test_post(self, mock_get_model_deployment_response): """Test post method to return model deployment response.""" - self.inference_handler.get_json_body = MagicMock( - return_value=TestDataset.inference_request + mock_response_gen = iter(["chunk1", "chunk2"]) + + mock_get_model_deployment_response.return_value = mock_response_gen + + self.handler.get_json_body = MagicMock( + return_value={"prompt": "Hello", "model": "some-model"} ) - self.inference_handler.post() + self.handler.request.headers = MagicMock() + self.handler.request.headers.get.return_value = "test-route" + + self.handler.post("mock-deployment-id") + mock_get_model_deployment_response.assert_called_with( - TestDataset.inference_request["endpoint"] + "mock-deployment-id", + {"prompt": "Hello", "model": "some-model"}, + "test-route", ) + self.handler.write.assert_any_call("chunk1") + self.handler.write.assert_any_call("chunk2") + self.handler.finish.assert_called_once()