Skip to content

Commit 8d7b9d5

Browse files
authored
[AQUA] Adding ADS support for embedding models in Multi Model Deployment (#1163)
2 parents e3f1f20 + 2bb7a9f commit 8d7b9d5

File tree

6 files changed

+79
-23
lines changed

6 files changed

+79
-23
lines changed

ads/aqua/common/entities.py

+3
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ class AquaMultiModelRef(Serializable):
151151
The name of the model.
152152
gpu_count : Optional[int]
153153
Number of GPUs required for deployment.
154+
model_task : Optional[str]
155+
The task that model operates on. Supported tasks are in MultiModelSupportedTaskType
154156
env_var : Optional[Dict[str, Any]]
155157
Optional environment variables to override during deployment.
156158
artifact_location : Optional[str]
@@ -162,6 +164,7 @@ class AquaMultiModelRef(Serializable):
162164
gpu_count: Optional[int] = Field(
163165
None, description="The gpu count allocation for the model."
164166
)
167+
model_task: Optional[str] = Field(None, description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType")
165168
env_var: Optional[dict] = Field(
166169
default_factory=dict, description="The environment variables of the model."
167170
)

ads/aqua/model/enums.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,7 @@ class FineTuningCustomMetadata(ExtendedEnum):
2626

2727

2828
class MultiModelSupportedTaskType(ExtendedEnum):
29-
TEXT_GENERATION = "text-generation"
30-
TEXT_GENERATION_ALT = "text_generation"
29+
TEXT_GENERATION = "text_generation"
30+
IMAGE_TEXT_TO_TEXT = "image_text_to_text"
31+
CODE_SYNTHESIS = "code_synthesis"
32+
EMBEDDING = "text_embedding"

ads/aqua/model/model.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import os
66
import pathlib
7+
import re
78
from datetime import datetime, timedelta
89
from threading import Lock
910
from typing import Any, Dict, List, Optional, Set, Union
@@ -80,6 +81,7 @@
8081
ImportModelDetails,
8182
ModelValidationResult,
8283
)
84+
from ads.aqua.model.enums import MultiModelSupportedTaskType
8385
from ads.common.auth import default_signer
8486
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
8587
from ads.common.utils import (
@@ -307,18 +309,10 @@ def create_multi(
307309
# "Currently only service models are supported for multi model deployment."
308310
# )
309311

310-
# TODO uncomment the section below if only the specific types of models should be allowed for multi-model deployment
311-
# if (
312-
# source_model.freeform_tags.get(Tags.TASK, UNKNOWN).lower()
313-
# not in MultiModelSupportedTaskType
314-
# ):
315-
# raise AquaValueError(
316-
# f"Invalid or missing {Tags.TASK} tag for selected model {display_name}. "
317-
# f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment."
318-
# )
319-
320312
display_name_list.append(display_name)
321313

314+
self._extract_model_task(model, source_model)
315+
322316
# Retrieve model artifact
323317
model_artifact_path = source_model.artifact
324318
if not model_artifact_path:
@@ -707,6 +701,26 @@ def edit_registered_model(
707701
else:
708702
raise AquaRuntimeError("Only registered unverified models can be edited.")
709703

704+
def _extract_model_task(
705+
self,
706+
model: AquaMultiModelRef,
707+
source_model: DataScienceModel,
708+
) -> None:
709+
"""In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user"""
710+
# user does not supply model task, we extract from model metadata
711+
if not model.model_task:
712+
model.model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN)
713+
714+
task_tag = re.sub(r"-", "_", model.model_task).lower()
715+
# re-visit logic when more model task types are supported
716+
if task_tag in MultiModelSupportedTaskType:
717+
model.model_task = task_tag
718+
else:
719+
raise AquaValueError(
720+
f"Invalid or missing {task_tag} tag for selected model {source_model.display_name}. "
721+
f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment."
722+
)
723+
710724
def _fetch_metric_from_metadata(
711725
self,
712726
custom_metadata_list: ModelCustomMetadata,

ads/aqua/modeldeployment/deployment.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,7 @@ def create(
178178
# validate instance shape availability in compartment
179179
available_shapes = [
180180
shape.name.lower()
181-
for shape in self.list_shapes(
182-
compartment_id=compartment_id
183-
)
181+
for shape in self.list_shapes(compartment_id=compartment_id)
184182
]
185183

186184
if create_deployment_details.instance_shape.lower() not in available_shapes:
@@ -645,7 +643,11 @@ def _create_multi(
645643
os_path = ObjectStorageDetails.from_path(artifact_path_prefix)
646644
artifact_path_prefix = os_path.filepath.rstrip("/")
647645

648-
model_config.append({"params": params, "model_path": artifact_path_prefix})
646+
# override by-default completion/ chat endpoint with other endpoint (embedding)
647+
config_data = {"params": params, "model_path": artifact_path_prefix}
648+
if model.model_task:
649+
config_data["model_task"] = model.model_task
650+
model_config.append(config_data)
649651
model_name_list.append(model.model_name)
650652

651653
env_var.update({AQUA_MULTI_MODEL_CONFIG: json.dumps({"models": model_config})})

tests/unitary/with_extras/aqua/test_deployment.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
ModelDeploymentConfigSummary,
4646
ModelParams,
4747
)
48+
from ads.aqua.model.enums import MultiModelSupportedTaskType
4849
from ads.aqua.modeldeployment.utils import MultiModelDeploymentConfigLoader
4950
from ads.model.datascience_model import DataScienceModel
5051
from ads.model.deployment.model_deployment import ModelDeployment
@@ -276,7 +277,7 @@ class TestDataset:
276277
"environment_configuration_type": "OCIR_CONTAINER",
277278
"environment_variables": {
278279
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions",
279-
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}',
280+
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "text_embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/", "model_task": "image_text_to_text"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/", "model_task": "code_synthesis"}]}',
280281
},
281282
"health_check_port": 8080,
282283
"image": "dsmc://image-name:1.0.0.0",
@@ -486,27 +487,30 @@ class TestDataset:
486487
"gpu_count": 2,
487488
"model_id": "test_model_id_1",
488489
"model_name": "test_model_1",
490+
"model_task": "text_embedding",
489491
"artifact_location": "test_location_1",
490492
},
491493
{
492494
"env_var": {},
493495
"gpu_count": 2,
494496
"model_id": "test_model_id_2",
495497
"model_name": "test_model_2",
498+
"model_task": "image_text_to_text",
496499
"artifact_location": "test_location_2",
497500
},
498501
{
499502
"env_var": {},
500503
"gpu_count": 2,
501504
"model_id": "test_model_id_3",
502505
"model_name": "test_model_3",
506+
"model_task": "code_synthesis",
503507
"artifact_location": "test_location_3",
504508
},
505509
],
506510
"model_id": "ocid1.datasciencemodel.oc1.<region>.<OCID>",
507511
"environment_variables": {
508512
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions",
509-
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}',
513+
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "text_embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/", "model_task": "image_text_to_text"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/", "model_task": "code_synthesis"}]}',
510514
},
511515
"cmd": [],
512516
"console_link": "https://cloud.oracle.com/data-science/model-deployments/ocid1.datasciencemodeldeployment.oc1.<region>.<MD_OCID>?region=region-name",
@@ -965,20 +969,23 @@ class TestDataset:
965969
"gpu_count": 1,
966970
"model_id": "ocid1.compartment.oc1..<OCID>",
967971
"model_name": "model_one",
972+
"model_task": "text_embedding",
968973
"artifact_location": "artifact_location_one",
969974
},
970975
{
971976
"env_var": {"--test_key_two": "test_value_two"},
972977
"gpu_count": 1,
973978
"model_id": "ocid1.compartment.oc1..<OCID>",
974979
"model_name": "model_two",
980+
"model_task": "image_text_to_text",
975981
"artifact_location": "artifact_location_two",
976982
},
977983
{
978984
"env_var": {"--test_key_three": "test_value_three"},
979985
"gpu_count": 1,
980986
"model_id": "ocid1.compartment.oc1..<OCID>",
981987
"model_name": "model_three",
988+
"model_task": "code_synthesis",
982989
"artifact_location": "artifact_location_three",
983990
},
984991
]
@@ -1787,20 +1794,23 @@ def test_create_deployment_for_multi_model(
17871794
model_info_1 = AquaMultiModelRef(
17881795
model_id="test_model_id_1",
17891796
model_name="test_model_1",
1797+
model_task="text_embedding",
17901798
gpu_count=2,
17911799
artifact_location="test_location_1",
17921800
)
17931801

17941802
model_info_2 = AquaMultiModelRef(
17951803
model_id="test_model_id_2",
17961804
model_name="test_model_2",
1805+
model_task="image_text_to_text",
17971806
gpu_count=2,
17981807
artifact_location="test_location_2",
17991808
)
18001809

18011810
model_info_3 = AquaMultiModelRef(
18021811
model_id="test_model_id_3",
18031812
model_name="test_model_3",
1813+
model_task="code_synthesis",
18041814
gpu_count=2,
18051815
artifact_location="test_location_3",
18061816
)
@@ -1826,6 +1836,7 @@ def test_create_deployment_for_multi_model(
18261836

18271837
expected_attributes = set(AquaDeployment.__annotations__.keys())
18281838
actual_attributes = result.to_dict()
1839+
18291840
assert set(actual_attributes) == set(expected_attributes), "Attributes mismatch"
18301841
expected_result = copy.deepcopy(TestDataset.aqua_multi_deployment_object)
18311842
expected_result["state"] = "CREATING"

tests/unitary/with_extras/aqua/test_model.py

+29-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import json
77
import os
8+
import re
89
import shlex
910
import tempfile
1011
from dataclasses import asdict
@@ -13,17 +14,14 @@
1314

1415
import oci
1516
import pytest
16-
17-
from ads.aqua.app import AquaApp
18-
from ads.aqua.config.container_config import AquaContainerConfig
1917
from huggingface_hub.hf_api import HfApi, ModelInfo
2018
from parameterized import parameterized
2119

2220
import ads.aqua.model
2321
import ads.common
2422
import ads.common.oci_client
2523
import ads.config
26-
24+
from ads.aqua.app import AquaApp
2725
from ads.aqua.common.entities import AquaMultiModelRef
2826
from ads.aqua.common.enums import ModelFormat, Tags
2927
from ads.aqua.common.errors import (
@@ -32,6 +30,7 @@
3230
AquaValueError,
3331
)
3432
from ads.aqua.common.utils import get_hf_model_info
33+
from ads.aqua.config.container_config import AquaContainerConfig
3534
from ads.aqua.constants import HF_METADATA_FOLDER
3635
from ads.aqua.model import AquaModelApp
3736
from ads.aqua.model.entities import (
@@ -40,14 +39,14 @@
4039
ImportModelDetails,
4140
ModelValidationResult,
4241
)
42+
from ads.aqua.model.enums import MultiModelSupportedTaskType
4343
from ads.common.object_storage_details import ObjectStorageDetails
4444
from ads.model.datascience_model import DataScienceModel
4545
from ads.model.model_metadata import (
4646
ModelCustomMetadata,
4747
ModelProvenanceMetadata,
4848
ModelTaxonomyMetadata,
4949
)
50-
5150
from tests.unitary.with_extras.aqua.utils import ServiceManagedContainers
5251

5352

@@ -397,12 +396,14 @@ def test_create_multimodel(
397396
model_info_1 = AquaMultiModelRef(
398397
model_id="test_model_id_1",
399398
gpu_count=2,
399+
model_task = "text_embedding",
400400
env_var={"params": "--trust-remote-code --max-model-len 60000"},
401401
)
402402

403403
model_info_2 = AquaMultiModelRef(
404404
model_id="test_model_id_2",
405405
gpu_count=2,
406+
model_task = "image_text_to_text",
406407
env_var={"params": "--trust-remote-code --max-model-len 32000"},
407408
)
408409

@@ -439,6 +440,29 @@ def test_create_multimodel(
439440
mock_model.custom_metadata_list = custom_metadata_list
440441
mock_from_id.return_value = mock_model
441442

443+
# testing _extract_model_task when a user passes an invalid task to AquaMultiModelRef
444+
model_info_1.model_task = "invalid_task"
445+
446+
with pytest.raises(AquaValueError):
447+
model = self.app.create_multi(
448+
models=[model_info_1, model_info_2],
449+
project_id="test_project_id",
450+
compartment_id="test_compartment_id",
451+
)
452+
453+
# testing if a user tries to invoke a model with a task mode that is not yet supported
454+
model_info_1.model_task = None
455+
mock_model.freeform_tags["task"] = "unsupported_task"
456+
with pytest.raises(AquaValueError):
457+
model = self.app.create_multi(
458+
models=[model_info_1, model_info_2],
459+
project_id="test_project_id",
460+
compartment_id="test_compartment_id",
461+
)
462+
463+
mock_model.freeform_tags["task"] = "text-generation"
464+
model_info_1.model_task = "text_embedding"
465+
442466
# will create a multi-model group
443467
model = self.app.create_multi(
444468
models=[model_info_1, model_info_2],

0 commit comments

Comments
 (0)