diff --git a/google/genai/models.py b/google/genai/models.py index 01d8cd9eb..d0f313bec 100644 --- a/google/genai/models.py +++ b/google/genai/models.py @@ -714,6 +714,38 @@ def _EmbedContentConfig_to_vertex( return to_object +def _EmbedContentConfig_to_vertex_embed_content( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ['task_type']) is not None: + setv( + parent_object, + ['taskType'], + getv(from_object, ['task_type']), + ) + + if getv(from_object, ['title']) is not None: + setv(parent_object, ['title'], getv(from_object, ['title'])) + + if getv(from_object, ['output_dimensionality']) is not None: + setv( + parent_object, + ['outputDimensionality'], + getv(from_object, ['output_dimensionality']), + ) + + if getv(from_object, ['auto_truncate']) is not None: + setv( + parent_object, + ['autoTruncate'], + getv(from_object, ['auto_truncate']), + ) + + return to_object + def _EmbedContentParameters_to_mldev( api_client: BaseApiClient, from_object: Union[dict[str, Any], object], @@ -750,6 +782,41 @@ def _EmbedContentParameters_to_mldev( return to_object +def _EmbedContentParameters_to_vertex_embed_content( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['model']) is not None: + setv( + to_object, + ['_url', 'model'], + t.t_model(api_client, getv(from_object, ['model'])), + ) + + if getv(from_object, ['contents']) is not None: + contents = getv(from_object, ['contents']) + if len(contents) != 1: + raise ValueError( + 'Only a single input content is supported for the given model at this' + ' time. Ensure you are using the most recent version of the GenAI' + ' SDK.' + ) + setv( + to_object, + ['content'], + t.t_content(contents[0]), + ) + + if getv(from_object, ['config']) is not None: + _EmbedContentConfig_to_vertex_embed_content( + getv(from_object, ['config']), to_object + ) + + return to_object + + def _EmbedContentParameters_to_vertex( api_client: BaseApiClient, from_object: Union[dict[str, Any], object], @@ -830,6 +897,80 @@ def _EmbedContentResponse_from_vertex( return to_object +def _EmbedContentResponse_from_vertex_embed_content( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['sdkHttpResponse']) is not None: + setv( + to_object, ['sdk_http_response'], getv(from_object, ['sdkHttpResponse']) + ) + + if getv(from_object, ['embedding']) is not None: + setv( + to_object, + ['embeddings'], + [getv(from_object, ['embedding'])], + ) + + if getv(from_object, ['truncated']) is not None: + setv( + to_object, + ['statistics', 'truncated'], + getv(from_object, ['truncated']), + ) + + if getv(from_object, ['usageMetadata']) is not None: + setv( + to_object, + ['statistics', 'token_count'], + getv(from_object, ['usageMetadata', 'totalTokenCount']), + ) + + if getv(from_object, ['metadata']) is not None: + setv(to_object, ['metadata'], getv(from_object, ['metadata'])) + + return to_object + + +def _EmbedContentResponse_from_vertex_embed_content( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ['sdkHttpResponse']) is not None: + setv( + to_object, ['sdk_http_response'], getv(from_object, ['sdkHttpResponse']) + ) + + if getv(from_object, ['embedding']) is not None: + setv( + to_object, + ['embeddings'], + [getv(from_object, ['embedding'])], + ) + + if getv(from_object, ['truncated']) is not None: + setv( + to_object, + ['statistics', 'truncated'], + getv(from_object, ['truncated']), + ) + + if getv(from_object, ['usageMetadata']) is not None: + setv( + to_object, + ['statistics', 'token_count'], + getv(from_object, ['usageMetadata', 'totalTokenCount']), + ) + + if getv(from_object, ['metadata']) is not None: + setv(to_object, ['metadata'], getv(from_object, ['metadata'])) + + return to_object + + def _Endpoint_from_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -3783,6 +3924,10 @@ def _Video_to_vertex( return to_object +def _is_vertex_embed_content_model(model: str) -> bool: + return 'gemini' in model and 'embedding' in model and '001' not in model + + class Models(_api_module.BaseModule): def _generate_content( @@ -3982,24 +4127,32 @@ def embed_content( request_url_dict: Optional[dict[str, str]] + is_vertex_embed_content_model = _is_vertex_embed_content_model(model) + if self._api_client.vertexai: - request_dict = _EmbedContentParameters_to_vertex( - self._api_client, parameter_model - ) - request_url_dict = request_dict.get('_url') - if request_url_dict: - path = '{model}:predict'.format_map(request_url_dict) + # Special handling for models exposed on the Vertex EmbedContent API. + if is_vertex_embed_content_model: + request_dict = _EmbedContentParameters_to_vertex_embed_content( + self._api_client, parameter_model + ) + uri_format = '{model}:embedContent' else: - path = '{model}:predict' + request_dict = _EmbedContentParameters_to_vertex( + self._api_client, parameter_model + ) + uri_format = '{model}:predict' else: request_dict = _EmbedContentParameters_to_mldev( self._api_client, parameter_model ) - request_url_dict = request_dict.get('_url') - if request_url_dict: - path = '{model}:batchEmbedContents'.format_map(request_url_dict) - else: - path = '{model}:batchEmbedContents' + uri_format = '{model}:batchEmbedContents' + + request_url_dict = request_dict.get('_url') + if request_url_dict: + path = uri_format.format_map(request_url_dict) + else: + path = uri_format + query_params = request_dict.get('_query') if query_params: path = f'{path}?{urlencode(query_params)}' @@ -4023,9 +4176,13 @@ def embed_content( response_dict = {} if not response.body else json.loads(response.body) if self._api_client.vertexai: - response_dict = _EmbedContentResponse_from_vertex(response_dict) - - if not self._api_client.vertexai: + if is_vertex_embed_content_model: + response_dict = _EmbedContentResponse_from_vertex_embed_content( + response_dict + ) + else: + response_dict = _EmbedContentResponse_from_vertex(response_dict) + else: response_dict = _EmbedContentResponse_from_mldev(response_dict) return_value = types.EmbedContentResponse._from_response( diff --git a/google/genai/tests/models/test_embed_content.py b/google/genai/tests/models/test_embed_content.py index be3ea0162..ffd1bb526 100644 --- a/google/genai/tests/models/test_embed_content.py +++ b/google/genai/tests/models/test_embed_content.py @@ -24,7 +24,7 @@ from .. import pytest_helper -test_table: list[pytest_helper.TestTableItem] = [ +text_embedding_test_table: list[pytest_helper.TestTableItem] = [ pytest_helper.TestTableItem( name='test_single_text', parameters=types._EmbedContentParameters( @@ -76,11 +76,42 @@ ), ] +new_api_test_table: list[pytest_helper.TestTableItem] = [ + pytest_helper.TestTableItem( + name='test_vertex_new_api_text_only', + parameters=types._EmbedContentParameters( + model='gemini-embedding-2.0-exp-11-25', + contents=t.t_contents('What is your name?'), + ), + # Model not exposed on MLDev. + exception_if_mldev='not found', + ), + pytest_helper.TestTableItem( + name='test_vertex_new_api_text_only_with_config', + parameters=types._EmbedContentParameters( + model='gemini-embedding-2.0-exp-11-25', + contents=t.t_contents('What is your name?'), + config={ + 'output_dimensionality': 10, + 'title': 'test_title', + 'task_type': 'RETRIEVAL_DOCUMENT', + 'http_options': { + 'headers': {'test': 'headers'}, + }, + 'auto_truncate': True, + }, + ), + # auto_truncate not supported on MLDev. + exception_if_mldev='parameter is not supported', + ), +] + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), test_method='models.embed_content', - test_table=test_table, + test_table=[*text_embedding_test_table, *new_api_test_table], )