diff --git a/ads/aqua/common/enums.py b/ads/aqua/common/enums.py index 6144877ee..bf0cc7c99 100644 --- a/ads/aqua/common/enums.py +++ b/ads/aqua/common/enums.py @@ -2,6 +2,8 @@ # Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from typing import Dict, List + from ads.common.extended_enum import ExtendedEnum @@ -106,3 +108,15 @@ class ModelFormat(ExtendedEnum): class Platform(ExtendedEnum): ARM_CPU = "ARM_CPU" NVIDIA_GPU = "NVIDIA_GPU" + + +# This dictionary defines compatibility groups for container families. +# The structure is: +# - Key: The preferred container family to use when multiple compatible families are selected. +# - Value: A list of all compatible families (including the preferred one). +CONTAINER_FAMILY_COMPATIBILITY: Dict[str, List[str]] = { + InferenceContainerTypeFamily.AQUA_VLLM_V1_CONTAINER_FAMILY: [ + InferenceContainerTypeFamily.AQUA_VLLM_V1_CONTAINER_FAMILY, + InferenceContainerTypeFamily.AQUA_VLLM_CONTAINER_FAMILY, + ], +} diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index a1df4a99b..4d3251274 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -37,6 +37,7 @@ from ads.aqua.common.entities import GPUShapesIndex from ads.aqua.common.enums import ( + CONTAINER_FAMILY_COMPATIBILITY, InferenceContainerParamType, InferenceContainerType, RqsAdditionalDetails, @@ -1316,3 +1317,40 @@ def load_gpu_shapes_index( ) return GPUShapesIndex(**data) + + +def get_preferred_compatible_family(selected_families: set[str]) -> str: + """ + Determines the preferred container family from a given set of container families. + + This method is used in the context of multi-model deployment to handle cases + where models selected for deployment use different, but compatible, container families. + + It checks the input `families` set against the `CONTAINER_FAMILY_COMPATIBILITY` map. + If a compatibility group exists that fully includes all the families in the input, + the corresponding key (i.e., the preferred family) is returned. + + Parameters + ---------- + families : set[str] + A set of container family identifiers. + + Returns + ------- + Optional[str] + The preferred container family if all families are compatible within one group; + otherwise, returns `None` indicating that no compatible family group was found. + + Example + ------- + >>> get_preferred_compatible_family({"odsc-vllm-serving", "odsc-vllm-serving-v1"}) + 'odsc-vllm-serving-v1' + + >>> get_preferred_compatible_family({"odsc-vllm-serving", "odsc-tgi-serving"}) + None # Incompatible families + """ + for preferred, compatible_list in CONTAINER_FAMILY_COMPATIBILITY.items(): + if selected_families.issubset(set(compatible_list)): + return preferred + + return None diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 027985702..705f24717 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -39,6 +39,7 @@ get_artifact_path, get_container_config, get_hf_model_info, + get_preferred_compatible_family, list_os_files_with_extension, load_config, read_file, @@ -337,15 +338,25 @@ def create_multi( selected_models_deployment_containers.add(deployment_container) - # Check if the all models in the group shares same container family - if len(selected_models_deployment_containers) > 1: + if not selected_models_deployment_containers: raise AquaValueError( - "The selected models are associated with different container families: " - f"{list(selected_models_deployment_containers)}." - "For multi-model deployment, all models in the group must share the same container family." + "None of the selected models are associated with a recognized container family. " + "Please review the selected models, or select a different group of models." ) - deployment_container = selected_models_deployment_containers.pop() + # Check if the all models in the group shares same container family + if len(selected_models_deployment_containers) > 1: + deployment_container = get_preferred_compatible_family( + selected_families=selected_models_deployment_containers + ) + if not deployment_container: + raise AquaValueError( + "The selected models are associated with different container families: " + f"{list(selected_models_deployment_containers)}." + "For multi-model deployment, all models in the group must share the same container family." + ) + else: + deployment_container = selected_models_deployment_containers.pop() # Generate model group details timestamp = datetime.now().strftime("%Y%m%d") diff --git a/tests/unitary/with_extras/aqua/test_common_utils.py b/tests/unitary/with_extras/aqua/test_common_utils.py new file mode 100644 index 000000000..9c548ba85 --- /dev/null +++ b/tests/unitary/with_extras/aqua/test_common_utils.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import pytest +from ads.aqua.common.utils import get_preferred_compatible_family + + +class TestCommonUtils: + @pytest.mark.parametrize( + "input_families, expected", + [ + ( + {"odsc-vllm-serving", "odsc-vllm-serving-v1"}, + "odsc-vllm-serving-v1", + ), + ({"odsc-tgi-serving", "odsc-vllm-serving"}, None), + ({"non-existing-one", "odsc-tgi-serving"}, None), + ], + ) + def test_get_preferred_compatible_family(self, input_families, expected): + assert get_preferred_compatible_family(input_families) == expected