Skip to content

Commit cf2b7a9

Browse files
authored
feat(py/vertexai): Enhance VertexAI plugin (#2184)
1 parent 5348a01 commit cf2b7a9

File tree

14 files changed

+329
-72
lines changed

14 files changed

+329
-72
lines changed

py/packages/genkit/src/genkit/ai/embedding.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from collections.abc import Callable
5+
from typing import Any
56

67
from pydantic import BaseModel
78

@@ -14,6 +15,7 @@ class EmbedRequest(BaseModel):
1415
"""
1516

1617
documents: list[str]
18+
options: dict[str, Any] | None = None
1719

1820

1921
class EmbedResponse(BaseModel):

py/packages/genkit/src/genkit/veneer/veneer.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -435,21 +435,27 @@ def generate_stream(
435435
return (stream, stream.closed)
436436

437437
async def embed(
438-
self, model: str | None = None, documents: list[str] | None = None
438+
self,
439+
model: str | None = None,
440+
documents: list[str] | None = None,
441+
options: dict[str, Any] | None = None,
439442
) -> EmbedResponse:
440443
"""Calculates embeddings for documents.
441444
442445
Args:
443446
model: Optional embedder model name to use.
444447
documents: Texts to embed.
448+
options: embedding options
445449
446450
Returns:
447451
The generated response with embeddings.
448452
"""
449453
embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, model)
450454

451455
return (
452-
await embed_action.arun(EmbedRequest(documents=documents))
456+
await embed_action.arun(
457+
EmbedRequest(documents=documents, options=options)
458+
)
453459
).response
454460

455461

py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
enabling the use of Vertex AI models and services within the Genkit framework.
77
"""
88

9-
from genkit.plugins.vertex_ai.embedding import EmbeddingModels
9+
from genkit.plugins.vertex_ai.embedding import (
10+
EmbeddingModels,
11+
EmbeddingsTaskType,
12+
)
1013
from genkit.plugins.vertex_ai.gemini import GeminiVersion
11-
from genkit.plugins.vertex_ai.imagen import ImagenVersion
14+
from genkit.plugins.vertex_ai.imagen import ImagenOptions, ImagenVersion
1215
from genkit.plugins.vertex_ai.plugin_api import VertexAI, vertexai_name
1316

1417

@@ -26,6 +29,8 @@ def package_name() -> str:
2629
VertexAI.__name__,
2730
vertexai_name.__name__,
2831
EmbeddingModels.__name__,
32+
EmbeddingsTaskType.__name__,
2933
GeminiVersion.__name__,
3034
ImagenVersion.__name__,
35+
ImagenOptions.__name__,
3136
]

py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/embedding.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class EmbeddingModels(StrEnum):
2828
TEXT_EMBEDDING_002_MULTILINGUAL = 'text-multilingual-embedding-002'
2929

3030

31-
class TaskType(StrEnum):
31+
class EmbeddingsTaskType(StrEnum):
3232
"""Task types supported by Vertex AI.
3333
3434
Attributes:
@@ -47,24 +47,21 @@ class TaskType(StrEnum):
4747
CLUSTERING = 'CLUSTERING'
4848
RETRIEVAL_DOCUMENT = 'RETRIEVAL_DOCUMENT'
4949
RETRIEVAL_QUERY = 'RETRIEVAL_QUERY'
50+
QUESTION_ANSWERING = 'QUESTION_ANSWERING'
51+
FACT_VERIFICATION = 'FACT_VERIFICATION'
52+
CODE_RETRIEVAL_QUERY = 'CODE_RETRIEVAL_QUERY'
5053

5154

5255
class Embedder:
53-
"""Embedder for Vertex AI.
56+
"""Embedder for Vertex AI."""
5457

55-
Attributes:
56-
version: The version of the embedding model to use.
57-
task: The task type to use for the embedding.
58-
dimensionality: The dimensionality of the embedding.
59-
"""
60-
61-
TASK = TaskType.RETRIEVAL_QUERY
58+
TASK_KEY = 'task'
59+
DEFAULT_TASK = EmbeddingsTaskType.RETRIEVAL_QUERY
6260

6361
# By default, the model generates embeddings with 768 dimensions.
6462
# Models such as `text-embedding-004`, `text-embedding-005`,
6563
# and `text-multilingual-embedding-002`allow the output dimensionality
6664
# to be adjusted between 1 and 768.
67-
DIMENSIONALITY = 768
6865

6966
def __init__(self, version: EmbeddingModels):
7067
"""Initialize the embedder.
@@ -81,9 +78,11 @@ def embedding_model(self) -> TextEmbeddingModel:
8178
Returns:
8279
The embedding model.
8380
"""
81+
82+
# TODO: pass additional parameters
8483
return TextEmbeddingModel.from_pretrained(self._version)
8584

86-
def handle_request(self, request: EmbedRequest) -> EmbedResponse:
85+
def generate(self, request: EmbedRequest) -> EmbedResponse:
8786
"""Handle an embedding request.
8887
8988
Args:
@@ -92,11 +91,18 @@ def handle_request(self, request: EmbedRequest) -> EmbedResponse:
9291
Returns:
9392
The embedding response.
9493
"""
95-
inputs = [
96-
TextEmbeddingInput(text, self.TASK) for text in request.documents
97-
]
98-
vertexai_embeddings = self.embedding_model.get_embeddings(inputs)
94+
options = request.options
95+
task = options.get(self.TASK_KEY) if options else self.DEFAULT_TASK
96+
if task not in EmbeddingsTaskType:
97+
raise ValueError(f'Unsupported task {task} for VertexAI.')
98+
99+
del options[self.TASK_KEY]
100+
inputs = [TextEmbeddingInput(text, task) for text in request.documents]
101+
vertexai_embeddings = self.embedding_model.get_embeddings(
102+
inputs, **options
103+
)
99104
embeddings = [embedding.values for embedding in vertexai_embeddings]
105+
100106
return EmbedResponse(embeddings=embeddings)
101107

102108
@property

py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/gemini.py

+81-16
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,29 @@
88
definitions and a client class for making requests to Gemini models.
99
"""
1010

