Skip to content

[WIP][AQUA] Add Supporting Fine-Tuned Models in Multi-Model Deployment #1186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions ads/aqua/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ads.aqua import logger
from ads.aqua.common.entities import ModelConfigResult
from ads.aqua.common.enums import ConfigFolder, Tags
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.errors import AquaValueError
from ads.aqua.common.utils import (
_is_valid_mvs,
get_artifact_path,
Expand Down Expand Up @@ -284,8 +284,11 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
logger.info(f"Artifact not found in model {model_id}.")
return False

@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
def get_config_from_metadata(
self, model_id: str, metadata_key: str
self,
model_id: str,
metadata_key: str,
) -> ModelConfigResult:
"""Gets the config for the given Aqua model from model catalog metadata content.

Expand All @@ -300,8 +303,9 @@ def get_config_from_metadata(
ModelConfigResult
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
"""
config = {}
config: Dict[str, Any] = {}
oci_model = self.ds_client.get_model(model_id).data

try:
config = self.ds_client.get_model_defined_metadatum_artifact_content(
model_id, metadata_key
Expand Down Expand Up @@ -346,8 +350,10 @@ def get_config(
ModelConfigResult
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
"""
config_folder = config_folder or ConfigFolder.CONFIG
config: Dict[str, Any] = {}
oci_model = self.ds_client.get_model(model_id).data

config_folder = config_folder or ConfigFolder.CONFIG
oci_aqua = (
(
Tags.AQUA_TAG in oci_model.freeform_tags
Expand All @@ -357,9 +363,10 @@ def get_config(
else False
)
if not oci_aqua:
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")
logger.debug(f"Target model {oci_model.id} is not an Aqua model.")
return ModelConfigResult(config=config, model_details=oci_model)
# raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")

config: Dict[str, Any] = {}
artifact_path = get_artifact_path(oci_model.custom_metadata_list)
if not artifact_path:
logger.debug(
Expand Down
91 changes: 82 additions & 9 deletions ads/aqua/modeldeployment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from ads.aqua.app import AquaApp
from ads.aqua.common.entities import ComputeShapeSummary, ModelConfigResult
from ads.aqua.common.enums import Tags
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
from ads.aqua.model.constants import AquaModelMetadataKeys
from ads.aqua.modeldeployment.entities import (
AquaDeploymentConfig,
Expand Down Expand Up @@ -194,24 +196,95 @@ def _fetch_deployment_configs_concurrently(
}

def _fetch_deployment_config_from_metadata_and_oss(
self, model_id
self, model_id: str
) -> ModelConfigResult:
"""
Attempts to retrieve the deployment configuration for a given model.

This method first checks whether the model is a fine-tuned model by inspecting its tags.
If so, it tries to extract the base model ID from the custom metadata and fetch configuration
from the base model instead.

It tries two sources in the following order:
1. Model metadata (custom metadata key)
2. Object Storage fallback (for backward compatibility or large configs)

Parameters
----------
model_id : str
The OCID of the model in the Model Catalog.

Returns
-------
ModelConfigResult
The configuration and model details, possibly empty if no config found.
"""
# Get model details from Model Catalog
oci_model = self.deployment_app.ds_client.get_model(model_id).data

# Check if the model is fine-tuned
is_fine_tuned_model = Tags.AQUA_FINE_TUNED_MODEL_TAG in oci_model.freeform_tags
if is_fine_tuned_model:
logger.info(
"Model '%s' is marked as fine-tuned. Attempting to retrieve base model ID.",
model_id,
)

base_model_id = next(
(
item.value
for item in oci_model.custom_metadata_list
if item.key == FineTuneCustomMetadata.FINE_TUNE_SOURCE
),
None,
)

if not base_model_id:
logger.warning(
"Base model reference not found in custom metadata for fine-tuned model '%s'.",
model_id,
)
return ModelConfigResult(config={}, model_details=oci_model)

logger.info(
"Base model for fine-tuned model '%s' is '%s'. Using base model for config extraction.",
model_id,
base_model_id,
)
model_id = base_model_id

# Attempt to retrieve config from metadata
metadata_key = AquaModelMetadataKeys.DEPLOYMENT_CONFIGURATION
config = self.deployment_app.get_config_from_metadata(
model_id, AquaModelMetadataKeys.DEPLOYMENT_CONFIGURATION
model_id=model_id, metadata_key=metadata_key
)
if config:

if config and config.config:
logger.info(
f"Fetched metadata key '{AquaModelMetadataKeys.DEPLOYMENT_CONFIGURATION}' from defined metadata for model '{model_id}'"
"Deployment configuration '%s' successfully retrieved from model metadata for model '%s'.",
metadata_key,
model_id,
)
return config
else:

# Attempt to retrieve config from Object Storage
logger.info(
"Deployment configuration '%s' not found in metadata for model '%s'. Falling back to Object Storage.",
metadata_key,
model_id,
)
config = self.deployment_app.get_config(
model_id=model_id, config_file_name=AQUA_MODEL_DEPLOYMENT_CONFIG
)

if config and config.config:
logger.info(
f"Fetching '{AquaModelMetadataKeys.DEPLOYMENT_CONFIGURATION}' from object storage bucket for {model_id}'"
)
return self.deployment_app.get_config(
model_id, AQUA_MODEL_DEPLOYMENT_CONFIG
"Successfully retrieved deployment configuration from Object Storage for model '%s'.",
model_id,
)

return config or ModelConfigResult()

def _extract_model_shape_gpu(
self,
deployment_configs: Dict[str, AquaDeploymentConfig],
Expand Down