Skip to content

Commit f967984

Browse files
authored
[AQUA][MMD] Add Support for Retrieving Deployment Configuration from Base Model for Fine-Tuned Models (#1185)
1 parent 75905e0 commit f967984

File tree

2 files changed

+95
-15
lines changed

2 files changed

+95
-15
lines changed

ads/aqua/app.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ads.aqua import logger
2323
from ads.aqua.common.entities import ModelConfigResult
2424
from ads.aqua.common.enums import ConfigFolder, Tags
25-
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
25+
from ads.aqua.common.errors import AquaValueError
2626
from ads.aqua.common.utils import (
2727
_is_valid_mvs,
2828
get_artifact_path,
@@ -284,8 +284,11 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
284284
logger.info(f"Artifact not found in model {model_id}.")
285285
return False
286286

287+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
287288
def get_config_from_metadata(
288-
self, model_id: str, metadata_key: str
289+
self,
290+
model_id: str,
291+
metadata_key: str,
289292
) -> ModelConfigResult:
290293
"""Gets the config for the given Aqua model from model catalog metadata content.
291294
@@ -300,8 +303,9 @@ def get_config_from_metadata(
300303
ModelConfigResult
301304
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
302305
"""
303-
config = {}
306+
config: Dict[str, Any] = {}
304307
oci_model = self.ds_client.get_model(model_id).data
308+
305309
try:
306310
config = self.ds_client.get_model_defined_metadatum_artifact_content(
307311
model_id, metadata_key
@@ -346,8 +350,10 @@ def get_config(
346350
ModelConfigResult
347351
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
348352
"""
349-
config_folder = config_folder or ConfigFolder.CONFIG
353+
config: Dict[str, Any] = {}
350354
oci_model = self.ds_client.get_model(model_id).data
355+
356+
config_folder = config_folder or ConfigFolder.CONFIG
351357
oci_aqua = (
352358
(
353359
Tags.AQUA_TAG in oci_model.freeform_tags
@@ -357,9 +363,10 @@ def get_config(
357363
else False
358364
)
359365
if not oci_aqua:
360-
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
366+
logger.debug(f"Target model {oci_model.id} is not an Aqua model.")
367+
return ModelConfigResult(config=config, model_details=oci_model)
368+
# raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
361369

362-
config: Dict[str, Any] = {}
363370
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
364371
if not artifact_path:
365372
logger.debug(

ads/aqua/modeldeployment/utils.py

+82-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from ads.aqua.app import AquaApp
1414
from ads.aqua.common.entities import ComputeShapeSummary, ModelConfigResult
15+
from ads.aqua.common.enums import Tags
16+
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
1517
from ads.aqua.model.constants import AquaModelMetadataKeys
1618
from ads.aqua.modeldeployment.entities import (
1719
AquaDeploymentConfig,
@@ -194,24 +196,95 @@ def _fetch_deployment_configs_concurrently(
194196
}
195197

196198
def _fetch_deployment_config_from_metadata_and_oss(
197-
self, model_id
199+
self, model_id: str
198200
) -> ModelConfigResult:
201+
"""
202+
Attempts to retrieve the deployment configuration for a given model.
203+
204+
This method first checks whether the model is a fine-tuned model by inspecting its tags.
205+
If so, it tries to extract the base model ID from the custom metadata and fetch configuration
206+
from the base model instead.
207+
208+
It tries two sources in the following order:
209+
1. Model metadata (custom metadata key)
210+
2. Object Storage fallback (for backward compatibility or large configs)
211+
212+
Parameters
213+
----------
214+
model_id : str
215+
The OCID of the model in the Model Catalog.
216+
217+
Returns
218+
-------
219+
ModelConfigResult
220+
The configuration and model details, possibly empty if no config found.
221+
"""
222+
# Get model details from Model Catalog
223+
oci_model = self.deployment_app.ds_client.get_model(model_id).data
224+
225+
# Check if the model is fine-tuned
226+
is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG in oci_model.freeform_tags
227+
if is_fine_tuned_model:
228+
logger.info(
229+
"Model '%s' is marked as fine-tuned. Attempting to retrieve base model ID.",
230+
model_id,
231+
)
232+
233+
base_model_id = next(
234+
(
235+
item.value
236+
for item in oci_model.custom_metadata_list
237+
if item.key == FineTuneCustomMetadata.FINE_TUNE_SOURCE
238+
),
239+
None,
240+
)
241+
242+
if not base_model_id:
243+
logger.warning(
244+
"Base model reference not found in custom metadata for fine-tuned model '%s'.",
245+
model_id,
246+
)
247+
return ModelConfigResult(config={}, model_details=oci_model)
248+
249+
logger.info(
250+
"Base model for fine-tuned model '%s' is '%s'. Using base model for config extraction.",
251+
model_id,
252+
base_model_id,
253+
)
254+
model_id = base_model_id
255+
256+
# Attempt to retrieve config from metadata
257+
metadata_key = AquaModelMetadataKeys.DEPLOYMENT_CONFIGURATION
199258
config = self.deployment_app.get_config_from_metadata(
200-
model_id, AquaModelMetadataKeys.DEPLOYMENT_CONFIGURATION
259+
model_id=model_id, metadata_key=metadata_key
201260
)
202-
if config:
261+
262+
if config and config.config:
203263
logger.info(
204-
f"Fetched metadata key '{AquaModelMetadataKeys.DEPLOYMENT_CONFIGURATION}' from defined metadata for model '{model_id}'"
264+
"Deployment configuration '%s' successfully retrieved from model metadata for model '%s'.",
265+
metadata_key,
266+
model_id,
205267
)
206268
return config
207-
else:
269+
270+
# Attempt to retrieve config from Object Storage
271+
logger.info(
272+
"Deployment configuration '%s' not found in metadata for model '%s'. Falling back to Object Storage.",
273+
metadata_key,
274+
model_id,
275+
)
276+
config = self.deployment_app.get_config(
277+
model_id=model_id, config_file_name=AQUA_MODEL_DEPLOYMENT_CONFIG
278+
)
279+
280+
if config and config.config:
208281
logger.info(
209-
f"Fetching '{AquaModelMetadataKeys.DEPLOYMENT_CONFIGURATION}' from object storage bucket for {model_id}'"
210-
)
211-
return self.deployment_app.get_config(
212-
model_id, AQUA_MODEL_DEPLOYMENT_CONFIG
282+
"Successfully retrieved deployment configuration from Object Storage for model '%s'.",
283+
model_id,
213284
)
214285

286+
return config or ModelConfigResult()
287+
215288
def _extract_model_shape_gpu(
216289
self,
217290
deployment_configs: Dict[str, AquaDeploymentConfig],

0 commit comments

Comments
 (0)