Skip to content

Commit c45c38a

Browse files
mrDzurblu-ohaielizjoVipulMascarenhasdipatidar
authored
Adds AQUA Mutli-Model Deployment
Co-authored-by: Dmitrii Cherkasov <[email protected]> Co-authored-by: Lu Peng <[email protected]> Co-authored-by: Liz Johnson <[email protected]> Co-authored-by: Vipul Mascarenhas <[email protected]> Co-authored-by: Dipali Patidar <[email protected]>
1 parent a5586db commit c45c38a

39 files changed

+5357
-387
lines changed

README-development.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ All the unit tests can be found [here](https://github.com/oracle/accelerated-dat
248248
The following commands detail how the unit tests can be run.
249249
```
250250
# Run all tests in AQUA project
251-
python -m pytest -q tests/unitary/with_extras/aqua/test_deployment.py
251+
python -m pytest -q tests/unitary/with_extras/aqua/*
252252
253253
# Run all tests specific to a module within in AQUA project (ex. test_deployment.py, test_model.py, etc.)
254254
python -m pytest -q tests/unitary/with_extras/aqua/test_deployment.py

ads/aqua/app.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import os
77
import traceback
88
from dataclasses import fields
9+
from datetime import datetime, timedelta
910
from typing import Any, Dict, Optional, Union
1011

1112
import oci
13+
from cachetools import TTLCache, cached
1214
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails
1315

1416
from ads import set_auth
@@ -269,6 +271,7 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
269271
logger.info(f"Artifact not found in model {model_id}.")
270272
return False
271273

274+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
272275
def get_config(
273276
self,
274277
model_id: str,
@@ -337,6 +340,9 @@ def get_config(
337340
config_file_path = os.path.join(config_path, config_file_name)
338341
if is_path_exists(config_file_path):
339342
try:
343+
logger.debug(
344+
f"Loading config: `{config_file_name}` from `{config_path}`"
345+
)
340346
config = load_config(
341347
config_path,
342348
config_file_name=config_file_name,

ads/aqua/common/entities.py

Lines changed: 224 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
import re
56
from typing import Any, Dict, Optional
67

78
from oci.data_science.models import Model
8-
from pydantic import BaseModel, Field
9+
from pydantic import BaseModel, Field, model_validator
10+
11+
from ads.aqua import logger
12+
from ads.aqua.config.utils.serializer import Serializable
913

1014

1115
class ContainerSpec:
@@ -25,7 +29,6 @@ class ContainerSpec:
2529
class ModelConfigResult(BaseModel):
2630
"""
2731
Represents the result of getting the AQUA model configuration.
28-
2932
Attributes:
3033
model_details (Dict[str, Any]): A dictionary containing model details extracted from OCI.
3134
config (Dict[str, Any]): A dictionary of the loaded configuration.
@@ -42,3 +45,222 @@ class Config:
4245
extra = "ignore"
4346
arbitrary_types_allowed = True
4447
protected_namespaces = ()
48+
49+
50+
class GPUSpecs(Serializable):
51+
"""
52+
Represents the GPU specifications for a compute instance.
53+
"""
54+
55+
gpu_memory_in_gbs: Optional[int] = Field(
56+
default=None, description="The amount of GPU memory available (in GB)."
57+
)
58+
gpu_count: Optional[int] = Field(
59+
default=None, description="The number of GPUs available."
60+
)
61+
gpu_type: Optional[str] = Field(
62+
default=None, description="The type of GPU (e.g., 'V100, A100, H100')."
63+
)
64+
65+
66+
class GPUShapesIndex(Serializable):
67+
"""
68+
Represents the index of GPU shapes.
69+
70+
Attributes
71+
----------
72+
shapes (Dict[str, GPUSpecs]): A mapping of compute shape names to their GPU specifications.
73+
"""
74+
75+
shapes: Dict[str, GPUSpecs] = Field(
76+
default_factory=dict,
77+
description="Mapping of shape names to GPU specifications.",
78+
)
79+
80+
81+
class ComputeShapeSummary(Serializable):
82+
"""
83+
Represents the specifications of a compute instance's shape.
84+
"""
85+
86+
core_count: Optional[int] = Field(
87+
default=None, description="The number of CPU cores available."
88+
)
89+
memory_in_gbs: Optional[int] = Field(
90+
default=None, description="The amount of memory (in GB) available."
91+
)
92+
name: Optional[str] = Field(
93+
default=None, description="The name identifier of the compute shape."
94+
)
95+
shape_series: Optional[str] = Field(
96+
default=None, description="The series or category of the compute shape."
97+
)
98+
gpu_specs: Optional[GPUSpecs] = Field(
99+
default=None,
100+
description="The GPU specifications associated with the compute shape.",
101+
)
102+
103+
@model_validator(mode="after")
104+
@classmethod
105+
def set_gpu_specs(cls, model: "ComputeShapeSummary") -> "ComputeShapeSummary":
106+
"""
107+
Validates and populates GPU specifications if the shape_series indicates a GPU-based shape.
108+
109+
- If the shape_series contains "GPU", the validator first checks if the shape name exists
110+
in the GPU_SPECS dictionary. If found, it creates a GPUSpecs instance with the corresponding data.
111+
- If the shape is not found in the GPU_SPECS, it attempts to extract the GPU count from the shape name
112+
using a regex pattern (looking for a number following a dot at the end of the name).
113+
114+
The information about shapes is taken from: https://docs.oracle.com/en-us/iaas/data-science/using/supported-shapes.htm
115+
116+
Returns:
117+
ComputeShapeSummary: The updated instance with gpu_specs populated if applicable.
118+
"""
119+
try:
120+
if (
121+
model.shape_series
122+
and "GPU" in model.shape_series.upper()
123+
and model.name
124+
and not model.gpu_specs
125+
):
126+
# Try to extract gpu_count from the shape name using a regex (e.g., "VM.GPU3.2" -> gpu_count=2)
127+
match = re.search(r"\.(\d+)$", model.name)
128+
if match:
129+
gpu_count = int(match.group(1))
130+
model.gpu_specs = GPUSpecs(gpu_count=gpu_count)
131+
except Exception as err:
132+
logger.debug(
133+
f"Error occurred in attempt to extract GPU specification for the f{model.name}. "
134+
f"Details: {err}"
135+
)
136+
return model
137+
138+
139+
class AquaMultiModelRef(Serializable):
140+
"""
141+
Lightweight model descriptor used for multi-model deployment.
142+
143+
This class only contains essential details
144+
required to fetch complete model metadata and deploy models.
145+
146+
Attributes
147+
----------
148+
model_id : str
149+
The unique identifier of the model.
150+
model_name : Optional[str]
151+
The name of the model.
152+
gpu_count : Optional[int]
153+
Number of GPUs required for deployment.
154+
env_var : Optional[Dict[str, Any]]
155+
Optional environment variables to override during deployment.
156+
artifact_location : Optional[str]
157+
Artifact path of model in the multimodel group.
158+
"""
159+
160+
model_id: str = Field(..., description="The model OCID to deploy.")
161+
model_name: Optional[str] = Field(None, description="The name of model.")
162+
gpu_count: Optional[int] = Field(
163+
None, description="The gpu count allocation for the model."
164+
)
165+
env_var: Optional[dict] = Field(
166+
default_factory=dict, description="The environment variables of the model."
167+
)
168+
artifact_location: Optional[str] = Field(
169+
None, description="Artifact path of model in the multimodel group."
170+
)
171+
172+
class Config:
173+
extra = "ignore"
174+
protected_namespaces = ()
175+
176+
177+
class ContainerPath(Serializable):
178+
"""
179+
Represents a parsed container path, extracting the path, name, and version.
180+
181+
This model is designed to parse a container path string of the format
182+
'<image_path>:<version>'. It extracts the following components:
183+
- `path`: The full path up to the version.
184+
- `name`: The last segment of the path, representing the image name.
185+
- `version`: The version number following the final colon.
186+
187+
Example Usage:
188+
--------------
189+
>>> container = ContainerPath(full_path="iad.ocir.io/ociodscdev/odsc-llm-evaluate:0.1.2.9")
190+
>>> container.path
191+
'iad.ocir.io/ociodscdev/odsc-llm-evaluate'
192+
>>> container.name
193+
'odsc-llm-evaluate'
194+
>>> container.version
195+
'0.1.2.9'
196+
197+
>>> container = ContainerPath(full_path="custom-scheme://path/to/versioned-model:2.5.1")
198+
>>> container.path
199+
'custom-scheme://path/to/versioned-model'
200+
>>> container.name
201+
'versioned-model'
202+
>>> container.version
203+
'2.5.1'
204+
205+
Attributes
206+
----------
207+
full_path : str
208+
The complete container path string to be parsed.
209+
path : Optional[str]
210+
The full path up to the version (e.g., 'iad.ocir.io/ociodscdev/odsc-llm-evaluate').
211+
name : Optional[str]
212+
The image name, which is the last segment of `path` (e.g., 'odsc-llm-evaluate').
213+
version : Optional[str]
214+
The version number following the final colon in the path (e.g., '0.1.2.9').
215+
216+
Methods
217+
-------
218+
validate(values: Any) -> Any
219+
Validates and parses the `full_path`, extracting `path`, `name`, and `version`.
220+
"""
221+
222+
full_path: str
223+
path: Optional[str] = None
224+
name: Optional[str] = None
225+
version: Optional[str] = None
226+
227+
@model_validator(mode="before")
228+
@classmethod
229+
def validate(cls, values: Any) -> Any:
230+
"""
231+
Validates and parses the full container path, extracting the image path, image name, and version.
232+
233+
Parameters
234+
----------
235+
values : dict
236+
The dictionary of values being validated, containing 'full_path'.
237+
238+
Returns
239+
-------
240+
dict
241+
Updated values dictionary with extracted 'path', 'name', and 'version'.
242+
"""
243+
full_path = values.get("full_path", "").strip()
244+
245+
# Regex to parse <image_path>:<version>
246+
match = re.match(
247+
r"^(?P<image_path>.+?)(?::(?P<image_version>[\w\.]+))?$", full_path
248+
)
249+
250+
if not match:
251+
raise ValueError(
252+
"Invalid container path format. Expected format: '<image_path>:<version>'"
253+
)
254+
255+
# Extract image_path and version
256+
values["path"] = match.group("image_path")
257+
values["version"] = match.group("image_version")
258+
259+
# Extract image_name as the last segment of image_path
260+
values["name"] = values["path"].split("/")[-1]
261+
262+
return values
263+
264+
class Config:
265+
extra = "ignore"
266+
protected_namespaces = ()

ads/aqua/common/enums.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Tags(ExtendedEnum):
2525
AQUA_TAG = "OCI_AQUA"
2626
AQUA_SERVICE_MODEL_TAG = "aqua_service_model"
2727
AQUA_FINE_TUNED_MODEL_TAG = "aqua_fine_tuned_model"
28+
AQUA_MODEL_ID_TAG = "aqua_model_id"
2829
AQUA_MODEL_NAME_TAG = "aqua_model_name"
2930
AQUA_EVALUATION = "aqua_evaluation"
3031
AQUA_FINE_TUNING = "aqua_finetuning"
@@ -34,6 +35,7 @@ class Tags(ExtendedEnum):
3435
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
3536
MODEL_FORMAT = "model_format"
3637
MODEL_ARTIFACT_FILE = "model_file"
38+
MULTIMODEL_TYPE_TAG = "aqua_multimodel"
3739

3840

3941
class InferenceContainerType(ExtendedEnum):
@@ -44,6 +46,7 @@ class InferenceContainerType(ExtendedEnum):
4446

4547
class InferenceContainerTypeFamily(ExtendedEnum):
4648
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
49+
AQUA_VLLM_V1_CONTAINER_FAMILY = "odsc-vllm-serving-v1"
4750
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
4851
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
4952

0 commit comments

Comments
 (0)