|
35 | 35 | InferenceContainerParamType,
|
36 | 36 | InferenceContainerType,
|
37 | 37 | RqsAdditionalDetails,
|
| 38 | + TextEmbeddingInferenceContainerParams, |
38 | 39 | )
|
39 | 40 | from ads.aqua.common.errors import (
|
40 | 41 | AquaFileNotFoundError,
|
|
51 | 52 | MODEL_BY_REFERENCE_OSS_PATH_KEY,
|
52 | 53 | SERVICE_MANAGED_CONTAINER_URI_SCHEME,
|
53 | 54 | SUPPORTED_FILE_FORMATS,
|
| 55 | + TEI_CONTAINER_DEFAULT_HOST, |
54 | 56 | TGI_INFERENCE_RESTRICTED_PARAMS,
|
55 | 57 | UNKNOWN,
|
56 | 58 | UNKNOWN_JSON_STR,
|
|
63 | 65 | from ads.common.object_storage_details import ObjectStorageDetails
|
64 | 66 | from ads.common.oci_resource import SEARCH_TYPE, OCIResource
|
65 | 67 | from ads.common.utils import copy_file, get_console_link, upload_to_os
|
66 |
| -from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID |
| 68 | +from ads.config import ( |
| 69 | + AQUA_MODEL_DEPLOYMENT_FOLDER, |
| 70 | + AQUA_SERVICE_MODELS_BUCKET, |
| 71 | + CONDA_BUCKET_NS, |
| 72 | + TENANCY_OCID, |
| 73 | +) |
67 | 74 | from ads.model import DataScienceModel, ModelVersionSet
|
68 | 75 |
|
69 | 76 | logger = logging.getLogger("ads.aqua")
|
@@ -569,15 +576,13 @@ def get_container_image(
|
569 | 576 | A dict of allowed configs.
|
570 | 577 | """
|
571 | 578 |
|
| 579 | + container_image = UNKNOWN |
572 | 580 | config = config_file_name or get_container_config()
|
573 | 581 | config_file_name = service_config_path()
|
574 | 582 |
|
575 | 583 | if container_type not in config:
|
576 |
| - raise AquaValueError( |
577 |
| - f"{config_file_name} does not have config details for model: {container_type}" |
578 |
| - ) |
| 584 | + return UNKNOWN |
579 | 585 |
|
580 |
| - container_image = None |
581 | 586 | mapping = config[container_type]
|
582 | 587 | versions = [obj["version"] for obj in mapping]
|
583 | 588 | # assumes numbered versions, update if `latest` is used
|
@@ -1078,3 +1083,76 @@ def list_hf_models(query: str) -> List[str]:
|
1078 | 1083 | return [model.id for model in models if model.disabled is None]
|
1079 | 1084 | except HfHubHTTPError as err:
|
1080 | 1085 | raise format_hf_custom_error_message(err) from err
|
| 1086 | + |
| 1087 | + |
| 1088 | +def generate_tei_cmd_var(os_path: str) -> List[str]: |
| 1089 | + """This utility functions generates CMD params for Text Embedding Inference container. Only the |
| 1090 | + essential parameters for OCI model deployment are added, defaults are used for the rest. |
| 1091 | + Parameters |
| 1092 | + ---------- |
| 1093 | + os_path: str |
| 1094 | + OCI bucket path where the model artifacts are uploaded - oci://bucket@namespace/prefix |
| 1095 | +
|
| 1096 | + Returns |
| 1097 | + ------- |
| 1098 | + cmd_var: |
| 1099 | + List of command line arguments |
| 1100 | + """ |
| 1101 | + |
| 1102 | + cmd_prefix = "--" |
| 1103 | + cmd_var = [ |
| 1104 | + f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.MODEL_ID}", |
| 1105 | + f"{AQUA_MODEL_DEPLOYMENT_FOLDER}{ObjectStorageDetails.from_path(os_path.rstrip('/')).filepath}/", |
| 1106 | + f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.PORT}", |
| 1107 | + TEI_CONTAINER_DEFAULT_HOST, |
| 1108 | + ] |
| 1109 | + |
| 1110 | + return cmd_var |
| 1111 | + |
| 1112 | + |
| 1113 | +def parse_cmd_var(cmd_list: List[str]) -> dict: |
| 1114 | + """Helper functions that parses a list into a key-value dictionary. The list contains keys separated by the prefix |
| 1115 | + '--' and the value of the key is the subsequent element. |
| 1116 | + """ |
| 1117 | + parsed_cmd = {} |
| 1118 | + |
| 1119 | + for i, cmd in enumerate(cmd_list): |
| 1120 | + if cmd.startswith("--"): |
| 1121 | + if i + 1 < len(cmd_list) and not cmd_list[i + 1].startswith("--"): |
| 1122 | + parsed_cmd[cmd] = cmd_list[i + 1] |
| 1123 | + i += 1 |
| 1124 | + else: |
| 1125 | + parsed_cmd[cmd] = None |
| 1126 | + return parsed_cmd |
| 1127 | + |
| 1128 | + |
| 1129 | +def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]: |
| 1130 | + """This function accepts two lists of parameters and combines them. If the second list shares the common parameter |
| 1131 | + names/keys, then it raises an error. |
| 1132 | + Parameters |
| 1133 | + ---------- |
| 1134 | + cmd_var: List[str] |
| 1135 | + Default list of parameters |
| 1136 | + overrides: List[str] |
| 1137 | + List of parameters to override |
| 1138 | + Returns |
| 1139 | + ------- |
| 1140 | + List[str] of combined parameters |
| 1141 | + """ |
| 1142 | + cmd_var = [str(x) for x in cmd_var] |
| 1143 | + if not overrides: |
| 1144 | + return cmd_var |
| 1145 | + overrides = [str(x) for x in overrides] |
| 1146 | + |
| 1147 | + cmd_dict = parse_cmd_var(cmd_var) |
| 1148 | + overrides_dict = parse_cmd_var(overrides) |
| 1149 | + |
| 1150 | + # check for conflicts |
| 1151 | + common_keys = set(cmd_dict.keys()) & set(overrides_dict.keys()) |
| 1152 | + if common_keys: |
| 1153 | + raise AquaValueError( |
| 1154 | + f"The following CMD input cannot be overridden for model deployment: {', '.join(common_keys)}" |
| 1155 | + ) |
| 1156 | + |
| 1157 | + combined_cmd_var = cmd_var + overrides |
| 1158 | + return combined_cmd_var |
0 commit comments