From 4248b986c04ad19e3528c2c2ead010e3ad9c333d Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Mon, 21 Apr 2025 13:49:52 -0700 Subject: [PATCH 01/11] added support for embedding models in multi model --- ads/aqua/common/entities.py | 4 ++++ ads/aqua/model/enums.py | 5 +++++ ads/aqua/model/model.py | 16 +++++++++++++++- ads/aqua/modeldeployment/deployment.py | 12 ++++++++---- 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/ads/aqua/common/entities.py b/ads/aqua/common/entities.py index bd7b2ede8..82a596662 100644 --- a/ads/aqua/common/entities.py +++ b/ads/aqua/common/entities.py @@ -151,6 +151,9 @@ class AquaMultiModelRef(Serializable): The name of the model. gpu_count : Optional[int] Number of GPUs required for deployment. + model_task : Optional[str] + The task that model operates on. + If specified, overrides by-default completion | chat inference endpoints with embedding endpoint. env_var : Optional[Dict[str, Any]] Optional environment variables to override during deployment. artifact_location : Optional[str] @@ -162,6 +165,7 @@ class AquaMultiModelRef(Serializable): gpu_count: Optional[int] = Field( None, description="The gpu count allocation for the model." ) + model_task: Optional[str] = Field(None, description="The task that model operates on.") env_var: Optional[dict] = Field( default_factory=dict, description="The environment variables of the model." ) diff --git a/ads/aqua/model/enums.py b/ads/aqua/model/enums.py index 1a21adabc..439729147 100644 --- a/ads/aqua/model/enums.py +++ b/ads/aqua/model/enums.py @@ -28,3 +28,8 @@ class FineTuningCustomMetadata(ExtendedEnum): class MultiModelSupportedTaskType(ExtendedEnum): TEXT_GENERATION = "text-generation" TEXT_GENERATION_ALT = "text_generation" + EMBEDDING_ALT = "text_embedding" + +class MultiModelConfigMode(ExtendedEnum): + EMBEDDING = "embedding" + DEFAULT = "completion" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 51732222d..cc9925e17 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -80,7 +80,7 @@ ImportModelDetails, ModelValidationResult, ) -from ads.aqua.model.enums import MultiModelSupportedTaskType +from ads.aqua.model.enums import MultiModelSupportedTaskType, MultiModelConfigMode from ads.common.auth import default_signer from ads.common.oci_resource import SEARCH_TYPE, OCIResource from ads.common.utils import ( @@ -316,6 +316,11 @@ def create_multi( display_name_list.append(display_name) + model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN) + + if model_task != UNKNOWN: + self._get_task(model, model_task) + # Retrieve model artifact model_artifact_path = source_model.artifact if not model_artifact_path: @@ -704,6 +709,15 @@ def edit_registered_model( else: raise AquaRuntimeError("Only registered unverified models can be edited.") + def _get_task( + self, + model: AquaMultiModelRef, + freeform_task_tag: str + ) -> str: + """Will set model task if freeform task tag from model needs a non-completion endpoint (embedding)""" + if freeform_task_tag == MultiModelSupportedTaskType.EMBEDDING_ALT: + model.model_task = MultiModelConfigMode.EMBEDDING + def _fetch_metric_from_metadata( self, custom_metadata_list: ModelCustomMetadata, diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index cdc77da3c..1a93d75c5 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -178,9 +178,7 @@ def create( # validate instance shape availability in compartment available_shapes = [ shape.name.lower() - for shape in self.list_shapes( - compartment_id=compartment_id - ) + for shape in self.list_shapes(compartment_id=compartment_id) ] if create_deployment_details.instance_shape.lower() not in available_shapes: @@ -645,9 +643,15 @@ def _create_multi( os_path = ObjectStorageDetails.from_path(artifact_path_prefix) artifact_path_prefix = os_path.filepath.rstrip("/") - model_config.append({"params": params, "model_path": artifact_path_prefix}) + # override by-default completion/ chat endpoint with other endpoint (embedding) + config_data = {"params": params, "model_path": artifact_path_prefix} + if model.model_task: + config_data["model_task"] = model.model_task + model_config.append(config_data) model_name_list.append(model.model_name) + print("***") + print(model_config) env_var.update({AQUA_MULTI_MODEL_CONFIG: json.dumps({"models": model_config})}) env_vars = container_spec.env_vars if container_spec else [] From 11802a7f5b9ae6e73b845e87657ba82a85e54280 Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Tue, 22 Apr 2025 18:45:16 -0700 Subject: [PATCH 02/11] fixed unit tests for embedding multi model deployments --- ads/aqua/modeldeployment/deployment.py | 4 +--- tests/unitary/with_extras/aqua/test_deployment.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 1a93d75c5..0afb35f14 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -650,8 +650,6 @@ def _create_multi( model_config.append(config_data) model_name_list.append(model.model_name) - print("***") - print(model_config) env_var.update({AQUA_MULTI_MODEL_CONFIG: json.dumps({"models": model_config})}) env_vars = container_spec.env_vars if container_spec else [] @@ -798,7 +796,7 @@ def _create_deployment( .with_infrastructure(infrastructure) .with_runtime(container_runtime) ).deploy(wait_for_completion=False) - + print(deployment) deployment_id = deployment.id logger.info( f"Aqua model deployment {deployment_id} created for model {aqua_model_id}." diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index b0bc1d5fe..073c66e32 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -276,7 +276,7 @@ class TestDataset: "environment_configuration_type": "OCIR_CONTAINER", "environment_variables": { "MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions", - "MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}', + "MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}', }, "health_check_port": 8080, "image": "dsmc://image-name:1.0.0.0", @@ -486,6 +486,7 @@ class TestDataset: "gpu_count": 2, "model_id": "test_model_id_1", "model_name": "test_model_1", + "model_task": "embedding", "artifact_location": "test_location_1", }, { @@ -493,6 +494,7 @@ class TestDataset: "gpu_count": 2, "model_id": "test_model_id_2", "model_name": "test_model_2", + "model_task": None, "artifact_location": "test_location_2", }, { @@ -500,13 +502,14 @@ class TestDataset: "gpu_count": 2, "model_id": "test_model_id_3", "model_name": "test_model_3", + "model_task": None, "artifact_location": "test_location_3", }, ], "model_id": "ocid1.datasciencemodel.oc1..", "environment_variables": { "MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions", - "MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}', + "MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}', }, "cmd": [], "console_link": "https://cloud.oracle.com/data-science/model-deployments/ocid1.datasciencemodeldeployment.oc1..?region=region-name", @@ -965,6 +968,7 @@ class TestDataset: "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_one", + "model_task": "embedding", "artifact_location": "artifact_location_one", }, { @@ -972,6 +976,7 @@ class TestDataset: "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_two", + "model_task": None, "artifact_location": "artifact_location_two", }, { @@ -979,6 +984,7 @@ class TestDataset: "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_three", + "model_task": None, "artifact_location": "artifact_location_three", }, ] @@ -1787,6 +1793,7 @@ def test_create_deployment_for_multi_model( model_info_1 = AquaMultiModelRef( model_id="test_model_id_1", model_name="test_model_1", + model_task="embedding", gpu_count=2, artifact_location="test_location_1", ) @@ -1794,6 +1801,7 @@ def test_create_deployment_for_multi_model( model_info_2 = AquaMultiModelRef( model_id="test_model_id_2", model_name="test_model_2", + model_task=None, gpu_count=2, artifact_location="test_location_2", ) @@ -1801,6 +1809,7 @@ def test_create_deployment_for_multi_model( model_info_3 = AquaMultiModelRef( model_id="test_model_id_3", model_name="test_model_3", + model_task=None, gpu_count=2, artifact_location="test_location_3", ) @@ -1826,6 +1835,7 @@ def test_create_deployment_for_multi_model( expected_attributes = set(AquaDeployment.__annotations__.keys()) actual_attributes = result.to_dict() + print(result) assert set(actual_attributes) == set(expected_attributes), "Attributes mismatch" expected_result = copy.deepcopy(TestDataset.aqua_multi_deployment_object) expected_result["state"] = "CREATING" From c1a0478beca5b24f31233561ae6f17e7352a5130 Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Tue, 22 Apr 2025 19:01:36 -0700 Subject: [PATCH 03/11] removed print statements and added clarifying docstring --- ads/aqua/model/model.py | 2 +- ads/aqua/modeldeployment/deployment.py | 2 +- tests/unitary/with_extras/aqua/test_deployment.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index cc9925e17..3c23a01cb 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -714,7 +714,7 @@ def _get_task( model: AquaMultiModelRef, freeform_task_tag: str ) -> str: - """Will set model task if freeform task tag from model needs a non-completion endpoint (embedding)""" + """In a Multi Model Deployment, will set model task if freeform task tag from model needs a non-completion endpoint (embedding)""" if freeform_task_tag == MultiModelSupportedTaskType.EMBEDDING_ALT: model.model_task = MultiModelConfigMode.EMBEDDING diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 0afb35f14..90f211af3 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -796,7 +796,7 @@ def _create_deployment( .with_infrastructure(infrastructure) .with_runtime(container_runtime) ).deploy(wait_for_completion=False) - print(deployment) + deployment_id = deployment.id logger.info( f"Aqua model deployment {deployment_id} created for model {aqua_model_id}." diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index 073c66e32..920096b85 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -1835,7 +1835,7 @@ def test_create_deployment_for_multi_model( expected_attributes = set(AquaDeployment.__annotations__.keys()) actual_attributes = result.to_dict() - print(result) + assert set(actual_attributes) == set(expected_attributes), "Attributes mismatch" expected_result = copy.deepcopy(TestDataset.aqua_multi_deployment_object) expected_result["state"] = "CREATING" From 7897b38e3faa60796f72f950fe67fb5c66a28749 Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Wed, 23 Apr 2025 15:36:29 -0700 Subject: [PATCH 04/11] added validation logic for model_task and unit test in test_model.py --- ads/aqua/model/enums.py | 8 +++---- ads/aqua/model/model.py | 21 +++++++++-------- .../with_extras/aqua/test_deployment.py | 23 ++++++++++--------- tests/unitary/with_extras/aqua/test_model.py | 22 ++++++++++++++---- 4 files changed, 44 insertions(+), 30 deletions(-) diff --git a/ads/aqua/model/enums.py b/ads/aqua/model/enums.py index 439729147..33a8dbda6 100644 --- a/ads/aqua/model/enums.py +++ b/ads/aqua/model/enums.py @@ -28,8 +28,6 @@ class FineTuningCustomMetadata(ExtendedEnum): class MultiModelSupportedTaskType(ExtendedEnum): TEXT_GENERATION = "text-generation" TEXT_GENERATION_ALT = "text_generation" - EMBEDDING_ALT = "text_embedding" - -class MultiModelConfigMode(ExtendedEnum): - EMBEDDING = "embedding" - DEFAULT = "completion" + IMAGE_TEXT_TO_TEXT = "image_text_to_text" + CODE_SYNTHESIS = "code_synthesis" + EMBEDDING = "text_embedding" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 3c23a01cb..a730002e3 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -4,6 +4,7 @@ import json import os import pathlib +import re from datetime import datetime, timedelta from threading import Lock from typing import Any, Dict, List, Optional, Set, Union @@ -80,7 +81,7 @@ ImportModelDetails, ModelValidationResult, ) -from ads.aqua.model.enums import MultiModelSupportedTaskType, MultiModelConfigMode +from ads.aqua.model.enums import MultiModelSupportedTaskType from ads.common.auth import default_signer from ads.common.oci_resource import SEARCH_TYPE, OCIResource from ads.common.utils import ( @@ -709,14 +710,16 @@ def edit_registered_model( else: raise AquaRuntimeError("Only registered unverified models can be edited.") - def _get_task( - self, - model: AquaMultiModelRef, - freeform_task_tag: str - ) -> str: - """In a Multi Model Deployment, will set model task if freeform task tag from model needs a non-completion endpoint (embedding)""" - if freeform_task_tag == MultiModelSupportedTaskType.EMBEDDING_ALT: - model.model_task = MultiModelConfigMode.EMBEDDING + def _get_task(self, model: AquaMultiModelRef, freeform_task_tag: str) -> str: + """In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user""" + task_tag = re.sub(r"-", "_", freeform_task_tag) + + if task_tag in MultiModelSupportedTaskType: + model.model_task = task_tag + else: + raise AquaValueError( + f"{freeform_task_tag} is not supported. Valid model_task inputs are: {MultiModelSupportedTaskType.values()}." + ) def _fetch_metric_from_metadata( self, diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index 920096b85..ce27b4348 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -45,6 +45,7 @@ ModelDeploymentConfigSummary, ModelParams, ) +from ads.aqua.model.enums import MultiModelSupportedTaskType from ads.aqua.modeldeployment.utils import MultiModelDeploymentConfigLoader from ads.model.datascience_model import DataScienceModel from ads.model.deployment.model_deployment import ModelDeployment @@ -276,7 +277,7 @@ class TestDataset: "environment_configuration_type": "OCIR_CONTAINER", "environment_variables": { "MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions", - "MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}', + "MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "text_embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/", "model_task": "image_text_to_text"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/", "model_task": "code_synthesis"}]}', }, "health_check_port": 8080, "image": "dsmc://image-name:1.0.0.0", @@ -486,7 +487,7 @@ class TestDataset: "gpu_count": 2, "model_id": "test_model_id_1", "model_name": "test_model_1", - "model_task": "embedding", + "model_task": "text_embedding", "artifact_location": "test_location_1", }, { @@ -494,7 +495,7 @@ class TestDataset: "gpu_count": 2, "model_id": "test_model_id_2", "model_name": "test_model_2", - "model_task": None, + "model_task": "image_text_to_text", "artifact_location": "test_location_2", }, { @@ -502,14 +503,14 @@ class TestDataset: "gpu_count": 2, "model_id": "test_model_id_3", "model_name": "test_model_3", - "model_task": None, + "model_task": "code_synthesis", "artifact_location": "test_location_3", }, ], "model_id": "ocid1.datasciencemodel.oc1..", "environment_variables": { "MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions", - "MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}', + "MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "text_embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/", "model_task": "image_text_to_text"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/", "model_task": "code_synthesis"}]}', }, "cmd": [], "console_link": "https://cloud.oracle.com/data-science/model-deployments/ocid1.datasciencemodeldeployment.oc1..?region=region-name", @@ -968,7 +969,7 @@ class TestDataset: "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_one", - "model_task": "embedding", + "model_task": "text_embedding", "artifact_location": "artifact_location_one", }, { @@ -976,7 +977,7 @@ class TestDataset: "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_two", - "model_task": None, + "model_task": "image_text_to_text", "artifact_location": "artifact_location_two", }, { @@ -984,7 +985,7 @@ class TestDataset: "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_three", - "model_task": None, + "model_task": "code_synthesis", "artifact_location": "artifact_location_three", }, ] @@ -1793,7 +1794,7 @@ def test_create_deployment_for_multi_model( model_info_1 = AquaMultiModelRef( model_id="test_model_id_1", model_name="test_model_1", - model_task="embedding", + model_task="text_embedding", gpu_count=2, artifact_location="test_location_1", ) @@ -1801,7 +1802,7 @@ def test_create_deployment_for_multi_model( model_info_2 = AquaMultiModelRef( model_id="test_model_id_2", model_name="test_model_2", - model_task=None, + model_task="image_text_to_text", gpu_count=2, artifact_location="test_location_2", ) @@ -1809,7 +1810,7 @@ def test_create_deployment_for_multi_model( model_info_3 = AquaMultiModelRef( model_id="test_model_id_3", model_name="test_model_3", - model_task=None, + model_task="code_synthesis", gpu_count=2, artifact_location="test_location_3", ) diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index 0cb14c98f..40541fc04 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -5,6 +5,7 @@ import json import os +import re import shlex import tempfile from dataclasses import asdict @@ -13,9 +14,6 @@ import oci import pytest - -from ads.aqua.app import AquaApp -from ads.aqua.config.container_config import AquaContainerConfig from huggingface_hub.hf_api import HfApi, ModelInfo from parameterized import parameterized @@ -23,7 +21,7 @@ import ads.common import ads.common.oci_client import ads.config - +from ads.aqua.app import AquaApp from ads.aqua.common.entities import AquaMultiModelRef from ads.aqua.common.enums import ModelFormat, Tags from ads.aqua.common.errors import ( @@ -32,6 +30,7 @@ AquaValueError, ) from ads.aqua.common.utils import get_hf_model_info +from ads.aqua.config.container_config import AquaContainerConfig from ads.aqua.constants import HF_METADATA_FOLDER from ads.aqua.model import AquaModelApp from ads.aqua.model.entities import ( @@ -40,6 +39,7 @@ ImportModelDetails, ModelValidationResult, ) +from ads.aqua.model.enums import MultiModelSupportedTaskType from ads.common.object_storage_details import ObjectStorageDetails from ads.model.datascience_model import DataScienceModel from ads.model.model_metadata import ( @@ -47,7 +47,6 @@ ModelProvenanceMetadata, ModelTaxonomyMetadata, ) - from tests.unitary.with_extras.aqua.utils import ServiceManagedContainers @@ -397,12 +396,14 @@ def test_create_multimodel( model_info_1 = AquaMultiModelRef( model_id="test_model_id_1", gpu_count=2, + model_task = "text_embedding", env_var={"params": "--trust-remote-code --max-model-len 60000"}, ) model_info_2 = AquaMultiModelRef( model_id="test_model_id_2", gpu_count=2, + model_task = "image_text_to_text", env_var={"params": "--trust-remote-code --max-model-len 32000"}, ) @@ -439,6 +440,17 @@ def test_create_multimodel( mock_model.custom_metadata_list = custom_metadata_list mock_from_id.return_value = mock_model + mock_model.freeform_tags["task"] = "invalid_task" + + with pytest.raises(AquaValueError): + model = self.app.create_multi( + models=[model_info_1, model_info_2], + project_id="test_project_id", + compartment_id="test_compartment_id", + ) + + mock_model.freeform_tags["task"] = "text-generation" + # will create a multi-model group model = self.app.create_multi( models=[model_info_1, model_info_2], From 987b60d4ba9360de395a5743d93285beee263b75 Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Thu, 24 Apr 2025 09:47:07 -0700 Subject: [PATCH 05/11] refactored _get_task --- ads/aqua/model/model.py | 19 ++++++++++++------- tests/unitary/with_extras/aqua/test_model.py | 14 +++++++++++++- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index a730002e3..307a2a031 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -317,10 +317,7 @@ def create_multi( display_name_list.append(display_name) - model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN) - - if model_task != UNKNOWN: - self._get_task(model, model_task) + self._get_task(model, source_model) # Retrieve model artifact model_artifact_path = source_model.artifact @@ -710,15 +707,23 @@ def edit_registered_model( else: raise AquaRuntimeError("Only registered unverified models can be edited.") - def _get_task(self, model: AquaMultiModelRef, freeform_task_tag: str) -> str: + def _get_task( + self, + model: AquaMultiModelRef, + source_model: DataScienceModel, + ) -> str: """In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user""" - task_tag = re.sub(r"-", "_", freeform_task_tag) + # user does not supply model task, we extract from model metadata + if not model.model_task: + model.model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN) + + task_tag = re.sub(r"-", "_", model.model_task) if task_tag in MultiModelSupportedTaskType: model.model_task = task_tag else: raise AquaValueError( - f"{freeform_task_tag} is not supported. Valid model_task inputs are: {MultiModelSupportedTaskType.values()}." + f"{task_tag} is not supported. Valid model_task inputs are: {MultiModelSupportedTaskType.values()}." ) def _fetch_metric_from_metadata( diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index 40541fc04..abb3acfe6 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -440,7 +440,8 @@ def test_create_multimodel( mock_model.custom_metadata_list = custom_metadata_list mock_from_id.return_value = mock_model - mock_model.freeform_tags["task"] = "invalid_task" + # testing _get_task when a user passes an invalid task to AquaMultiModelRef + model_info_1.model_task = "invalid_task" with pytest.raises(AquaValueError): model = self.app.create_multi( @@ -449,7 +450,18 @@ def test_create_multimodel( compartment_id="test_compartment_id", ) + # testing if a user tries to invoke a model with a task mode that is not yet supported + model_info_1.model_task = None + mock_model.freeform_tags["task"] = "unsupported_task" + with pytest.raises(AquaValueError): + model = self.app.create_multi( + models=[model_info_1, model_info_2], + project_id="test_project_id", + compartment_id="test_compartment_id", + ) + mock_model.freeform_tags["task"] = "text-generation" + model_info_1.model_task = "text_embedding" # will create a multi-model group model = self.app.create_multi( From f37c6e067a085e5eb1750c2ec2e46e941f7ec4e7 Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Thu, 24 Apr 2025 09:53:38 -0700 Subject: [PATCH 06/11] removed comment and added case-insensitive match for model_task --- ads/aqua/model/enums.py | 3 +-- ads/aqua/model/model.py | 12 +----------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/ads/aqua/model/enums.py b/ads/aqua/model/enums.py index 33a8dbda6..b953b2383 100644 --- a/ads/aqua/model/enums.py +++ b/ads/aqua/model/enums.py @@ -26,8 +26,7 @@ class FineTuningCustomMetadata(ExtendedEnum): class MultiModelSupportedTaskType(ExtendedEnum): - TEXT_GENERATION = "text-generation" - TEXT_GENERATION_ALT = "text_generation" + TEXT_GENERATION = "text_generation" IMAGE_TEXT_TO_TEXT = "image_text_to_text" CODE_SYNTHESIS = "code_synthesis" EMBEDDING = "text_embedding" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 307a2a031..5baf3822c 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -305,16 +305,6 @@ def create_multi( # "Currently only service models are supported for multi model deployment." # ) - # TODO uncomment the section below if only the specific types of models should be allowed for multi-model deployment - # if ( - # source_model.freeform_tags.get(Tags.TASK, UNKNOWN).lower() - # not in MultiModelSupportedTaskType - # ): - # raise AquaValueError( - # f"Invalid or missing {Tags.TASK} tag for selected model {display_name}. " - # f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment." - # ) - display_name_list.append(display_name) self._get_task(model, source_model) @@ -717,7 +707,7 @@ def _get_task( if not model.model_task: model.model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN) - task_tag = re.sub(r"-", "_", model.model_task) + task_tag = re.sub(r"-", "_", model.model_task).lower() if task_tag in MultiModelSupportedTaskType: model.model_task = task_tag From e0bbc9b5c3e09e9c3ba5e4cd64379e60bf47e520 Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Thu, 24 Apr 2025 09:57:45 -0700 Subject: [PATCH 07/11] updated comment on model_task pydantic parameter description --- ads/aqua/common/entities.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ads/aqua/common/entities.py b/ads/aqua/common/entities.py index 82a596662..e33599926 100644 --- a/ads/aqua/common/entities.py +++ b/ads/aqua/common/entities.py @@ -152,8 +152,7 @@ class AquaMultiModelRef(Serializable): gpu_count : Optional[int] Number of GPUs required for deployment. model_task : Optional[str] - The task that model operates on. - If specified, overrides by-default completion | chat inference endpoints with embedding endpoint. + The task that model operates on. Supported tasks are in MultiModelSupportedTaskType env_var : Optional[Dict[str, Any]] Optional environment variables to override during deployment. artifact_location : Optional[str] @@ -165,7 +164,7 @@ class AquaMultiModelRef(Serializable): gpu_count: Optional[int] = Field( None, description="The gpu count allocation for the model." ) - model_task: Optional[str] = Field(None, description="The task that model operates on.") + model_task: Optional[str] = Field(None, description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType") env_var: Optional[dict] = Field( default_factory=dict, description="The environment variables of the model." ) From 44c25b8b9f745267dc598d07273c0c077d7128c9 Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Thu, 24 Apr 2025 11:16:00 -0700 Subject: [PATCH 08/11] fixed type signature and changed method name --- ads/aqua/model/model.py | 6 +++--- tests/unitary/with_extras/aqua/test_model.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 18b86da4f..f7659b4b9 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -310,7 +310,7 @@ def create_multi( display_name_list.append(display_name) - self._get_task(model, source_model) + self._extract_model_task(model, source_model) # Retrieve model artifact model_artifact_path = source_model.artifact @@ -700,11 +700,11 @@ def edit_registered_model( else: raise AquaRuntimeError("Only registered unverified models can be edited.") - def _get_task( + def _extract_model_task( self, model: AquaMultiModelRef, source_model: DataScienceModel, - ) -> str: + ) -> None: """In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user""" # user does not supply model task, we extract from model metadata if not model.model_task: diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index abb3acfe6..1587b7592 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -440,7 +440,7 @@ def test_create_multimodel( mock_model.custom_metadata_list = custom_metadata_list mock_from_id.return_value = mock_model - # testing _get_task when a user passes an invalid task to AquaMultiModelRef + # testing _extract_model_task when a user passes an invalid task to AquaMultiModelRef model_info_1.model_task = "invalid_task" with pytest.raises(AquaValueError): From d6b66f1145178391e1a090495463f07e7d0e3827 Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Thu, 24 Apr 2025 11:30:35 -0700 Subject: [PATCH 09/11] updated error messages --- ads/aqua/model/model.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index f7659b4b9..11c081b30 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -81,6 +81,7 @@ ImportModelDetails, ModelValidationResult, ) +from ads.aqua.model.enums import MultiModelSupportedTaskType from ads.common.auth import default_signer from ads.common.oci_resource import SEARCH_TYPE, OCIResource from ads.common.utils import ( @@ -184,12 +185,8 @@ def create( target_project = project_id or PROJECT_OCID target_compartment = compartment_id or COMPARTMENT_OCID - # Skip model copying if it is registered model or fine-tuned model - if ( - service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None - or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG) - is not None - ): + # Skip model copying if it is registered model + if service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None: logger.info( f"Aqua Model {model_id} already exists in the user's compartment." "Skipped copying." @@ -716,7 +713,8 @@ def _extract_model_task( model.model_task = task_tag else: raise AquaValueError( - f"{task_tag} is not supported. Valid model_task inputs are: {MultiModelSupportedTaskType.values()}." + f"Invalid or missing {task_tag} tag for selected model {source_model.display_name}. " + f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment." ) def _fetch_metric_from_metadata( From c21d4cbc77cb0fc54214a4a75a6a2e6e20c1061e Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Thu, 24 Apr 2025 11:48:52 -0700 Subject: [PATCH 10/11] fixed diff for model.py --- ads/aqua/model/model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 11c081b30..f770829aa 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -185,8 +185,12 @@ def create( target_project = project_id or PROJECT_OCID target_compartment = compartment_id or COMPARTMENT_OCID - # Skip model copying if it is registered model - if service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None: + # Skip model copying if it is registered model or fine-tuned model + if ( + service_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) is not None + or service_model.freeform_tags.get(Tags.AQUA_FINE_TUNED_MODEL_TAG) + is not None + ): logger.info( f"Aqua Model {model_id} already exists in the user's compartment." "Skipped copying." From 2bb7a9f0b06f4c95a2bb54d276e3d25b0d8c7022 Mon Sep 17 00:00:00 2001 From: Liz Johnson Date: Thu, 24 Apr 2025 14:54:41 -0700 Subject: [PATCH 11/11] added comment about revisiting logic --- ads/aqua/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index f770829aa..7fe419280 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -712,7 +712,7 @@ def _extract_model_task( model.model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN) task_tag = re.sub(r"-", "_", model.model_task).lower() - + # re-visit logic when more model task types are supported if task_tag in MultiModelSupportedTaskType: model.model_task = task_tag else: