-
Notifications
You must be signed in to change notification settings - Fork 47
[AQUA] Adding ADS support for embedding models in Multi Model Deployment #1163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ads/aqua/model/enums.py
Outdated
@@ -28,3 +28,8 @@ class FineTuningCustomMetadata(ExtendedEnum): | |||
class MultiModelSupportedTaskType(ExtendedEnum): | |||
TEXT_GENERATION = "text-generation" | |||
TEXT_GENERATION_ALT = "text_generation" | |||
EMBEDDING_ALT = "text_embedding" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we add embedding as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed- we add embedding in SMC level
ads/aqua/model/model.py
Outdated
@@ -316,6 +316,11 @@ def create_multi( | |||
|
|||
display_name_list.append(display_name) | |||
|
|||
model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather to move this logic to the _get_task()
method
model.taks = self._get_task(model, source_model)
def _get_task(model_ref:AquaMultiModelRef, source_model: DataScienceModel) -> str:
# extract task from model_ref by itself, if task is not presented there, then extract it from the freeform tags.
# model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN)
....
return taks
I believe we should also allow users to pass task within AquaMultiModelRef
, just in case if the tags were not populated well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we allow user to pass task, or if not provided, use the freeform tags of the source model
ads/aqua/model/model.py
Outdated
model.model_task = task_tag | ||
else: | ||
raise AquaValueError( | ||
f"{task_tag} is not supported. Valid model_task inputs are: {MultiModelSupportedTaskType.values()}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case of empty task_tag, what the error will look like?
ads/aqua/model/model.py
Outdated
@@ -707,6 +700,25 @@ def edit_registered_model( | |||
else: | |||
raise AquaRuntimeError("Only registered unverified models can be edited.") | |||
|
|||
def _get_task( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this method doesn't return any value, yet its signature indicates a return type of str. Should we update the type hint to reflect that it returns None, or adjust the implementation to return a string as specified?
ads/aqua/model/model.py
Outdated
display_name_list.append(display_name) | ||
|
||
self._get_task(model, source_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think It might be clearer if we do something like this:
model.model_task = self._extract_model_task(model, source_model)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
if task_tag in MultiModelSupportedTaskType: | ||
model.model_task = task_tag | ||
else: | ||
raise AquaValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can show more informative error:
raise AquaValueError(
f"Invalid or missing {task_tag} tag for selected model {display_name}. "
f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we removed the task level validation in the recent release, any reason to add the validation in the function `_extract_model_task again?
This is fine for now since we only have 1 verified embedding model, but if in the future if we start supporting (unverified) models, embedding models could have task value as feature_extraction or sentence_similarity. Might be good to add a comment here to reconsider this logic when we start supporting additional models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor comment, overall looks good.
if task_tag in MultiModelSupportedTaskType: | ||
model.model_task = task_tag | ||
else: | ||
raise AquaValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we removed the task level validation in the recent release, any reason to add the validation in the function `_extract_model_task again?
This is fine for now since we only have 1 verified embedding model, but if in the future if we start supporting (unverified) models, embedding models could have task value as feature_extraction or sentence_similarity. Might be good to add a comment here to reconsider this logic when we start supporting additional models.
This PR adds support of using embedding models in a multi model deployment.
To accomplish this, we have to pass data specifying the model task from ADS to the MULTI_MODEL_CONFIG environment variable in model deployment.
We only have 'model_task' key for embedding models used in a multi model deployment.
All unit tests pass (see screenshot). This PR was tested by modifying the existing unit test test_create_deployment_for_multi_model.