diff --git a/ads/aqua/client/openai_client.py b/ads/aqua/client/openai_client.py new file mode 100644 index 000000000..6aefee05f --- /dev/null +++ b/ads/aqua/client/openai_client.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import json +import logging +import re +from typing import Any, Dict, Optional + +import httpx +from git import Union + +from ads.aqua.client.client import get_async_httpx_client, get_httpx_client +from ads.common.extended_enum import ExtendedEnum + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = httpx.Timeout(timeout=600, connect=5.0) +DEFAULT_MAX_RETRIES = 2 + + +try: + import openai +except ImportError as e: + raise ModuleNotFoundError( + "The custom OpenAI client requires the `openai-python` package. " + "Please install it with `pip install openai`." + ) from e + + +class ModelDeploymentBaseEndpoint(ExtendedEnum): + """Supported base endpoints for model deployments.""" + + PREDICT = "predict" + PREDICT_WITH_RESPONSE_STREAM = "predictwithresponsestream" + + +class AquaOpenAIMixin: + """ + Mixin that provides common logic to patch HTTP request headers and URLs + for compatibility with the OCI Model Deployment service using the OpenAI API schema. + """ + + def _patch_route(self, original_path: str) -> str: + """ + Extracts and formats the OpenAI-style route path from a full request path. + + Args: + original_path (str): The full URL path from the incoming request. + + Returns: + str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions'). + """ + normalized_path = original_path.lower().rstrip("/") + + match = re.search(r"/predict(withresponsestream)?", normalized_path) + if not match: + logger.debug("Route header cannot be resolved from path: %s", original_path) + return "" + + route_suffix = normalized_path[match.end() :].lstrip("/") + if not route_suffix: + logger.warning( + "Missing OpenAI route suffix after '/predict'. " + "Expected something like '/v1/completions'." + ) + return "" + + if not route_suffix.startswith("v"): + logger.warning( + "Route suffix does not start with a version prefix (e.g., '/v1'). " + "This may lead to compatibility issues with OpenAI-style endpoints. " + "Consider updating the URL to include a version prefix, " + "such as '/predict/v1' or '/predictwithresponsestream/v1'." + ) + # route_suffix = f"v1/{route_suffix}" + + return f"/{route_suffix}" + + def _patch_streaming(self, request: httpx.Request) -> None: + """ + Sets the 'enable-streaming' header based on the JSON request body contents. + + If the request body contains `"stream": true`, the `enable-streaming` header is set to "true". + Otherwise, it defaults to "false". + + Args: + request (httpx.Request): The outgoing HTTPX request. + """ + streaming_enabled = "false" + content_type = request.headers.get("Content-Type", "") + + if "application/json" in content_type and request.content: + try: + body = ( + request.content.decode("utf-8") + if isinstance(request.content, bytes) + else request.content + ) + payload = json.loads(body) + if payload.get("stream") is True: + streaming_enabled = "true" + except Exception as e: + logger.exception( + "Failed to parse request JSON body for streaming flag: %s", e + ) + + request.headers.setdefault("enable-streaming", streaming_enabled) + logger.debug("Patched 'enable-streaming' header: %s", streaming_enabled) + + def _patch_headers(self, request: httpx.Request) -> None: + """ + Patches request headers by injecting OpenAI-compatible values: + - `enable-streaming` for streaming-aware endpoints + - `route` for backend routing + + Args: + request (httpx.Request): The outgoing HTTPX request. + """ + self._patch_streaming(request) + route_header = self._patch_route(request.url.path) + request.headers.setdefault("route", route_header) + logger.debug("Patched 'route' header: %s", route_header) + + def _patch_url(self) -> httpx.URL: + """ + Strips any suffixes from the base URL to retain only the `/predict` or `/predictwithresponsestream` path. + + Returns: + httpx.URL: The normalized base URL with the correct model deployment path. + """ + base_path = f"{self.base_url.path.lower().rstrip('/')}/" + match = re.search(r"/predict(withresponsestream)?/", base_path) + if match: + trimmed = base_path[: match.end() - 1] + return self.base_url.copy_with(path=trimmed) + + logger.debug("Could not determine a valid endpoint from path: %s", base_path) + return self.base_url + + def _prepare_request_common(self, request: httpx.Request) -> None: + """ + Common preparation routine for all requests. + + This includes: + - Patching headers with streaming and routing info. + - Normalizing the URL path to include only `/predict` or `/predictwithresponsestream`. + + Args: + request (httpx.Request): The outgoing HTTPX request. + """ + # Patch headers + logger.debug("Original headers: %s", request.headers) + self._patch_headers(request) + logger.debug("Headers after patching: %s", request.headers) + + # Patch URL + logger.debug("URL before patching: %s", request.url) + request.url = self._patch_url() + logger.debug("URL after patching: %s", request.url) + + +class OpenAI(openai.OpenAI, AquaOpenAIMixin): + def __init__( + self, + *, + api_key: Optional[str] = None, + organization: Optional[str] = None, + project: Optional[str] = None, + base_url: Optional[Union[str, httpx.URL]] = None, + websocket_base_url: Optional[Union[str, httpx.URL]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Optional[Dict[str, str]] = None, + default_query: Optional[Dict[str, object]] = None, + http_client: Optional[httpx.Client] = None, + http_client_kwargs: Optional[Dict[str, Any]] = None, + _strict_response_validation: bool = False, + **kwargs: Any, + ) -> None: + """ + Construct a new synchronous OpenAI client instance. + + If no http_client is provided, one will be automatically created using ads.aqua.get_httpx_client(). + + Args: + api_key (str, optional): API key for authentication. Defaults to env variable OPENAI_API_KEY. + organization (str, optional): Organization ID. Defaults to env variable OPENAI_ORG_ID. + project (str, optional): Project ID. Defaults to env variable OPENAI_PROJECT_ID. + base_url (str | httpx.URL, optional): Base URL for the API. + websocket_base_url (str | httpx.URL, optional): Base URL for WebSocket connections. + timeout (float | httpx.Timeout, optional): Timeout for API requests. + max_retries (int, optional): Maximum number of retries for API requests. + default_headers (dict[str, str], optional): Additional headers. + default_query (dict[str, object], optional): Additional query parameters. + http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created. + http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client. + _strict_response_validation (bool, optional): Enable strict response validation. + **kwargs: Additional keyword arguments passed to the parent __init__. + """ + if http_client is None: + logger.debug( + "No http_client provided; auto-creating one using ads.aqua.get_httpx_client()" + ) + http_client = get_httpx_client(**(http_client_kwargs or {})) + if not api_key: + logger.debug("API key not provided; using default placeholder for OCI.") + api_key = "OCI" + + super().__init__( + api_key=api_key, + organization=organization, + project=project, + base_url=base_url, + websocket_base_url=websocket_base_url, + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + default_query=default_query, + http_client=http_client, + _strict_response_validation=_strict_response_validation, + **kwargs, + ) + + def _prepare_request(self, request: httpx.Request) -> None: + """ + Prepare the synchronous HTTP request by applying common modifications. + + Args: + request (httpx.Request): The outgoing HTTP request. + """ + self._prepare_request_common(request) + + +class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin): + def __init__( + self, + *, + api_key: Optional[str] = None, + organization: Optional[str] = None, + project: Optional[str] = None, + base_url: Optional[Union[str, httpx.URL]] = None, + websocket_base_url: Optional[Union[str, httpx.URL]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT, + max_retries: int = DEFAULT_MAX_RETRIES, + default_headers: Optional[Dict[str, str]] = None, + default_query: Optional[Dict[str, object]] = None, + http_client: Optional[httpx.Client] = None, + http_client_kwargs: Optional[Dict[str, Any]] = None, + _strict_response_validation: bool = False, + **kwargs: Any, + ) -> None: + """ + Construct a new asynchronous AsyncOpenAI client instance. + + If no http_client is provided, one will be automatically created using + ads.aqua.get_async_httpx_client(). + + Args: + api_key (str, optional): API key for authentication. Defaults to env variable OPENAI_API_KEY. + organization (str, optional): Organization ID. + project (str, optional): Project ID. + base_url (str | httpx.URL, optional): Base URL for the API. + websocket_base_url (str | httpx.URL, optional): Base URL for WebSocket connections. + timeout (float | httpx.Timeout, optional): Timeout for API requests. + max_retries (int, optional): Maximum number of retries for API requests. + default_headers (dict[str, str], optional): Additional headers. + default_query (dict[str, object], optional): Additional query parameters. + http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created. + http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client. + _strict_response_validation (bool, optional): Enable strict response validation. + **kwargs: Additional keyword arguments passed to the parent __init__. + """ + if http_client is None: + logger.debug( + "No async http_client provided; auto-creating one using ads.aqua.get_async_httpx_client()" + ) + http_client = get_async_httpx_client(**(http_client_kwargs or {})) + if not api_key: + logger.debug("API key not provided; using default placeholder for OCI.") + api_key = "OCI" + + super().__init__( + api_key=api_key, + organization=organization, + project=project, + base_url=base_url, + websocket_base_url=websocket_base_url, + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + default_query=default_query, + http_client=http_client, + _strict_response_validation=_strict_response_validation, + **kwargs, + ) + + async def _prepare_request(self, request: httpx.Request) -> None: + """ + Asynchronously prepare the HTTP request by applying common modifications. + + Args: + request (httpx.Request): The outgoing HTTP request. + """ + self._prepare_request_common(request) diff --git a/docs/source/user_guide/large_language_model/aqua_client.rst b/docs/source/user_guide/large_language_model/aqua_client.rst index 45bf40578..4eec3eeb4 100644 --- a/docs/source/user_guide/large_language_model/aqua_client.rst +++ b/docs/source/user_guide/large_language_model/aqua_client.rst @@ -146,6 +146,7 @@ Usage .. code-block:: python3 import ads + import ads.aqua ads.set_auth(auth="security_token", profile="<replace-with-your-profile>") @@ -167,7 +168,92 @@ Usage .. code-block:: python3 import ads + import ads.aqua ads.set_auth(auth="security_token", profile="<replace-with-your-profile>") async_client = client = ads.aqua.get_async_httpx_client(timeout=10.0) + + +Aqua OpenAI Client +================== + +.. versionadded:: 2.13.4 + +The Oracle-ADS **OpenAI** and **AsyncOpenAI** clients extend the official OpenAI Python SDK to support model deployments on **OCI**. These clients automatically patch request headers and normalize URL paths based on the provided deployment OCID, ensuring that API calls are formatted correctly for OCI Model Deployment. + +You can refer to the official `Open AI quick start examples <https://platform.openai.com/docs/quickstart?api-mode=responses>`_ for general usage patterns. +When working with **OCI Model Deployments**, make sure to import the client from the **oracle-ads** library: + +.. code-block:: python3 + + from ads.aqua.client.openai_client import OpenAI + + +Requirements +------------ +To use these clients, you must have the ``openai-python`` package installed. This package is an optional dependency. If it is not installed, you will receive an informative error when attempting to instantiate one of these clients. To install the package, run: + +.. code-block:: bash + + pip install openai + + +Usage +----- +Both synchronous and asynchronous versions are available. + +**Synchronous Client** + +The synchronous client, ``OpenAI``, extends the OpenAI client. If no HTTP client is provided, it will automatically create one using ``ads.aqua.get_httpx_client()``. + +.. code-block:: python + + import ads + from ads.aqua.client.openai_client import OpenAI + ads.set_auth(auth="security_token", profile="<replace-with-your-profile>") + + client = OpenAI( + base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<OCID>/predict/v1", + ) + + response = client.chat.completions.create( + model="odsc-llm", + messages=[ + { + "role": "user", + "content": "Tell me a joke.", + } + ], + # stream=True, # enable for streaming + ) + + print(response) + + +**Asynchronous Client** + +The asynchronous client, ``AsynOpenAI``, extends the AsyncOpenAI client. If no async HTTP client is provided, it will automatically create one using ``ads.aqua.get_async_httpx_client()``. + +.. code-block:: python + + import ads + import asyncio + import nest_asyncio + from ads.aqua.client.openai_client import AsyncOpenAI + + ads.set_auth(auth="security_token") + + async def test_async() -> None: + client_async = AsyncOpenAI( + base_url="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<OCID>/predict/v1", + ) + response = await client_async.chat.completions.create( + model="odsc-llm", + messages=[{"role": "user", "content": "Tell me a long joke"}], + stream=True + ) + async for event in response: + print(event) + + asyncio.run(test_async())