Skip to content

Commit e65f09c

Browse files
authored
Enhance the container family validation for multi-model deployment (#1148)
2 parents e619a07 + 2ca57d9 commit e65f09c

File tree

4 files changed

+93
-6
lines changed

4 files changed

+93
-6
lines changed

ads/aqua/common/enums.py

+14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
from typing import Dict, List
6+
57
from ads.common.extended_enum import ExtendedEnum
68

79

@@ -106,3 +108,15 @@ class ModelFormat(ExtendedEnum):
106108
class Platform(ExtendedEnum):
107109
ARM_CPU = "ARM_CPU"
108110
NVIDIA_GPU = "NVIDIA_GPU"
111+
112+
113+
# This dictionary defines compatibility groups for container families.
114+
# The structure is:
115+
# - Key: The preferred container family to use when multiple compatible families are selected.
116+
# - Value: A list of all compatible families (including the preferred one).
117+
CONTAINER_FAMILY_COMPATIBILITY: Dict[str, List[str]] = {
118+
InferenceContainerTypeFamily.AQUA_VLLM_V1_CONTAINER_FAMILY: [
119+
InferenceContainerTypeFamily.AQUA_VLLM_V1_CONTAINER_FAMILY,
120+
InferenceContainerTypeFamily.AQUA_VLLM_CONTAINER_FAMILY,
121+
],
122+
}

ads/aqua/common/utils.py

+38
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from ads.aqua.common.entities import GPUShapesIndex
3939
from ads.aqua.common.enums import (
40+
CONTAINER_FAMILY_COMPATIBILITY,
4041
InferenceContainerParamType,
4142
InferenceContainerType,
4243
RqsAdditionalDetails,
@@ -1316,3 +1317,40 @@ def load_gpu_shapes_index(
13161317
)
13171318

13181319
return GPUShapesIndex(**data)
1320+
1321+
1322+
def get_preferred_compatible_family(selected_families: set[str]) -> str:
1323+
"""
1324+
Determines the preferred container family from a given set of container families.
1325+
1326+
This method is used in the context of multi-model deployment to handle cases
1327+
where models selected for deployment use different, but compatible, container families.
1328+
1329+
It checks the input `families` set against the `CONTAINER_FAMILY_COMPATIBILITY` map.
1330+
If a compatibility group exists that fully includes all the families in the input,
1331+
the corresponding key (i.e., the preferred family) is returned.
1332+
1333+
Parameters
1334+
----------
1335+
families : set[str]
1336+
A set of container family identifiers.
1337+
1338+
Returns
1339+
-------
1340+
Optional[str]
1341+
The preferred container family if all families are compatible within one group;
1342+
otherwise, returns `None` indicating that no compatible family group was found.
1343+
1344+
Example
1345+
-------
1346+
>>> get_preferred_compatible_family({"odsc-vllm-serving", "odsc-vllm-serving-v1"})
1347+
'odsc-vllm-serving-v1'
1348+
1349+
>>> get_preferred_compatible_family({"odsc-vllm-serving", "odsc-tgi-serving"})
1350+
None # Incompatible families
1351+
"""
1352+
for preferred, compatible_list in CONTAINER_FAMILY_COMPATIBILITY.items():
1353+
if selected_families.issubset(set(compatible_list)):
1354+
return preferred
1355+
1356+
return None

ads/aqua/model/model.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
get_artifact_path,
4040
get_container_config,
4141
get_hf_model_info,
42+
get_preferred_compatible_family,
4243
list_os_files_with_extension,
4344
load_config,
4445
read_file,
@@ -337,15 +338,25 @@ def create_multi(
337338

338339
selected_models_deployment_containers.add(deployment_container)
339340

340-
# Check if the all models in the group shares same container family
341-
if len(selected_models_deployment_containers) > 1:
341+
if not selected_models_deployment_containers:
342342
raise AquaValueError(
343-
"The selected models are associated with different container families: "
344-
f"{list(selected_models_deployment_containers)}."
345-
"For multi-model deployment, all models in the group must share the same container family."
343+
"None of the selected models are associated with a recognized container family. "
344+
"Please review the selected models, or select a different group of models."
346345
)
347346

348-
deployment_container = selected_models_deployment_containers.pop()
347+
# Check if the all models in the group shares same container family
348+
if len(selected_models_deployment_containers) > 1:
349+
deployment_container = get_preferred_compatible_family(
350+
selected_families=selected_models_deployment_containers
351+
)
352+
if not deployment_container:
353+
raise AquaValueError(
354+
"The selected models are associated with different container families: "
355+
f"{list(selected_models_deployment_containers)}."
356+
"For multi-model deployment, all models in the group must share the same container family."
357+
)
358+
else:
359+
deployment_container = selected_models_deployment_containers.pop()
349360

350361
# Generate model group details
351362
timestamp = datetime.now().strftime("%Y%m%d")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2025 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import pytest
8+
from ads.aqua.common.utils import get_preferred_compatible_family
9+
10+
11+
class TestCommonUtils:
12+
@pytest.mark.parametrize(
13+
"input_families, expected",
14+
[
15+
(
16+
{"odsc-vllm-serving", "odsc-vllm-serving-v1"},
17+
"odsc-vllm-serving-v1",
18+
),
19+
({"odsc-tgi-serving", "odsc-vllm-serving"}, None),
20+
({"non-existing-one", "odsc-tgi-serving"}, None),
21+
],
22+
)
23+
def test_get_preferred_compatible_family(self, input_families, expected):
24+
assert get_preferred_compatible_family(input_families) == expected

0 commit comments

Comments
 (0)