11+
import logging
1112
from enum import StrEnum
1213
from typing import Any
1314

15+
import vertexai.generative_models as genai
16+
from genkit.core.action import ActionRunContext
1417
from genkit.core.typing import (
18+
CustomPart,
19+
DataPart,
1520
GenerateRequest,
1621
GenerateResponse,
22+
GenerateResponseChunk,
23+
MediaPart,
1724
Message,
1825
ModelInfo,
1926
Role,
2027
Supports,
2128
TextPart,
29+
ToolRequestPart,
30+
ToolResponsePart,
2231
)
23-
from vertexai.generative_models import Content, GenerativeModel, Part
32+
33+
LOG = logging.getLogger(__name__)
2434

2535

2636
class GeminiVersion(StrEnum):
@@ -83,7 +93,7 @@ class Gemini:
8393
handling message formatting and response processing.
8494
"""
8595

86-
def __init__(self, version: str):
96+
def __init__(self, version: str | GeminiVersion):
8797
"""Initialize a Gemini client.
8898
8999
Args:
@@ -92,38 +102,93 @@ def __init__(self, version: str):
92102
"""
93103
self._version = version
94104

105+
def is_multimode(self):
106+
return SUPPORTED_MODELS[self._version].supports.media
107+
108+
def build_messages(self, request: GenerateRequest) -> list[genai.Content]:
109+
"""Builds a list of VertexAI content from a request.
110+
111+
Args:
112+
- request: a packed request for the model
113+
114+
Returns:
115+
- a list of VertexAI GenAI Content for the request
116+
"""
117+
messages: list[genai.Content] = []
118+
for message in request.messages:
119+
parts: list[genai.Part] = []
120+
for part in message.content:
121+
if isinstance(part.root, TextPart):
122+
parts.append(genai.Part.from_text(part.root.text))
123+
elif isinstance(part.root, MediaPart):
124+
if not self.is_multimode():
125+
LOG.error(
126+
f'The model {self._version} does not'
127+
f' support multimode input'
128+
)
129+
continue
130+
parts.append(
131+
genai.Part.from_uri(
132+
mime_type=part.root.media.content_type,
133+
uri=part.root.media.url,
134+
)
135+
)
136+
elif isinstance(part.root, ToolRequestPart | ToolResponsePart):
137+
LOG.warning('Tools are not supported yet')
138+
elif isinstance(part.root, CustomPart):
139+
# TODO: handle CustomPart
140+
LOG.warning('The code part is not supported yet.')
141+
else:
142+
LOG.error('The type is not supported')
143+
messages.append(genai.Content(role=message.role.value, parts=parts))
144+
145+
return messages
146+
95147
@property
96-
def gemini_model(self) -> GenerativeModel:
148+
def gemini_model(self) -> genai.GenerativeModel:
97149
"""Get the Vertex AI GenerativeModel instance.
98150
99151
Returns:
100152
A configured GenerativeModel instance for the specified version.
101153
"""
102-
return GenerativeModel(self._version)
154+
return genai.GenerativeModel(self._version)
103155

104-
def handle_request(self, request: GenerateRequest) -> GenerateResponse:
156+
def generate(
157+
self, request: GenerateRequest, ctx: ActionRunContext
158+
) -> GenerateResponse | None:
105159
"""Handle a generation request using the Gemini model.
106160
107161
Args:
108162
request: The generation request containing messages and parameters.
163+
ctx: additional context
109164
110165
Returns:
111166
The model's response to the generation request.
112167
"""
113-
messages: list[Content] = []
114-
for m in request.messages:
115-
parts: list[Part] = []
116-
for p in m.content:
117-
if p.root.text is not None:
118-
parts.append(Part.from_text(p.root.text))
119-
else:
120-
raise Exception('unsupported part type')
121-
messages.append(Content(role=m.role.value, parts=parts))
122-
response = self.gemini_model.generate_content(contents=messages)
168+
169+
messages = self.build_messages(request)
170+
response = self.gemini_model.generate_content(
171+
contents=messages, stream=ctx.is_streaming
172+
)
173+
174+
text_response = ''
175+
if ctx.is_streaming:
176+
for chunk in response:
177+
# TODO: Support other types of output
178+
ctx.send_chunk(
179+
GenerateResponseChunk(
180+
role=Role.MODEL,
181+
content=[TextPart(text=chunk.text)],
182+
)
183+
)
184+
185+
else:
186+
text_response = response.text
187+
123188
return GenerateResponse(
124189
message=Message(
125190
role=Role.MODEL,
126-
content=[TextPart(text=response.text)],
191+
content=[TextPart(text=text_response)],
127192
)
128193
)
129194

0 commit comments

Comments
 (0)