Skip to content

Commit cc18bcf

Browse files
Rebasing and Fixing UTs
2 parents b890243 + 0ada1cc commit cc18bcf

File tree

20 files changed

+1036
-385
lines changed

20 files changed

+1036
-385
lines changed

ads/aqua/client/openai_client.py

+305
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2025 Oracle and/or its affiliates.
3+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
5+
import json
6+
import logging
7+
import re
8+
from typing import Any, Dict, Optional
9+
10+
import httpx
11+
from git import Union
12+
13+
from ads.aqua.client.client import get_async_httpx_client, get_httpx_client
14+
from ads.common.extended_enum import ExtendedEnum
15+
16+
logger = logging.getLogger(__name__)
17+
18+
DEFAULT_TIMEOUT = httpx.Timeout(timeout=600, connect=5.0)
19+
DEFAULT_MAX_RETRIES = 2
20+
21+
22+
try:
23+
import openai
24+
except ImportError as e:
25+
raise ModuleNotFoundError(
26+
"The custom OpenAI client requires the `openai-python` package. "
27+
"Please install it with `pip install openai`."
28+
) from e
29+
30+
31+
class ModelDeploymentBaseEndpoint(ExtendedEnum):
32+
"""Supported base endpoints for model deployments."""
33+
34+
PREDICT = "predict"
35+
PREDICT_WITH_RESPONSE_STREAM = "predictwithresponsestream"
36+
37+
38+
class AquaOpenAIMixin:
39+
"""
40+
Mixin that provides common logic to patch HTTP request headers and URLs
41+
for compatibility with the OCI Model Deployment service using the OpenAI API schema.
42+
"""
43+
44+
def _patch_route(self, original_path: str) -> str:
45+
"""
46+
Extracts and formats the OpenAI-style route path from a full request path.
47+
48+
Args:
49+
original_path (str): The full URL path from the incoming request.
50+
51+
Returns:
52+
str: The normalized OpenAI-compatible route path (e.g., '/v1/chat/completions').
53+
"""
54+
normalized_path = original_path.lower().rstrip("/")
55+
56+
match = re.search(r"/predict(withresponsestream)?", normalized_path)
57+
if not match:
58+
logger.debug("Route header cannot be resolved from path: %s", original_path)
59+
return ""
60+
61+
route_suffix = normalized_path[match.end() :].lstrip("/")
62+
if not route_suffix:
63+
logger.warning(
64+
"Missing OpenAI route suffix after '/predict'. "
65+
"Expected something like '/v1/completions'."
66+
)
67+
return ""
68+
69+
if not route_suffix.startswith("v"):
70+
logger.warning(
71+
"Route suffix does not start with a version prefix (e.g., '/v1'). "
72+
"This may lead to compatibility issues with OpenAI-style endpoints. "
73+
"Consider updating the URL to include a version prefix, "
74+
"such as '/predict/v1' or '/predictwithresponsestream/v1'."
75+
)
76+
# route_suffix = f"v1/{route_suffix}"
77+
78+
return f"/{route_suffix}"
79+
80+
def _patch_streaming(self, request: httpx.Request) -> None:
81+
"""
82+
Sets the 'enable-streaming' header based on the JSON request body contents.
83+
84+
If the request body contains `"stream": true`, the `enable-streaming` header is set to "true".
85+
Otherwise, it defaults to "false".
86+
87+
Args:
88+
request (httpx.Request): The outgoing HTTPX request.
89+
"""
90+
streaming_enabled = "false"
91+
content_type = request.headers.get("Content-Type", "")
92+
93+
if "application/json" in content_type and request.content:
94+
try:
95+
body = (
96+
request.content.decode("utf-8")
97+
if isinstance(request.content, bytes)
98+
else request.content
99+
)
100+
payload = json.loads(body)
101+
if payload.get("stream") is True:
102+
streaming_enabled = "true"
103+
except Exception as e:
104+
logger.exception(
105+
"Failed to parse request JSON body for streaming flag: %s", e
106+
)
107+
108+
request.headers.setdefault("enable-streaming", streaming_enabled)
109+
logger.debug("Patched 'enable-streaming' header: %s", streaming_enabled)
110+
111+
def _patch_headers(self, request: httpx.Request) -> None:
112+
"""
113+
Patches request headers by injecting OpenAI-compatible values:
114+
- `enable-streaming` for streaming-aware endpoints
115+
- `route` for backend routing
116+
117+
Args:
118+
request (httpx.Request): The outgoing HTTPX request.
119+
"""
120+
self._patch_streaming(request)
121+
route_header = self._patch_route(request.url.path)
122+
request.headers.setdefault("route", route_header)
123+
logger.debug("Patched 'route' header: %s", route_header)
124+
125+
def _patch_url(self) -> httpx.URL:
126+
"""
127+
Strips any suffixes from the base URL to retain only the `/predict` or `/predictwithresponsestream` path.
128+
129+
Returns:
130+
httpx.URL: The normalized base URL with the correct model deployment path.
131+
"""
132+
base_path = f"{self.base_url.path.lower().rstrip('/')}/"
133+
match = re.search(r"/predict(withresponsestream)?/", base_path)
134+
if match:
135+
trimmed = base_path[: match.end() - 1]
136+
return self.base_url.copy_with(path=trimmed)
137+
138+
logger.debug("Could not determine a valid endpoint from path: %s", base_path)
139+
return self.base_url
140+
141+
def _prepare_request_common(self, request: httpx.Request) -> None:
142+
"""
143+
Common preparation routine for all requests.
144+
145+
This includes:
146+
- Patching headers with streaming and routing info.
147+
- Normalizing the URL path to include only `/predict` or `/predictwithresponsestream`.
148+
149+
Args:
150+
request (httpx.Request): The outgoing HTTPX request.
151+
"""
152+
# Patch headers
153+
logger.debug("Original headers: %s", request.headers)
154+
self._patch_headers(request)
155+
logger.debug("Headers after patching: %s", request.headers)
156+
157+
# Patch URL
158+
logger.debug("URL before patching: %s", request.url)
159+
request.url = self._patch_url()
160+
logger.debug("URL after patching: %s", request.url)
161+
162+
163+
class OpenAI(openai.OpenAI, AquaOpenAIMixin):
164+
def __init__(
165+
self,
166+
*,
167+
api_key: Optional[str] = None,
168+
organization: Optional[str] = None,
169+
project: Optional[str] = None,
170+
base_url: Optional[Union[str, httpx.URL]] = None,
171+
websocket_base_url: Optional[Union[str, httpx.URL]] = None,
172+
timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT,
173+
max_retries: int = DEFAULT_MAX_RETRIES,
174+
default_headers: Optional[Dict[str, str]] = None,
175+
default_query: Optional[Dict[str, object]] = None,
176+
http_client: Optional[httpx.Client] = None,
177+
http_client_kwargs: Optional[Dict[str, Any]] = None,
178+
_strict_response_validation: bool = False,
179+
**kwargs: Any,
180+
) -> None:
181+
"""
182+
Construct a new synchronous OpenAI client instance.
183+
184+
If no http_client is provided, one will be automatically created using ads.aqua.get_httpx_client().
185+
186+
Args:
187+
api_key (str, optional): API key for authentication. Defaults to env variable OPENAI_API_KEY.
188+
organization (str, optional): Organization ID. Defaults to env variable OPENAI_ORG_ID.
189+
project (str, optional): Project ID. Defaults to env variable OPENAI_PROJECT_ID.
190+
base_url (str | httpx.URL, optional): Base URL for the API.
191+
websocket_base_url (str | httpx.URL, optional): Base URL for WebSocket connections.
192+
timeout (float | httpx.Timeout, optional): Timeout for API requests.
193+
max_retries (int, optional): Maximum number of retries for API requests.
194+
default_headers (dict[str, str], optional): Additional headers.
195+
default_query (dict[str, object], optional): Additional query parameters.
196+
http_client (httpx.Client, optional): Custom HTTP client; if not provided, one will be auto-created.
197+
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
198+
_strict_response_validation (bool, optional): Enable strict response validation.
199+
**kwargs: Additional keyword arguments passed to the parent __init__.
200+
"""
201+
if http_client is None:
202+
logger.debug(
203+
"No http_client provided; auto-creating one using ads.aqua.get_httpx_client()"
204+
)
205+
http_client = get_httpx_client(**(http_client_kwargs or {}))
206+
if not api_key:
207+
logger.debug("API key not provided; using default placeholder for OCI.")
208+
api_key = "OCI"
209+
210+
super().__init__(
211+
api_key=api_key,
212+
organization=organization,
213+
project=project,
214+
base_url=base_url,
215+
websocket_base_url=websocket_base_url,
216+
timeout=timeout,
217+
max_retries=max_retries,
218+
default_headers=default_headers,
219+
default_query=default_query,
220+
http_client=http_client,
221+
_strict_response_validation=_strict_response_validation,
222+
**kwargs,
223+
)
224+
225+
def _prepare_request(self, request: httpx.Request) -> None:
226+
"""
227+
Prepare the synchronous HTTP request by applying common modifications.
228+
229+
Args:
230+
request (httpx.Request): The outgoing HTTP request.
231+
"""
232+
self._prepare_request_common(request)
233+
234+
235+
class AsyncOpenAI(openai.AsyncOpenAI, AquaOpenAIMixin):
236+
def __init__(
237+
self,
238+
*,
239+
api_key: Optional[str] = None,
240+
organization: Optional[str] = None,
241+
project: Optional[str] = None,
242+
base_url: Optional[Union[str, httpx.URL]] = None,
243+
websocket_base_url: Optional[Union[str, httpx.URL]] = None,
244+
timeout: Optional[Union[float, httpx.Timeout]] = DEFAULT_TIMEOUT,
245+
max_retries: int = DEFAULT_MAX_RETRIES,
246+
default_headers: Optional[Dict[str, str]] = None,
247+
default_query: Optional[Dict[str, object]] = None,
248+
http_client: Optional[httpx.Client] = None,
249+
http_client_kwargs: Optional[Dict[str, Any]] = None,
250+
_strict_response_validation: bool = False,
251+
**kwargs: Any,
252+
) -> None:
253+
"""
254+
Construct a new asynchronous AsyncOpenAI client instance.
255+
256+
If no http_client is provided, one will be automatically created using
257+
ads.aqua.get_async_httpx_client().
258+
259+
Args:
260+
api_key (str, optional): API key for authentication. Defaults to env variable OPENAI_API_KEY.
261+
organization (str, optional): Organization ID.
262+
project (str, optional): Project ID.
263+
base_url (str | httpx.URL, optional): Base URL for the API.
264+
websocket_base_url (str | httpx.URL, optional): Base URL for WebSocket connections.
265+
timeout (float | httpx.Timeout, optional): Timeout for API requests.
266+
max_retries (int, optional): Maximum number of retries for API requests.
267+
default_headers (dict[str, str], optional): Additional headers.
268+
default_query (dict[str, object], optional): Additional query parameters.
269+
http_client (httpx.AsyncClient, optional): Custom asynchronous HTTP client; if not provided, one will be auto-created.
270+
http_client_kwargs (dict[str, Any], optional): Extra kwargs for auto-creating the HTTP client.
271+
_strict_response_validation (bool, optional): Enable strict response validation.
272+
**kwargs: Additional keyword arguments passed to the parent __init__.
273+
"""
274+
if http_client is None:
275+
logger.debug(
276+
"No async http_client provided; auto-creating one using ads.aqua.get_async_httpx_client()"
277+
)
278+
http_client = get_async_httpx_client(**(http_client_kwargs or {}))
279+
if not api_key:
280+
logger.debug("API key not provided; using default placeholder for OCI.")
281+
api_key = "OCI"
282+
283+
super().__init__(
284+
api_key=api_key,
285+
organization=organization,
286+
project=project,
287+
base_url=base_url,
288+
websocket_base_url=websocket_base_url,
289+
timeout=timeout,
290+
max_retries=max_retries,
291+
default_headers=default_headers,
292+
default_query=default_query,
293+
http_client=http_client,
294+
_strict_response_validation=_strict_response_validation,
295+
**kwargs,
296+
)
297+
298+
async def _prepare_request(self, request: httpx.Request) -> None:
299+
"""
300+
Asynchronously prepare the HTTP request by applying common modifications.
301+
302+
Args:
303+
request (httpx.Request): The outgoing HTTP request.
304+
"""
305+
self._prepare_request_common(request)

