From 7b171984e4a3fc6cb4e90f0f3fce943e08e870f1 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Sun, 30 Mar 2025 15:49:53 -0700 Subject: [PATCH 1/7] Add AquaOpenAI and AsyncAquaOpenAI Clients --- ads/aqua/client/openai_client.py | 317 +++++++++++++++++++++++++++++++ 1 file changed, 317 insertions(+) create mode 100644 ads/aqua/client/openai_client.py diff --git a/ads/aqua/client/openai_client.py b/ads/aqua/client/openai_client.py new file mode 100644 index 000000000..0225d7609 --- /dev/null +++ b/ads/aqua/client/openai_client.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import json +import logging +import re +from typing import Any, Dict, Optional +from urllib.parse import urlparse, urlunparse + +import httpx +from git import Union + +from ads.aqua.client.client import get_async_httpx_client, get_httpx_client + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = httpx.Timeout(timeout=600, connect=5.0) +DEFAULT_MAX_RETRIES = 2 + +try: + from openai import AsyncOpenAI, OpenAI +except ImportError as e: + raise ModuleNotFoundError( + "The custom OpenAI client requires the `openai-python` package. " + "Please install it with `pip install openai`." + ) from e + + +class AquaAIMixin: + """ + Mixin that provides common logic to patch request headers and URLs + for both synchronous and asynchronous clients. + """ + + def _patch_route(self, original_path: str) -> str: + """ + Dynamically determine the route header based on the URL path. + This method extracts the portion of the path that follows the deployment OCID. + It does so by normalizing the path, splitting it into segments, and then using a + regular expression to identify the OCID segment (e.g. "ocid1.datasciencemodeldeployment…"). + All segments following the OCID (skipping an initial + 'predict', if present) are joined to form the route header (prefixed with "/v1/"). + If no extra segment is found, it defaults to "/predict". + + Args: + original_path (str): The original URL path. + + Returns: + str: The computed route header. + """ + normalized = original_path.strip("/").lower() + segments = normalized.split("/") + ocid_pattern = re.compile( + r"^ocid\d+\.datasciencemodeldeployment", re.IGNORECASE + ) + base_index = None + for i, seg in enumerate(segments): + if ocid_pattern.match(seg): + base_index = i + break + + if base_index is None: + route = f"/v1/{segments[-1]}" if segments and segments[-1] else "/predict" + logger.debug("OCID not found; using fallback route: %s", route) + return route + + remainder = segments[base_index + 1 :] + if remainder and remainder[0] == "predict": + remainder = remainder[1:] + + route = f"/v1/{'/'.join(remainder)}" if remainder else "" + logger.debug("Computed route from path '%s': %s", original_path, route) + return route + + def _patch_url_path(self, original_url: str) -> httpx.URL: + """ + Normalize the URL path so that it always ends with '/predict'. + + This function uses the OCID in the URL to extract the base deployment path, + then discards any additional endpoint segments and appends '/predict'. + This design is robust against future changes, as it relies on identifying the OCID. + + Args: + original_url (str): The original URL. + + Returns: + httpx.URL: The normalized URL with its path ending in '/predict'. + """ + parsed_url = urlparse(original_url) + # Split the path into non-empty segments. + path_segments = [seg for seg in parsed_url.path.split("/") if seg] + ocid_pattern = re.compile( + r"^ocid\d+\.datasciencemodeldeployment", re.IGNORECASE + ) + base_index = None + for i, segment in enumerate(path_segments): + if ocid_pattern.match(segment): + base_index = i + break + + if base_index is not None: + base_path = "/" + "/".join(path_segments[: base_index + 1]) + else: + base_path = "" + logger.debug("OCID not found in URL path; using empty base.") + + new_path = f"{base_path}/predict" if base_path else "/predict" + new_url = urlunparse(parsed_url._replace(path=new_path, query="")) + logger.debug("Normalized URL path to: %s", new_url) + return httpx.URL(new_url) + + def _patch_streaming(self, request: httpx.Request) -> None: + """ + Set the 'enable-streaming' header based on whether the JSON request body contains + a 'stream': true parameter. + + If the Content-Type is JSON, the request body is parsed. If the key 'stream' is set to True, + the header 'enable-streaming' is set to "true". Otherwise, it is set to "false". + If parsing fails, a warning is logged and the default value remains "false". + + Args: + request (httpx.Request): The outgoing HTTP request. + """ + streaming_enabled = "false" + content_type = request.headers.get("Content-Type", "") + if "application/json" in content_type and request.content: + try: + body_str = ( + request.content.decode("utf-8") + if isinstance(request.content, bytes) + else request.content + ) + data = json.loads(body_str) + if data.get("stream") is True: + streaming_enabled = "true" + except Exception as e: + logger.exception("Failed to parse JSON from request body: %s", e) + request.headers.setdefault("enable-streaming", streaming_enabled) + logger.debug( + "Patched streaming header to: %s", request.headers["enable-streaming"] + ) + + def _patch_headers(self, request: httpx.Request) -> None: + """ + Patch the headers of the request by setting the 'enable-streaming' and 'route' headers. + + Args: + request (httpx.Request): The HTTP request to patch. + """ + self._patch_streaming(request) + request.headers.setdefault("route", self._patch_route(request.url.path)) + logger.debug("Patched route header to: %s", request.headers["route"]) + + def _prepare_request_common(self, request: httpx.Request) -> None: + """ + Prepare the HTTP request by patching headers and normalizing the URL path. + + This method: + 1. Automatically sets the 'enable-streaming' header based on the request body. + 2. Determines the 'route' header based on the original URL path using OCID-based extraction. + 3. Rewrites the URL path to always end with '/predict' based on the deployment base. + + Args: + request (httpx.Request): The outgoing HTTP request. + """ + logger.debug("Original headers: %s", request.headers) + self._patch_headers(request) + logger.debug("Headers after patching: %s", request.headers) + new_url = self._patch_url_path(str(request.url)) + logger.debug("Rewriting URL from %s to %s", request.url, new_url) + request.url = new_url + + +class AquaOpenAI(OpenAI, AquaAIMixin): + def __init__( + self, + *, + api_key: Optional[str] = None, + organization: Optional[str] = None, + project: Optional[str] = None, + base_url: Optional[Union[str, httpx.URL]] = None, + websocket_base_url: Optional[Union[str, httpx.URL]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Optional[Dict[str, str]] = None, + default_query: Optional[Dict[str, object]] = None, + http_client: Optional[httpx.Client] = None, + http_client_kwargs: Optional[Dict[str, Any]] = None, + _strict_response_validation: bool = False, + **kwargs: Any, + ) -> None: + """ + Construct a new synchronous AquaOpenAI client instance. + + If no http_client is provided, one will be automatically created using ads.aqua.get_httpx_client(). + + Args: + api_key (str, optional): API key for authentication. Defaults to env variable OPENAI_API_KEY. + organization (str, optional): Organization ID. Defaults to env variable OPENAI_ORG_ID. + project (str, optional): Project ID. Defaults to env variable OPENAI_PROJECT_ID. + base_url (str | httpx.URL, optional): Base URL for the API. + websocket_base_url (str | httpx.URL, optional): Base URL for WebSocket connections. + timeout (float | httpx.Timeout, optional): Timeout for API requests. + max_retries (int, optional): Maximum number of retries for API requests. + default_headers (dict[str, str], optional): Additional headers. + default_query (dict[str, object], optional): Additional query parameters. + http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created. + http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client. + _strict_response_validation (bool, optional): Enable strict response validation. + **kwargs: Additional keyword arguments passed to the parent __init__. + """ + if http_client is None: + logger.debug( + "No http_client provided; auto-creating one using ads.aqua.get_httpx_client()" + ) + http_client = get_httpx_client(**(http_client_kwargs or {})) + if not api_key: + logger.debug("API key not provided; using default placeholder for OCI.") + api_key = "OCI" + + super().__init__( + api_key=api_key, + organization=organization, + project=project, + base_url=base_url, + websocket_base_url=websocket_base_url, + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + default_query=default_query, + http_client=http_client, + _strict_response_validation=_strict_response_validation, + **kwargs, + ) + + def _prepare_request(self, request: httpx.Request) -> None: + """ + Prepare the synchronous HTTP request by applying common modifications. + + Args: + request (httpx.Request): The outgoing HTTP request. + """ + self._prepare_request_common(request) + + +class AsyncAquaOpenAI(AsyncOpenAI, AquaAIMixin): + def __init__( + self, + *, + api_key: Optional[str] = None, + organization: Optional[str] = None, + project: Optional[str] = None, + base_url: Optional[Union[str, httpx.URL]] = None, + websocket_base_url: Optional[Union[str, httpx.URL]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Optional[Dict[str, str]] = None, + default_query: Optional[Dict[str, object]] = None, + http_client: Optional[httpx.Client] = None, + http_client_kwargs: Optional[Dict[str, Any]] = None, + _strict_response_validation: bool = False, + **kwargs: Any, + ) -> None: + """ + Construct a new asynchronous AsyncAquaOpenAI client instance. + + If no http_client is provided, one will be automatically created using + ads.aqua.get_async_httpx_client(). + + Args: + api_key (str, optional): API key for authentication. Defaults to env variable OPENAI_API_KEY. + organization (str, optional): Organization ID. + project (str, optional): Project ID. + base_url (str | httpx.URL, optional): Base URL for the API. + websocket_base_url (str | httpx.URL, optional): Base URL for WebSocket connections. + timeout (float | httpx.Timeout, optional): Timeout for API requests. + max_retries (int, optional): Maximum number of retries for API requests. + default_headers (dict[str, str], optional): Additional headers. + default_query (dict[str, object], optional): Additional query parameters. + http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created. + http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client. + _strict_response_validation (bool, optional): Enable strict response validation. + **kwargs: Additional keyword arguments passed to the parent __init__. + """ + if http_client is None: + logger.debug( + "No async http_client provided; auto-creating one using ads.aqua.get_async_httpx_client()" + ) + http_client = get_async_httpx_client(**(http_client_kwargs or {})) + if not api_key: + logger.debug("API key not provided; using default placeholder for OCI.") + api_key = "OCI" + + super().__init__( + api_key=api_key, + organization=organization, + project=project, + base_url=base_url, + websocket_base_url=websocket_base_url, + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + default_query=default_query, + http_client=http_client, + _strict_response_validation=_strict_response_validation, + **kwargs, + ) + + async def _prepare_request(self, request: httpx.Request) -> None: + """ + Asynchronously prepare the HTTP request by applying common modifications. + + Args: + request (httpx.Request): The outgoing HTTP request. + """ + self._prepare_request_common(request) From c5ac9e65535142cca362ac8089a7f09c7e242eb6 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Sun, 30 Mar 2025 22:31:32 -0700 Subject: [PATCH 2/7] Adds OpenAi client docs --- .../large_language_model/aqua_client.rst | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/docs/source/user_guide/large_language_model/aqua_client.rst b/docs/source/user_guide/large_language_model/aqua_client.rst index 45bf40578..b37a616fa 100644 --- a/docs/source/user_guide/large_language_model/aqua_client.rst +++ b/docs/source/user_guide/large_language_model/aqua_client.rst @@ -146,6 +146,7 @@ Usage .. code-block:: python3 import ads + import ads.aqua ads.set_auth(auth="security_token", profile="") @@ -167,7 +168,84 @@ Usage .. code-block:: python3 import ads + import ads.aqua ads.set_auth(auth="security_token", profile="") async_client = client = ads.aqua.get_async_httpx_client(timeout=10.0) + + +Aqua OpenAI Client +================== + +.. versionadded:: 2.13.4 + +The **AquaOpenAI** and **AsyncAquaOpenAI** clients extend the official OpenAI Python SDK to support OCI-based model deployments. They automatically patch request headers and normalize URL paths based on the deployment OCID, ensuring that API calls are sent in the proper format. + +Requirements +------------ +To use these clients, you must have the ``openai-python`` package installed. This package is an optional dependency. If it is not installed, you will receive an informative error when attempting to instantiate one of these clients. To install the package, run: + +.. code-block:: bash + + pip install openai + + +Usage +----- +Both synchronous and asynchronous versions are available. + +**Synchronous Client** + +The synchronous client, ``AquaOpenAI``, extends the OpenAI client. If no HTTP client is provided, it will automatically create one using ``ads.aqua.get_httpx_client()``. + +.. code-block:: python + + import ads + from ads.aqua.client.openai_client import AquaOpenAI + ads.set_auth(auth="security_token", profile="") + + client = AquaOpenAI( + base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/", + ) + + response = client.chat.completions.create( + model="odsc-llm", + messages=[ + { + "role": "user", + "content": "Tell me a joke.", + } + ], + # stream=True, # enable for streaming + ) + + print(response) + + +**Asynchronous Client** + +The asynchronous client, ``AsyncAquaOpenAI``, extends the AsyncOpenAI client. If no async HTTP client is provided, it will automatically create one using ``ads.aqua.get_async_httpx_client()``. + +.. code-block:: python + + import ads + import asyncio + import nest_asyncio + from ads.aqua.client.openai_client import AsyncAquaOpenAI + + ads.set_auth(auth="security_token") + + async def test_async() -> None: + client_async = AsyncAquaOpenAI( + base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/", + ) + response = await client_async.chat.completions.create( + model="odsc-llm", + messages=[{"role": "user", "content": "Tell me a long joke"}], + stream=True + ) + async for event in response: + print(event) + + asyncio.run(test_async()) From cebe0367d7c32f107856c82b5a8aabadd98c944f Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Mon, 31 Mar 2025 17:21:41 -0700 Subject: [PATCH 3/7] Simplifies the logic --- ads/aqua/client/openai_client.py | 83 +++++--------------------------- 1 file changed, 11 insertions(+), 72 deletions(-) diff --git a/ads/aqua/client/openai_client.py b/ads/aqua/client/openai_client.py index 0225d7609..8b0759291 100644 --- a/ads/aqua/client/openai_client.py +++ b/ads/aqua/client/openai_client.py @@ -4,9 +4,7 @@ import json import logging -import re from typing import Any, Dict, Optional -from urllib.parse import urlparse, urlunparse import httpx from git import Union @@ -35,80 +33,20 @@ class AquaAIMixin: def _patch_route(self, original_path: str) -> str: """ - Dynamically determine the route header based on the URL path. - This method extracts the portion of the path that follows the deployment OCID. - It does so by normalizing the path, splitting it into segments, and then using a - regular expression to identify the OCID segment (e.g. "ocid1.datasciencemodeldeployment…"). - All segments following the OCID (skipping an initial - 'predict', if present) are joined to form the route header (prefixed with "/v1/"). - If no extra segment is found, it defaults to "/predict". + Determine the appropriate route header based on the original URL path. Args: original_path (str): The original URL path. Returns: - str: The computed route header. + str: The route header value. """ - normalized = original_path.strip("/").lower() - segments = normalized.split("/") - ocid_pattern = re.compile( - r"^ocid\d+\.datasciencemodeldeployment", re.IGNORECASE + route = ( + original_path.lower() + .rstrip("/") + .replace(self.base_url.path.lower().rstrip("/"), "") ) - base_index = None - for i, seg in enumerate(segments): - if ocid_pattern.match(seg): - base_index = i - break - - if base_index is None: - route = f"/v1/{segments[-1]}" if segments and segments[-1] else "/predict" - logger.debug("OCID not found; using fallback route: %s", route) - return route - - remainder = segments[base_index + 1 :] - if remainder and remainder[0] == "predict": - remainder = remainder[1:] - - route = f"/v1/{'/'.join(remainder)}" if remainder else "" - logger.debug("Computed route from path '%s': %s", original_path, route) - return route - - def _patch_url_path(self, original_url: str) -> httpx.URL: - """ - Normalize the URL path so that it always ends with '/predict'. - - This function uses the OCID in the URL to extract the base deployment path, - then discards any additional endpoint segments and appends '/predict'. - This design is robust against future changes, as it relies on identifying the OCID. - - Args: - original_url (str): The original URL. - - Returns: - httpx.URL: The normalized URL with its path ending in '/predict'. - """ - parsed_url = urlparse(original_url) - # Split the path into non-empty segments. - path_segments = [seg for seg in parsed_url.path.split("/") if seg] - ocid_pattern = re.compile( - r"^ocid\d+\.datasciencemodeldeployment", re.IGNORECASE - ) - base_index = None - for i, segment in enumerate(path_segments): - if ocid_pattern.match(segment): - base_index = i - break - - if base_index is not None: - base_path = "/" + "/".join(path_segments[: base_index + 1]) - else: - base_path = "" - logger.debug("OCID not found in URL path; using empty base.") - - new_path = f"{base_path}/predict" if base_path else "/predict" - new_url = urlunparse(parsed_url._replace(path=new_path, query="")) - logger.debug("Normalized URL path to: %s", new_url) - return httpx.URL(new_url) + return f"/v1{route}" if route else "" def _patch_streaming(self, request: httpx.Request) -> None: """ @@ -164,12 +102,13 @@ def _prepare_request_common(self, request: httpx.Request) -> None: Args: request (httpx.Request): The outgoing HTTP request. """ + # Patches the headers logger.debug("Original headers: %s", request.headers) self._patch_headers(request) logger.debug("Headers after patching: %s", request.headers) - new_url = self._patch_url_path(str(request.url)) - logger.debug("Rewriting URL from %s to %s", request.url, new_url) - request.url = new_url + + # Patches the URL + request.url = self.base_url.copy_with(path=self.base_url.path.rstrip("/")) class AquaOpenAI(OpenAI, AquaAIMixin): From 3f0143e30a1c1059c6f70211743bb7af1bcf6bfc Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Mon, 31 Mar 2025 17:31:21 -0700 Subject: [PATCH 4/7] Fixes the docs --- docs/source/user_guide/large_language_model/aqua_client.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/user_guide/large_language_model/aqua_client.rst b/docs/source/user_guide/large_language_model/aqua_client.rst index b37a616fa..45ad93edf 100644 --- a/docs/source/user_guide/large_language_model/aqua_client.rst +++ b/docs/source/user_guide/large_language_model/aqua_client.rst @@ -206,7 +206,7 @@ The synchronous client, ``AquaOpenAI``, extends the OpenAI client. If no HTTP cl ads.set_auth(auth="security_token", profile="") client = AquaOpenAI( - base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/", + base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", ) response = client.chat.completions.create( @@ -238,7 +238,7 @@ The asynchronous client, ``AsyncAquaOpenAI``, extends the AsyncOpenAI client. If async def test_async() -> None: client_async = AsyncAquaOpenAI( - base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/", + base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", ) response = await client_async.chat.completions.create( model="odsc-llm", From c1e3fb07a5c894ebbd847c7246c8089ac7b32e5c Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Mon, 31 Mar 2025 20:55:55 -0700 Subject: [PATCH 5/7] Chaned the names of the classes --- ads/aqua/client/openai_client.py | 12 ++++++------ .../large_language_model/aqua_client.rst | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/ads/aqua/client/openai_client.py b/ads/aqua/client/openai_client.py index 8b0759291..011032593 100644 --- a/ads/aqua/client/openai_client.py +++ b/ads/aqua/client/openai_client.py @@ -17,7 +17,7 @@ DEFAULT_MAX_RETRIES = 2 try: - from openai import AsyncOpenAI, OpenAI + import openai except ImportError as e: raise ModuleNotFoundError( "The custom OpenAI client requires the `openai-python` package. " @@ -25,7 +25,7 @@ ) from e -class AquaAIMixin: +class AquaOpenAIMixin: """ Mixin that provides common logic to patch request headers and URLs for both synchronous and asynchronous clients. @@ -111,7 +111,7 @@ def _prepare_request_common(self, request: httpx.Request) -> None: request.url = self.base_url.copy_with(path=self.base_url.path.rstrip("/")) -class AquaOpenAI(OpenAI, AquaAIMixin): +class OpenAI(openai.OpenAI, AquaOpenAIMixin): def __init__( self, *, @@ -130,7 +130,7 @@ def __init__( **kwargs: Any, ) -> None: """ - Construct a new synchronous AquaOpenAI client instance. + Construct a new synchronous OpenAI client instance. If no http_client is provided, one will be automatically created using ads.aqua.get_httpx_client(). @@ -183,7 +183,7 @@ def _prepare_request(self, request: httpx.Request) -> None: self._prepare_request_common(request) -class AsyncAquaOpenAI(AsyncOpenAI, AquaAIMixin): +class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin): def __init__( self, *, @@ -202,7 +202,7 @@ def __init__( **kwargs: Any, ) -> None: """ - Construct a new asynchronous AsyncAquaOpenAI client instance. + Construct a new asynchronous AsyncOpenAI client instance. If no http_client is provided, one will be automatically created using ads.aqua.get_async_httpx_client(). diff --git a/docs/source/user_guide/large_language_model/aqua_client.rst b/docs/source/user_guide/large_language_model/aqua_client.rst index 45ad93edf..682569be3 100644 --- a/docs/source/user_guide/large_language_model/aqua_client.rst +++ b/docs/source/user_guide/large_language_model/aqua_client.rst @@ -180,7 +180,7 @@ Aqua OpenAI Client .. versionadded:: 2.13.4 -The **AquaOpenAI** and **AsyncAquaOpenAI** clients extend the official OpenAI Python SDK to support OCI-based model deployments. They automatically patch request headers and normalize URL paths based on the deployment OCID, ensuring that API calls are sent in the proper format. +The **OpenAI** and **AsyncOpenAI** clients extend the official OpenAI Python SDK to support OCI-based model deployments. They automatically patch request headers and normalize URL paths based on the deployment OCID, ensuring that API calls are sent in the proper format. Requirements ------------ @@ -197,15 +197,15 @@ Both synchronous and asynchronous versions are available. **Synchronous Client** -The synchronous client, ``AquaOpenAI``, extends the OpenAI client. If no HTTP client is provided, it will automatically create one using ``ads.aqua.get_httpx_client()``. +The synchronous client, ``OpenAI``, extends the OpenAI client. If no HTTP client is provided, it will automatically create one using ``ads.aqua.get_httpx_client()``. .. code-block:: python import ads - from ads.aqua.client.openai_client import AquaOpenAI + from ads.aqua.client.openai_client import OpenAI ads.set_auth(auth="security_token", profile="") - client = AquaOpenAI( + client = OpenAI( base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", ) @@ -225,19 +225,19 @@ The synchronous client, ``AquaOpenAI``, extends the OpenAI client. If no HTTP cl **Asynchronous Client** -The asynchronous client, ``AsyncAquaOpenAI``, extends the AsyncOpenAI client. If no async HTTP client is provided, it will automatically create one using ``ads.aqua.get_async_httpx_client()``. +The asynchronous client, ``AsynOpenAI``, extends the AsyncOpenAI client. If no async HTTP client is provided, it will automatically create one using ``ads.aqua.get_async_httpx_client()``. .. code-block:: python import ads import asyncio import nest_asyncio - from ads.aqua.client.openai_client import AsyncAquaOpenAI + from ads.aqua.client.openai_client import AsyncOpenAI ads.set_auth(auth="security_token") async def test_async() -> None: - client_async = AsyncAquaOpenAI( + client_async = AsyncOpenAI( base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", ) response = await client_async.chat.completions.create( From d8fe81235062a489d859a4b6540d7a6bc7330dd1 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Tue, 1 Apr 2025 12:44:22 -0700 Subject: [PATCH 6/7] Enhances documentaiton --- .../user_guide/large_language_model/aqua_client.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/source/user_guide/large_language_model/aqua_client.rst b/docs/source/user_guide/large_language_model/aqua_client.rst index 682569be3..12c6313a7 100644 --- a/docs/source/user_guide/large_language_model/aqua_client.rst +++ b/docs/source/user_guide/large_language_model/aqua_client.rst @@ -180,7 +180,15 @@ Aqua OpenAI Client .. versionadded:: 2.13.4 -The **OpenAI** and **AsyncOpenAI** clients extend the official OpenAI Python SDK to support OCI-based model deployments. They automatically patch request headers and normalize URL paths based on the deployment OCID, ensuring that API calls are sent in the proper format. +The Oracle-ADS **OpenAI** and **AsyncOpenAI** clients extend the official OpenAI Python SDK to support model deployments on **OCI**. These clients automatically patch request headers and normalize URL paths based on the provided deployment OCID, ensuring that API calls are formatted correctly for OCI Model Deployment. + +You can refer to the official `Open AI quick start examples `_ for general usage patterns. +When working with **OCI Model Deployments**, make sure to import the client from the **oracle-ads** library: + +.. code-block:: python3 + + from ads.aqua.client.openai_client import OpenAI + Requirements ------------ From 92d8162e39fac9a6dcb7a01d7f65432b76aebc13 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Tue, 1 Apr 2025 17:19:03 -0700 Subject: [PATCH 7/7] Adds v1 suffix validation --- ads/aqua/client/openai_client.py | 123 ++++++++++++------ .../large_language_model/aqua_client.rst | 4 +- 2 files changed, 88 insertions(+), 39 deletions(-) diff --git a/ads/aqua/client/openai_client.py b/ads/aqua/client/openai_client.py index 011032593..6aefee05f 100644 --- a/ads/aqua/client/openai_client.py +++ b/ads/aqua/client/openai_client.py @@ -4,18 +4,21 @@ import json import logging +import re from typing import Any, Dict, Optional import httpx from git import Union from ads.aqua.client.client import get_async_httpx_client, get_httpx_client +from ads.common.extended_enum import ExtendedEnum logger = logging.getLogger(__name__) DEFAULT_TIMEOUT = httpx.Timeout(timeout=600, connect=5.0) DEFAULT_MAX_RETRIES = 2 + try: import openai except ImportError as e: @@ -25,90 +28,136 @@ ) from e +class ModelDeploymentBaseEndpoint(ExtendedEnum): + """Supported base endpoints for model deployments.""" + + PREDICT = "predict" + PREDICT_WITH_RESPONSE_STREAM = "predictwithresponsestream" + + class AquaOpenAIMixin: """ - Mixin that provides common logic to patch request headers and URLs - for both synchronous and asynchronous clients. + Mixin that provides common logic to patch HTTP request headers and URLs + for compatibility with the OCI Model Deployment service using the OpenAI API schema. """ def _patch_route(self, original_path: str) -> str: """ - Determine the appropriate route header based on the original URL path. + Extracts and formats the OpenAI-style route path from a full request path. Args: - original_path (str): The original URL path. + original_path (str): The full URL path from the incoming request. Returns: - str: The route header value. + str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions'). """ - route = ( - original_path.lower() - .rstrip("/") - .replace(self.base_url.path.lower().rstrip("/"), "") - ) - return f"/v1{route}" if route else "" + normalized_path = original_path.lower().rstrip("/") + + match = re.search(r"/predict(withresponsestream)?", normalized_path) + if not match: + logger.debug("Route header cannot be resolved from path: %s", original_path) + return "" + + route_suffix = normalized_path[match.end() :].lstrip("/") + if not route_suffix: + logger.warning( + "Missing OpenAI route suffix after '/predict'. " + "Expected something like '/v1/completions'." + ) + return "" + + if not route_suffix.startswith("v"): + logger.warning( + "Route suffix does not start with a version prefix (e.g., '/v1'). " + "This may lead to compatibility issues with OpenAI-style endpoints. " + "Consider updating the URL to include a version prefix, " + "such as '/predict/v1' or '/predictwithresponsestream/v1'." + ) + # route_suffix = f"v1/{route_suffix}" + + return f"/{route_suffix}" def _patch_streaming(self, request: httpx.Request) -> None: """ - Set the 'enable-streaming' header based on whether the JSON request body contains - a 'stream': true parameter. + Sets the 'enable-streaming' header based on the JSON request body contents. - If the Content-Type is JSON, the request body is parsed. If the key 'stream' is set to True, - the header 'enable-streaming' is set to "true". Otherwise, it is set to "false". - If parsing fails, a warning is logged and the default value remains "false". + If the request body contains `"stream": true`, the `enable-streaming` header is set to "true". + Otherwise, it defaults to "false". Args: - request (httpx.Request): The outgoing HTTP request. + request (httpx.Request): The outgoing HTTPX request. """ streaming_enabled = "false" content_type = request.headers.get("Content-Type", "") + if "application/json" in content_type and request.content: try: - body_str = ( + body = ( request.content.decode("utf-8") if isinstance(request.content, bytes) else request.content ) - data = json.loads(body_str) - if data.get("stream") is True: + payload = json.loads(body) + if payload.get("stream") is True: streaming_enabled = "true" except Exception as e: - logger.exception("Failed to parse JSON from request body: %s", e) + logger.exception( + "Failed to parse request JSON body for streaming flag: %s", e + ) + request.headers.setdefault("enable-streaming", streaming_enabled) - logger.debug( - "Patched streaming header to: %s", request.headers["enable-streaming"] - ) + logger.debug("Patched 'enable-streaming' header: %s", streaming_enabled) def _patch_headers(self, request: httpx.Request) -> None: """ - Patch the headers of the request by setting the 'enable-streaming' and 'route' headers. + Patches request headers by injecting OpenAI-compatible values: + - `enable-streaming` for streaming-aware endpoints + - `route` for backend routing Args: - request (httpx.Request): The HTTP request to patch. + request (httpx.Request): The outgoing HTTPX request. """ self._patch_streaming(request) - request.headers.setdefault("route", self._patch_route(request.url.path)) - logger.debug("Patched route header to: %s", request.headers["route"]) + route_header = self._patch_route(request.url.path) + request.headers.setdefault("route", route_header) + logger.debug("Patched 'route' header: %s", route_header) + + def _patch_url(self) -> httpx.URL: + """ + Strips any suffixes from the base URL to retain only the `/predict` or `/predictwithresponsestream` path. + + Returns: + httpx.URL: The normalized base URL with the correct model deployment path. + """ + base_path = f"{self.base_url.path.lower().rstrip('/')}/" + match = re.search(r"/predict(withresponsestream)?/", base_path) + if match: + trimmed = base_path[: match.end() - 1] + return self.base_url.copy_with(path=trimmed) + + logger.debug("Could not determine a valid endpoint from path: %s", base_path) + return self.base_url def _prepare_request_common(self, request: httpx.Request) -> None: """ - Prepare the HTTP request by patching headers and normalizing the URL path. + Common preparation routine for all requests. - This method: - 1. Automatically sets the 'enable-streaming' header based on the request body. - 2. Determines the 'route' header based on the original URL path using OCID-based extraction. - 3. Rewrites the URL path to always end with '/predict' based on the deployment base. + This includes: + - Patching headers with streaming and routing info. + - Normalizing the URL path to include only `/predict` or `/predictwithresponsestream`. Args: - request (httpx.Request): The outgoing HTTP request. + request (httpx.Request): The outgoing HTTPX request. """ - # Patches the headers + # Patch headers logger.debug("Original headers: %s", request.headers) self._patch_headers(request) logger.debug("Headers after patching: %s", request.headers) - # Patches the URL - request.url = self.base_url.copy_with(path=self.base_url.path.rstrip("/")) + # Patch URL + logger.debug("URL before patching: %s", request.url) + request.url = self._patch_url() + logger.debug("URL after patching: %s", request.url) class OpenAI(openai.OpenAI, AquaOpenAIMixin): diff --git a/docs/source/user_guide/large_language_model/aqua_client.rst b/docs/source/user_guide/large_language_model/aqua_client.rst index 12c6313a7..4eec3eeb4 100644 --- a/docs/source/user_guide/large_language_model/aqua_client.rst +++ b/docs/source/user_guide/large_language_model/aqua_client.rst @@ -214,7 +214,7 @@ The synchronous client, ``OpenAI``, extends the OpenAI client. If no HTTP client ads.set_auth(auth="security_token", profile="") client = OpenAI( - base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", + base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict/v1", ) response = client.chat.completions.create( @@ -246,7 +246,7 @@ The asynchronous client, ``AsynOpenAI``, extends the AsyncOpenAI client. If no a async def test_async() -> None: client_async = AsyncOpenAI( - base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", + base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict/v1", ) response = await client_async.chat.completions.create( model="odsc-llm",