Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 12 additions & 65 deletions google/genai/tests/tunings/test_tune.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
pytest_helper.TestTableItem(
name="test_dataset_gcs_uri",
parameters=genai_types.CreateTuningJobParameters(
base_model="gemini-1.5-pro-002",
base_model="gemini-2.5-flash",
training_dataset=genai_types.TuningDataset(
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
),
Expand Down Expand Up @@ -81,7 +81,7 @@
pytest_helper.TestTableItem(
name="test_dataset_gcs_uri_all_parameters",
parameters=genai_types.CreateTuningJobParameters(
base_model="gemini-1.5-pro-002",
base_model="gemini-2.5-flash",
training_dataset=genai_types.TuningDataset(
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
),
Expand All @@ -90,92 +90,39 @@
epoch_count=1,
learning_rate_multiplier=1.0,
adapter_size="ADAPTER_SIZE_ONE",
validation_dataset=genai_types.TuningDataset(
validation_dataset=genai_types.TuningValidationDataset(
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_validation_data.jsonl",
),
# Not supported in Vertex AI
# batch_size=4,
# learning_rate=0.01,
),
),
exception_if_mldev="gcs_uri parameter is not supported in Gemini API.",
),
pytest_helper.TestTableItem(
name="test_dataset_gcs_uri_parameters_unsupported_by_vertex",
parameters=genai_types.CreateTuningJobParameters(
base_model="gemini-1.5-pro-002",
training_dataset=genai_types.TuningDataset(
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
),
config=genai_types.CreateTuningJobConfig(
# Not supported in Vertex AI
batch_size=4,
learning_rate=0.01,
),
),
exception_if_vertex="batch_size parameter is not supported in Vertex AI.",
exception_if_mldev="gcs_uri parameter is not supported in Gemini API.",
),
pytest_helper.TestTableItem(
name="test_dataset_examples_parameters_unsupported_by_mldev",
parameters=genai_types.CreateTuningJobParameters(
# Error: "models/gemini-1.5-pro-002 is not found for
# CREATE TUNED MODEL at API version v1beta."
# base_model="gemini-1.5-pro-002",
base_model="models/gemini-1.0-pro-001",
training_dataset=genai_types.TuningDataset(
examples=[
genai_types.TuningExample(
text_input=f"Input text {i}",
output=f"Output text {i}",
)
for i in range(5)
],
),
# Required for MLDev:
# "Either tuned_model_id or display_name must be set."
config=genai_types.CreateTuningJobConfig(
tuned_model_display_name="Model display name",
# Not supported in MLDev
adapter_size="ADAPTER_SIZE_ONE",
# Generator issue: "validationDatasetUri": {}. See b/375079287
# validation_dataset=genai_types.TuningDataset(
# gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_validation_data.jsonl",
# ),
),
),
exception_if_mldev="adapter_size parameter is not supported in Gemini API.",
exception_if_vertex="examples parameter is not supported in Vertex AI.",
),
pytest_helper.TestTableItem(
name="test_dataset_vertex_dataset_resource",
parameters=genai_types.CreateTuningJobParameters(
base_model="gemini-1.5-pro-002",
base_model="gemini-2.5-flash",
training_dataset=genai_types.TuningDataset(
vertex_dataset_resource="projects/613165508263/locations/us-central1/datasets/8254568702121345024",
vertex_dataset_resource="projects/801452371447/locations/us-central1/datasets/5779918772206829568",
),
),
exception_if_mldev="vertex_dataset_resource parameter is not supported in Gemini API.",
),
pytest_helper.TestTableItem(
name="test_dataset_dataset_resource_all_parameters",
parameters=genai_types.CreateTuningJobParameters(
base_model="gemini-1.5-pro-002",
base_model="gemini-2.5-flash",
training_dataset=genai_types.TuningDataset(
vertex_dataset_resource="projects/613165508263/locations/us-central1/datasets/8254568702121345024",
vertex_dataset_resource="projects/801452371447/locations/us-central1/datasets/5779918772206829568",
),
config=genai_types.CreateTuningJobConfig(
tuned_model_display_name="Model display name",
epoch_count=1,
learning_rate_multiplier=1.0,
adapter_size="ADAPTER_SIZE_ONE",
validation_dataset=genai_types.TuningDataset(
vertex_dataset_resource="projects/613165508263/locations/us-central1/datasets/5556912525326417920",
validation_dataset=genai_types.TuningValidationDataset(
vertex_dataset_resource="projects/801452371447/locations/us-central1/datasets/1168232753779441664",
),
labels={"testlabelkey": "testlabelvalue"},
# Not supported in Vertex AI
# batch_size=4,
# learning_rate=0.01,
),
),
exception_if_mldev="vertex_dataset_resource parameter is not supported in Gemini API.",
Expand Down Expand Up @@ -208,7 +155,7 @@ def test_eval_config(client):
epoch_count=1,
learning_rate_multiplier=1.0,
adapter_size="ADAPTER_SIZE_ONE",
validation_dataset=genai_types.TuningDataset(
validation_dataset=genai_types.TuningValidationDataset(
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_validation_data.jsonl"
),
evaluation_config=evaluation_config,
Expand Down Expand Up @@ -255,7 +202,7 @@ def test_eval_config_with_metrics(client):
epoch_count=1,
learning_rate_multiplier=1.0,
adapter_size="ADAPTER_SIZE_ONE",
validation_dataset=genai_types.TuningDataset(
validation_dataset=genai_types.TuningValidationDataset(
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_validation_data.jsonl"
),
evaluation_config=evaluation_config,
Expand All @@ -280,7 +227,7 @@ async def test_eval_config_async(client):
epoch_count=1,
learning_rate_multiplier=1.0,
adapter_size="ADAPTER_SIZE_ONE",
validation_dataset=genai_types.TuningDataset(
validation_dataset=genai_types.TuningValidationDataset(
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_validation_data.jsonl"
),
evaluation_config=evaluation_config,
Expand Down
4 changes: 2 additions & 2 deletions google/genai/tunings.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,8 @@ def _TuningValidationDataset_to_vertex(

if getv(from_object, ['vertex_dataset_resource']) is not None:
setv(
parent_object,
['supervisedTuningSpec', 'trainingDatasetUri'],
to_object,
['validationDatasetUri'],
getv(from_object, ['vertex_dataset_resource']),
)

Expand Down
8 changes: 4 additions & 4 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10932,7 +10932,7 @@ class TuningValidationDataset(_common.BaseModel):
)
vertex_dataset_resource: Optional[str] = Field(
default=None,
description="""The resource name of the Vertex Multimodal Dataset that is used as training dataset. Example: 'projects/my-project-id-or-number/locations/my-location/datasets/my-dataset-id'.""",
description="""The resource name of the Vertex Multimodal Dataset that is used as validation dataset. Example: 'projects/my-project-id-or-number/locations/my-location/datasets/my-dataset-id'.""",
)


Expand All @@ -10942,7 +10942,7 @@ class TuningValidationDatasetDict(TypedDict, total=False):
"""GCS URI of the file containing validation dataset in JSONL format."""

vertex_dataset_resource: Optional[str]
"""The resource name of the Vertex Multimodal Dataset that is used as training dataset. Example: 'projects/my-project-id-or-number/locations/my-location/datasets/my-dataset-id'."""
"""The resource name of the Vertex Multimodal Dataset that is used as validation dataset. Example: 'projects/my-project-id-or-number/locations/my-location/datasets/my-dataset-id'."""


TuningValidationDatasetOrDict = Union[
Expand All @@ -10958,7 +10958,7 @@ class CreateTuningJobConfig(_common.BaseModel):
)
validation_dataset: Optional[TuningValidationDataset] = Field(
default=None,
description="""Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""",
description="""Validation dataset for tuning. The dataset must be formatted as a JSONL file.""",
)
tuned_model_display_name: Optional[str] = Field(
default=None,
Expand Down Expand Up @@ -11010,7 +11010,7 @@ class CreateTuningJobConfigDict(TypedDict, total=False):
"""Used to override HTTP request options."""

validation_dataset: Optional[TuningValidationDatasetDict]
"""Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file."""
"""Validation dataset for tuning. The dataset must be formatted as a JSONL file."""

tuned_model_display_name: Optional[str]
"""The display name of the tuned Model. The name can be up to 128 characters long and can consist of any UTF-8 characters."""
Expand Down