ads/aqua/common/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,9 @@ def load_gpu_shapes_index(
12491249
try:
12501250
auth = auth or authutil.default_signer()
12511251
# Construct the object storage path. Adjust bucket name and path as needed.
1252-
storage_path = f"oci://{CONDA_BUCKET_NAME}@{CONDA_BUCKET_NS}/{file_name}/1"
1252+
storage_path = (
1253+
f"oci://{CONDA_BUCKET_NAME}@{CONDA_BUCKET_NS}/service_pack/{file_name}"
1254+
)
12531255
logger.debug("Loading GPU shapes index from Object Storage")
12541256
with fsspec.open(storage_path, mode="r", **auth) as file_obj:
12551257
data = json.load(file_obj)

ads/aqua/config/container_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class Usage(ExtendedEnum):
2222
INFERENCE = "inference"
2323
BATCH_INFERENCE = "batch_inference"
2424
MULTI_MODEL = "multi_model"
25+
OTHER = "other"
2526

2627

2728
class AquaContainerConfigSpec(Serializable):

ads/aqua/constants.py

+25
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
HF_METADATA_FOLDER = ".cache/"
4343
HF_LOGIN_DEFAULT_TIMEOUT = 2
4444
MODEL_NAME_DELIMITER = ";"
45+
AQUA_TROUBLESHOOTING_LINK = "https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/troubleshooting-tips.md"
4546

4647
TRAINING_METRICS_FINAL = "training_metrics_final"
4748
VALIDATION_METRICS_FINAL = "validation_metrics_final"
@@ -87,3 +88,27 @@
8788
"--host",
8889
}
8990
TEI_CONTAINER_DEFAULT_HOST = "8080"
91+
92+
OCI_OPERATION_FAILURES = {
93+
"list_model_deployments": "Unable to list model deployments. See tips for troubleshooting: ",
94+
"list_models": "Unable to list models. See tips for troubleshooting: ",
95+
"get_namespace": "Unable to access specified Object Storage Bucket. See tips for troubleshooting: ",
96+
"list_log_groups": "Unable to access logs. See tips for troubleshooting: ",
97+
"list_buckets": "Unable to list Object Storage Bucket. See tips for troubleshooting: ",
98+
"put_object": "Unable to access or find Object Storage Bucket. See tips for troubleshooting: ",
99+
"list_model_version_sets": "Unable to create or fetch model version set. See tips for troubleshooting:",
100+
"update_model": "Unable to update model. See tips for troubleshooting: ",
101+
"list_data_science_private_endpoints": "Unable to access private endpoint. See tips for troubleshooting: ",
102+
"create_model": "Unable to register model. See tips for troubleshooting: ",
103+
"create_deployment": "Unable to create deployment. See tips for troubleshooting: ",
104+
"create_model_version_sets": "Unable to create model version set. See tips for troubleshooting: ",
105+
"create_job": "Unable to create job. See tips for troubleshooting: ",
106+
"create_job_run": "Unable to create job run. See tips for troubleshooting: ",
107+
}
108+
109+
STATUS_CODE_MESSAGES = {
110+
"400": "Could not process your request due to invalid input.",
111+
"403": "We're having trouble processing your request with the information provided.",
112+
"404": "Authorization Failed: The resource you're looking for isn't accessible.",
113+
"408": "Server is taking too long to respond, please try again.",
114+
}

0 commit comments

Comments
 (0)