Skip to content

Commit 4936025

Browse files
[ODSC-63984] BYOC TEI deployment for embedding models (#975)
2 parents c5bb94e + cf81e28 commit 4936025

File tree

13 files changed

+668
-87
lines changed

13 files changed

+668
-87
lines changed

ads/aqua/common/enums.py

+9
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5252
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5353
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5454
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55+
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
5556

5657

5758
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
@@ -80,3 +81,11 @@ class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta):
8081
MODEL_VERSION_SET_NAME = "modelVersionSetName"
8182
PROJECT_ID = "projectId"
8283
VERSION_LABEL = "versionLabel"
84+
85+
86+
class TextEmbeddingInferenceContainerParams(str, metaclass=ExtendedEnumMeta):
87+
"""Contains a subset of params that are required for enabling model deployment in OCI Data Science. More options
88+
are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments"""
89+
90+
MODEL_ID = "model-id"
91+
PORT = "port"

ads/aqua/common/utils.py

+83-5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
InferenceContainerParamType,
3636
InferenceContainerType,
3737
RqsAdditionalDetails,
38+
TextEmbeddingInferenceContainerParams,
3839
)
3940
from ads.aqua.common.errors import (
4041
AquaFileNotFoundError,
@@ -51,6 +52,7 @@
5152
MODEL_BY_REFERENCE_OSS_PATH_KEY,
5253
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
5354
SUPPORTED_FILE_FORMATS,
55+
TEI_CONTAINER_DEFAULT_HOST,
5456
TGI_INFERENCE_RESTRICTED_PARAMS,
5557
UNKNOWN,
5658
UNKNOWN_JSON_STR,
@@ -63,7 +65,12 @@
6365
from ads.common.object_storage_details import ObjectStorageDetails
6466
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
6567
from ads.common.utils import copy_file, get_console_link, upload_to_os
66-
from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID
68+
from ads.config import (
69+
AQUA_MODEL_DEPLOYMENT_FOLDER,
70+
AQUA_SERVICE_MODELS_BUCKET,
71+
CONDA_BUCKET_NS,
72+
TENANCY_OCID,
73+
)
6774
from ads.model import DataScienceModel, ModelVersionSet
6875

6976
logger = logging.getLogger("ads.aqua")
@@ -569,15 +576,13 @@ def get_container_image(
569576
A dict of allowed configs.
570577
"""
571578

579+
container_image = UNKNOWN
572580
config = config_file_name or get_container_config()
573581
config_file_name = service_config_path()
574582

575583
if container_type not in config:
576-
raise AquaValueError(
577-
f"{config_file_name} does not have config details for model: {container_type}"
578-
)
584+
return UNKNOWN
579585

580-
container_image = None
581586
mapping = config[container_type]
582587
versions = [obj["version"] for obj in mapping]
583588
# assumes numbered versions, update if `latest` is used
@@ -1078,3 +1083,76 @@ def list_hf_models(query: str) -> List[str]:
10781083
return [model.id for model in models if model.disabled is None]
10791084
except HfHubHTTPError as err:
10801085
raise format_hf_custom_error_message(err) from err
1086+
1087+
1088+
def generate_tei_cmd_var(os_path: str) -> List[str]:
1089+
"""This utility functions generates CMD params for Text Embedding Inference container. Only the
1090+
essential parameters for OCI model deployment are added, defaults are used for the rest.
1091+
Parameters
1092+
----------
1093+
os_path: str
1094+
OCI bucket path where the model artifacts are uploaded - oci://bucket@namespace/prefix
1095+
1096+
Returns
1097+
-------
1098+
cmd_var:
1099+
List of command line arguments
1100+
"""
1101+
1102+
cmd_prefix = "--"
1103+
cmd_var = [
1104+
f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.MODEL_ID}",
1105+
f"{AQUA_MODEL_DEPLOYMENT_FOLDER}{ObjectStorageDetails.from_path(os_path.rstrip('/')).filepath}/",
1106+
f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.PORT}",
1107+
TEI_CONTAINER_DEFAULT_HOST,
1108+
]
1109+
1110+
return cmd_var
1111+
1112+
1113+
def parse_cmd_var(cmd_list: List[str]) -> dict:
1114+
"""Helper functions that parses a list into a key-value dictionary. The list contains keys separated by the prefix
1115+
'--' and the value of the key is the subsequent element.
1116+
"""
1117+
parsed_cmd = {}
1118+
1119+
for i, cmd in enumerate(cmd_list):
1120+
if cmd.startswith("--"):
1121+
if i + 1 < len(cmd_list) and not cmd_list[i + 1].startswith("--"):
1122+
parsed_cmd[cmd] = cmd_list[i + 1]
1123+
i += 1
1124+
else:
1125+
parsed_cmd[cmd] = None
1126+
return parsed_cmd
1127+
1128+
1129+
def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
1130+
"""This function accepts two lists of parameters and combines them. If the second list shares the common parameter
1131+
names/keys, then it raises an error.
1132+
Parameters
1133+
----------
1134+
cmd_var: List[str]
1135+
Default list of parameters
1136+
overrides: List[str]
1137+
List of parameters to override
1138+
Returns
1139+
-------
1140+
List[str] of combined parameters
1141+
"""
1142+
cmd_var = [str(x) for x in cmd_var]
1143+
if not overrides:
1144+
return cmd_var
1145+
overrides = [str(x) for x in overrides]
1146+
1147+
cmd_dict = parse_cmd_var(cmd_var)
1148+
overrides_dict = parse_cmd_var(overrides)
1149+
1150+
# check for conflicts
1151+
common_keys = set(cmd_dict.keys()) & set(overrides_dict.keys())
1152+
if common_keys:
1153+
raise AquaValueError(
1154+
f"The following CMD input cannot be overridden for model deployment: {', '.join(common_keys)}"
1155+
)
1156+
1157+
combined_cmd_var = cmd_var + overrides
1158+
return combined_cmd_var

ads/aqua/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,4 @@
8080
"--port",
8181
"--host",
8282
}
83+
TEI_CONTAINER_DEFAULT_HOST = "8080"

ads/aqua/model/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class ModelCustomMetadataFields(str, metaclass=ExtendedEnumMeta):
1717
DEPLOYMENT_CONTAINER = "deployment-container"
1818
EVALUATION_CONTAINER = "evaluation-container"
1919
FINETUNE_CONTAINER = "finetune-container"
20+
DEPLOYMENT_CONTAINER_URI = "deployment-container-uri"
2021

2122

2223
class ModelTask(str, metaclass=ExtendedEnumMeta):

ads/aqua/model/entities.py

+2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class AquaModel(AquaModelSummary, DataClassSerializable):
9898

9999
model_card: str = None
100100
inference_container: str = None
101+
inference_container_uri: str = None
101102
finetuning_container: str = None
102103
evaluation_container: str = None
103104
artifact_location: str = None
@@ -287,6 +288,7 @@ class ImportModelDetails(CLIBuilderMixin):
287288
compartment_id: Optional[str] = None
288289
project_id: Optional[str] = None
289290
model_file: Optional[str] = None
291+
inference_container_uri: Optional[str] = None
290292

291293
def __post_init__(self):
292294
self._command = "model register"

0 commit comments

Comments
 (0)