Skip to content

[AQUA] Adding ADS support for embedding models in Multi Model Deployment #1163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 25, 2025
Merged
3 changes: 3 additions & 0 deletions ads/aqua/common/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class AquaMultiModelRef(Serializable):
The name of the model.
gpu_count : Optional[int]
Number of GPUs required for deployment.
model_task : Optional[str]
The task that model operates on. Supported tasks are in MultiModelSupportedTaskType
env_var : Optional[Dict[str, Any]]
Optional environment variables to override during deployment.
artifact_location : Optional[str]
Expand All @@ -162,6 +164,7 @@ class AquaMultiModelRef(Serializable):
gpu_count: Optional[int] = Field(
None, description="The gpu count allocation for the model."
)
model_task: Optional[str] = Field(None, description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType")
env_var: Optional[dict] = Field(
default_factory=dict, description="The environment variables of the model."
)
Expand Down
6 changes: 4 additions & 2 deletions ads/aqua/model/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ class FineTuningCustomMetadata(ExtendedEnum):


class MultiModelSupportedTaskType(ExtendedEnum):
TEXT_GENERATION = "text-generation"
TEXT_GENERATION_ALT = "text_generation"
TEXT_GENERATION = "text_generation"
IMAGE_TEXT_TO_TEXT = "image_text_to_text"
CODE_SYNTHESIS = "code_synthesis"
EMBEDDING = "text_embedding"
34 changes: 24 additions & 10 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import pathlib
import re
from datetime import datetime, timedelta
from threading import Lock
from typing import Any, Dict, List, Optional, Set, Union
Expand Down Expand Up @@ -80,6 +81,7 @@
ImportModelDetails,
ModelValidationResult,
)
from ads.aqua.model.enums import MultiModelSupportedTaskType
from ads.common.auth import default_signer
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
from ads.common.utils import (
Expand Down Expand Up @@ -307,18 +309,10 @@ def create_multi(
# "Currently only service models are supported for multi model deployment."
# )

# TODO uncomment the section below if only the specific types of models should be allowed for multi-model deployment
# if (
# source_model.freeform_tags.get(Tags.TASK, UNKNOWN).lower()
# not in MultiModelSupportedTaskType
# ):
# raise AquaValueError(
# f"Invalid or missing {Tags.TASK} tag for selected model {display_name}. "
# f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment."
# )

display_name_list.append(display_name)

self._extract_model_task(model, source_model)

# Retrieve model artifact
model_artifact_path = source_model.artifact
if not model_artifact_path:
Expand Down Expand Up @@ -707,6 +701,26 @@ def edit_registered_model(
else:
raise AquaRuntimeError("Only registered unverified models can be edited.")

def _extract_model_task(
self,
model: AquaMultiModelRef,
source_model: DataScienceModel,
) -> None:
"""In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user"""
# user does not supply model task, we extract from model metadata
if not model.model_task:
model.model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN)

task_tag = re.sub(r"-", "_", model.model_task).lower()
# re-visit logic when more model task types are supported
if task_tag in MultiModelSupportedTaskType:
model.model_task = task_tag
else:
raise AquaValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can show more informative error:

raise AquaValueError(
     f"Invalid or missing {task_tag} tag for selected model {display_name}. "
     f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we removed the task level validation in the recent release, any reason to add the validation in the function `_extract_model_task again?

This is fine for now since we only have 1 verified embedding model, but if in the future if we start supporting (unverified) models, embedding models could have task value as feature_extraction or sentence_similarity. Might be good to add a comment here to reconsider this logic when we start supporting additional models.

f"Invalid or missing {task_tag} tag for selected model {source_model.display_name}. "
f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment."
)

def _fetch_metric_from_metadata(
self,
custom_metadata_list: ModelCustomMetadata,
Expand Down
10 changes: 6 additions & 4 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ def create(
# validate instance shape availability in compartment
available_shapes = [
shape.name.lower()
for shape in self.list_shapes(
compartment_id=compartment_id
)
for shape in self.list_shapes(compartment_id=compartment_id)
]

if create_deployment_details.instance_shape.lower() not in available_shapes:
Expand Down Expand Up @@ -645,7 +643,11 @@ def _create_multi(
os_path = ObjectStorageDetails.from_path(artifact_path_prefix)
artifact_path_prefix = os_path.filepath.rstrip("/")

model_config.append({"params": params, "model_path": artifact_path_prefix})
# override by-default completion/ chat endpoint with other endpoint (embedding)
config_data = {"params": params, "model_path": artifact_path_prefix}
if model.model_task:
config_data["model_task"] = model.model_task
model_config.append(config_data)
model_name_list.append(model.model_name)

env_var.update({AQUA_MULTI_MODEL_CONFIG: json.dumps({"models": model_config})})
Expand Down
15 changes: 13 additions & 2 deletions tests/unitary/with_extras/aqua/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
ModelDeploymentConfigSummary,
ModelParams,
)
from ads.aqua.model.enums import MultiModelSupportedTaskType
from ads.aqua.modeldeployment.utils import MultiModelDeploymentConfigLoader
from ads.model.datascience_model import DataScienceModel
from ads.model.deployment.model_deployment import ModelDeployment
Expand Down Expand Up @@ -276,7 +277,7 @@ class TestDataset:
"environment_configuration_type": "OCIR_CONTAINER",
"environment_variables": {
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions",
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}',
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "text_embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/", "model_task": "image_text_to_text"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/", "model_task": "code_synthesis"}]}',
},
"health_check_port": 8080,
"image": "dsmc://image-name:1.0.0.0",
Expand Down Expand Up @@ -486,27 +487,30 @@ class TestDataset:
"gpu_count": 2,
"model_id": "test_model_id_1",
"model_name": "test_model_1",
"model_task": "text_embedding",
"artifact_location": "test_location_1",
},
{
"env_var": {},
"gpu_count": 2,
"model_id": "test_model_id_2",
"model_name": "test_model_2",
"model_task": "image_text_to_text",
"artifact_location": "test_location_2",
},
{
"env_var": {},
"gpu_count": 2,
"model_id": "test_model_id_3",
"model_name": "test_model_3",
"model_task": "code_synthesis",
"artifact_location": "test_location_3",
},
],
"model_id": "ocid1.datasciencemodel.oc1.<region>.<OCID>",
"environment_variables": {
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions",
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}',
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "text_embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/", "model_task": "image_text_to_text"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/", "model_task": "code_synthesis"}]}',
},
"cmd": [],
"console_link": "https://cloud.oracle.com/data-science/model-deployments/ocid1.datasciencemodeldeployment.oc1.<region>.<MD_OCID>?region=region-name",
Expand Down Expand Up @@ -965,20 +969,23 @@ class TestDataset:
"gpu_count": 1,
"model_id": "ocid1.compartment.oc1..<OCID>",
"model_name": "model_one",
"model_task": "text_embedding",
"artifact_location": "artifact_location_one",
},
{
"env_var": {"--test_key_two": "test_value_two"},
"gpu_count": 1,
"model_id": "ocid1.compartment.oc1..<OCID>",
"model_name": "model_two",
"model_task": "image_text_to_text",
"artifact_location": "artifact_location_two",
},
{
"env_var": {"--test_key_three": "test_value_three"},
"gpu_count": 1,
"model_id": "ocid1.compartment.oc1..<OCID>",
"model_name": "model_three",
"model_task": "code_synthesis",
"artifact_location": "artifact_location_three",
},
]
Expand Down Expand Up @@ -1787,20 +1794,23 @@ def test_create_deployment_for_multi_model(
model_info_1 = AquaMultiModelRef(
model_id="test_model_id_1",
model_name="test_model_1",
model_task="text_embedding",
gpu_count=2,
artifact_location="test_location_1",
)

model_info_2 = AquaMultiModelRef(
model_id="test_model_id_2",
model_name="test_model_2",
model_task="image_text_to_text",
gpu_count=2,
artifact_location="test_location_2",
)

model_info_3 = AquaMultiModelRef(
model_id="test_model_id_3",
model_name="test_model_3",
model_task="code_synthesis",
gpu_count=2,
artifact_location="test_location_3",
)
Expand All @@ -1826,6 +1836,7 @@ def test_create_deployment_for_multi_model(

expected_attributes = set(AquaDeployment.__annotations__.keys())
actual_attributes = result.to_dict()

assert set(actual_attributes) == set(expected_attributes), "Attributes mismatch"
expected_result = copy.deepcopy(TestDataset.aqua_multi_deployment_object)
expected_result["state"] = "CREATING"
Expand Down
34 changes: 29 additions & 5 deletions tests/unitary/with_extras/aqua/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import json
import os
import re
import shlex
import tempfile
from dataclasses import asdict
Expand All @@ -13,17 +14,14 @@

import oci
import pytest

from ads.aqua.app import AquaApp
from ads.aqua.config.container_config import AquaContainerConfig
from huggingface_hub.hf_api import HfApi, ModelInfo
from parameterized import parameterized

import ads.aqua.model
import ads.common
import ads.common.oci_client
import ads.config

from ads.aqua.app import AquaApp
from ads.aqua.common.entities import AquaMultiModelRef
from ads.aqua.common.enums import ModelFormat, Tags
from ads.aqua.common.errors import (
Expand All @@ -32,6 +30,7 @@
AquaValueError,
)
from ads.aqua.common.utils import get_hf_model_info
from ads.aqua.config.container_config import AquaContainerConfig
from ads.aqua.constants import HF_METADATA_FOLDER
from ads.aqua.model import AquaModelApp
from ads.aqua.model.entities import (
Expand All @@ -40,14 +39,14 @@
ImportModelDetails,
ModelValidationResult,
)
from ads.aqua.model.enums import MultiModelSupportedTaskType
from ads.common.object_storage_details import ObjectStorageDetails
from ads.model.datascience_model import DataScienceModel
from ads.model.model_metadata import (
ModelCustomMetadata,
ModelProvenanceMetadata,
ModelTaxonomyMetadata,
)

from tests.unitary.with_extras.aqua.utils import ServiceManagedContainers


Expand Down Expand Up @@ -397,12 +396,14 @@ def test_create_multimodel(
model_info_1 = AquaMultiModelRef(
model_id="test_model_id_1",
gpu_count=2,
model_task = "text_embedding",
env_var={"params": "--trust-remote-code --max-model-len 60000"},
)

model_info_2 = AquaMultiModelRef(
model_id="test_model_id_2",
gpu_count=2,
model_task = "image_text_to_text",
env_var={"params": "--trust-remote-code --max-model-len 32000"},
)

Expand Down Expand Up @@ -439,6 +440,29 @@ def test_create_multimodel(
mock_model.custom_metadata_list = custom_metadata_list
mock_from_id.return_value = mock_model

# testing _extract_model_task when a user passes an invalid task to AquaMultiModelRef
model_info_1.model_task = "invalid_task"

with pytest.raises(AquaValueError):
model = self.app.create_multi(
models=[model_info_1, model_info_2],
project_id="test_project_id",
compartment_id="test_compartment_id",
)

# testing if a user tries to invoke a model with a task mode that is not yet supported
model_info_1.model_task = None
mock_model.freeform_tags["task"] = "unsupported_task"
with pytest.raises(AquaValueError):
model = self.app.create_multi(
models=[model_info_1, model_info_2],
project_id="test_project_id",
compartment_id="test_compartment_id",
)

mock_model.freeform_tags["task"] = "text-generation"
model_info_1.model_task = "text_embedding"

# will create a multi-model group
model = self.app.create_multi(
models=[model_info_1, model_info_2],
Expand Down