Skip to content

[AQUA] Adding ADS support for embedding models in Multi Model Deployment #1163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 25, 2025

Conversation

elizjo
Copy link
Member

@elizjo elizjo commented Apr 23, 2025

Screenshot 2025-04-22 at 6 49 37 PM

This PR adds support of using embedding models in a multi model deployment.

To accomplish this, we have to pass data specifying the model task from ADS to the MULTI_MODEL_CONFIG environment variable in model deployment.

##Before PR
MULTI_MODEL_CONFIG=
'{ "models": [ 
{ "params": "--served-model-name bge 
--tensor-parallel-size 1 
--trust-remote-code 
--max-model-len 4096", 
"model_path": "bge-m3" } , { "params": "--served-model-name llama --enforce-eager --max-num-seqs 16 --tensor-parallel-size 2 --max-model-len 16000", "model_path": "Llama-3.2-11B-Vision" }] }'

##After PR- see 'model_task' key
MULTI_MODEL_CONFIG=
'{ "models": [  
{ "params": "--served-model-name bge
 --tensor-parallel-size 1 
--trust-remote-code 
--max-model-len 4096", 
"model_path": "bge-m3", 
"model_task": "embedding" }, { "params": "--served-model-name llama --enforce-eager --max-num-seqs 16 --tensor-parallel-size 2 --max-model-len 16000", "model_path": "Llama-3.2-11B-Vision" }, ] }'

We only have 'model_task' key for embedding models used in a multi model deployment.

  • We extract whether a model is an embedding model by reading the freeform tags ('task' tag) of the model (to determine whether model is an embedding model)
  • We added model_task as an optional parameter in the AquaMultiModelRef object
  • the model_task parameter is used to construct the MULTI_MODEL_CONFIG which has added (model_task: "embedding") key.

All unit tests pass (see screenshot). This PR was tested by modifying the existing unit test test_create_deployment_for_multi_model.

@oracle-contributor-agreement oracle-contributor-agreement bot added the OCA Verified All contributors have signed the Oracle Contributor Agreement. label Apr 23, 2025
Copy link

📌 Cov diff with main:

Coverage-93%

📌 Overall coverage:

Coverage-58.60%

@mrDzurb mrDzurb changed the title Adding ADS support for embedding models in Multi Model Deployment [AQUA] Adding ADS support for embedding models in Multi Model Deployment Apr 23, 2025
Copy link

📌 Cov diff with main:

Coverage-93%

📌 Overall coverage:

Coverage-58.60%

@@ -28,3 +28,8 @@ class FineTuningCustomMetadata(ExtendedEnum):
class MultiModelSupportedTaskType(ExtendedEnum):
TEXT_GENERATION = "text-generation"
TEXT_GENERATION_ALT = "text_generation"
EMBEDDING_ALT = "text_embedding"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we add embedding as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed- we add embedding in SMC level

@@ -316,6 +316,11 @@ def create_multi(

display_name_list.append(display_name)

model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather to move this logic to the _get_task() method

model.taks = self._get_task(model, source_model)
def _get_task(model_ref:AquaMultiModelRef, source_model: DataScienceModel) -> str:
    # extract task from model_ref by itself, if task is not presented there, then extract it from the freeform tags. 
    # model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN)
    
    ....
    return taks

I believe we should also allow users to pass task within AquaMultiModelRef, just in case if the tags were not populated well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we allow user to pass task, or if not provided, use the freeform tags of the source model

Copy link

📌 Cov diff with main:

Coverage-0%

📌 Overall coverage:

Coverage-19.16%

Copy link

📌 Cov diff with main:

Coverage-0%

📌 Overall coverage:

Coverage-19.16%

model.model_task = task_tag
else:
raise AquaValueError(
f"{task_tag} is not supported. Valid model_task inputs are: {MultiModelSupportedTaskType.values()}."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of empty task_tag, what the error will look like?

@@ -707,6 +700,25 @@ def edit_registered_model(
else:
raise AquaRuntimeError("Only registered unverified models can be edited.")

def _get_task(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this method doesn't return any value, yet its signature indicates a return type of str. Should we update the type hint to reflect that it returns None, or adjust the implementation to return a string as specified?

display_name_list.append(display_name)

self._get_task(model, source_model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think It might be clearer if we do something like this:

model.model_task =  self._extract_model_task(model, source_model)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if task_tag in MultiModelSupportedTaskType:
model.model_task = task_tag
else:
raise AquaValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can show more informative error:

raise AquaValueError(
     f"Invalid or missing {task_tag} tag for selected model {display_name}. "
     f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we removed the task level validation in the recent release, any reason to add the validation in the function `_extract_model_task again?

This is fine for now since we only have 1 verified embedding model, but if in the future if we start supporting (unverified) models, embedding models could have task value as feature_extraction or sentence_similarity. Might be good to add a comment here to reconsider this logic when we start supporting additional models.

Copy link

📌 Cov diff with main:

Coverage-100%

📌 Overall coverage:

Coverage-58.60%

mrDzurb
mrDzurb previously approved these changes Apr 24, 2025
Copy link
Member

@VipulMascarenhas VipulMascarenhas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comment, overall looks good.

if task_tag in MultiModelSupportedTaskType:
model.model_task = task_tag
else:
raise AquaValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we removed the task level validation in the recent release, any reason to add the validation in the function `_extract_model_task again?

This is fine for now since we only have 1 verified embedding model, but if in the future if we start supporting (unverified) models, embedding models could have task value as feature_extraction or sentence_similarity. Might be good to add a comment here to reconsider this logic when we start supporting additional models.

@elizjo elizjo dismissed stale reviews from VipulMascarenhas and mrDzurb via 2bb7a9f April 24, 2025 21:54
Copy link

📌 Cov diff with main:

Coverage-100%

📌 Overall coverage:

Coverage-58.60%

@elizjo elizjo merged commit 8d7b9d5 into main Apr 25, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
OCA Verified All contributors have signed the Oracle Contributor Agreement.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants