Skip to content

Commit 6150fb9

Browse files
speedstorm1copybara-github
authored andcommitted
chore: Update tuning tests and the Tuning validation dataset field conversion
FUTURE_COPYBARA_INTEGRATE_REVIEW=#1526 from googleapis:release-please--branches--main f06d3b9 PiperOrigin-RevId: 822184287
1 parent c2bbe11 commit 6150fb9

File tree

3 files changed

+18
-71
lines changed

3 files changed

+18
-71
lines changed

google/genai/tests/tunings/test_tune.py

100644100755
Lines changed: 12 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
pytest_helper.TestTableItem(
3636
name="test_dataset_gcs_uri",
3737
parameters=genai_types.CreateTuningJobParameters(
38-
base_model="gemini-1.5-pro-002",
38+
base_model="gemini-2.5-flash",
3939
training_dataset=genai_types.TuningDataset(
4040
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
4141
),
@@ -81,7 +81,7 @@
8181
pytest_helper.TestTableItem(
8282
name="test_dataset_gcs_uri_all_parameters",
8383
parameters=genai_types.CreateTuningJobParameters(
84-
base_model="gemini-1.5-pro-002",
84+
base_model="gemini-2.5-flash",
8585
training_dataset=genai_types.TuningDataset(
8686
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
8787
),
@@ -90,92 +90,39 @@
9090
epoch_count=1,
9191
learning_rate_multiplier=1.0,
9292
adapter_size="ADAPTER_SIZE_ONE",
93-
validation_dataset=genai_types.TuningDataset(
93+
validation_dataset=genai_types.TuningValidationDataset(
9494
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_validation_data.jsonl",
9595
),
96-
# Not supported in Vertex AI
97-
# batch_size=4,
98-
# learning_rate=0.01,
99-
),
100-
),
101-
exception_if_mldev="gcs_uri parameter is not supported in Gemini API.",
102-
),
103-
pytest_helper.TestTableItem(
104-
name="test_dataset_gcs_uri_parameters_unsupported_by_vertex",
105-
parameters=genai_types.CreateTuningJobParameters(
106-
base_model="gemini-1.5-pro-002",
107-
training_dataset=genai_types.TuningDataset(
108-
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_train_data.jsonl",
109-
),
110-
config=genai_types.CreateTuningJobConfig(
111-
# Not supported in Vertex AI
112-
batch_size=4,
113-
learning_rate=0.01,
11496
),
11597
),
116-
exception_if_vertex="batch_size parameter is not supported in Vertex AI.",
11798
exception_if_mldev="gcs_uri parameter is not supported in Gemini API.",
11899
),
119-
pytest_helper.TestTableItem(
120-
name="test_dataset_examples_parameters_unsupported_by_mldev",
121-
parameters=genai_types.CreateTuningJobParameters(
122-
# Error: "models/gemini-1.5-pro-002 is not found for
123-
# CREATE TUNED MODEL at API version v1beta."
124-
# base_model="gemini-1.5-pro-002",
125-
base_model="models/gemini-1.0-pro-001",
126-
training_dataset=genai_types.TuningDataset(
127-
examples=[
128-
genai_types.TuningExample(
129-
text_input=f"Input text {i}",
130-
output=f"Output text {i}",
131-
)
132-
for i in range(5)
133-
],
134-
),
135-
# Required for MLDev:
136-
# "Either tuned_model_id or display_name must be set."
137-
config=genai_types.CreateTuningJobConfig(
138-
tuned_model_display_name="Model display name",
139-
# Not supported in MLDev
140-
adapter_size="ADAPTER_SIZE_ONE",
141-
# Generator issue: "validationDatasetUri": {}. See b/375079287
142-
# validation_dataset=genai_types.TuningDataset(
143-
# gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-1_5/text/sft_validation_data.jsonl",
144-
# ),
145-
),
146-
),
147-
exception_if_mldev="adapter_size parameter is not supported in Gemini API.",
148-
exception_if_vertex="examples parameter is not supported in Vertex AI.",
149-
),
150100
pytest_helper.TestTableItem(
151101
name="test_dataset_vertex_dataset_resource",
152102
parameters=genai_types.CreateTuningJobParameters(
153-
base_model="gemini-1.5-pro-002",
103+
base_model="gemini-2.5-flash",
154104
training_dataset=genai_types.TuningDataset(
155-
vertex_dataset_resource="projects/613165508263/locations/us-central1/datasets/8254568702121345024",
105+
vertex_dataset_resource="projects/801452371447/locations/us-central1/datasets/5779918772206829568",
156106
),
157107
),
158108
exception_if_mldev="vertex_dataset_resource parameter is not supported in Gemini API.",
159109
),
160110
pytest_helper.TestTableItem(
161111
name="test_dataset_dataset_resource_all_parameters",
162112
parameters=genai_types.CreateTuningJobParameters(
163-
base_model="gemini-1.5-pro-002",
113+
base_model="gemini-2.5-flash",
164114
training_dataset=genai_types.TuningDataset(
165-
vertex_dataset_resource="projects/613165508263/locations/us-central1/datasets/8254568702121345024",
115+
vertex_dataset_resource="projects/801452371447/locations/us-central1/datasets/5779918772206829568",
166116
),
167117
config=genai_types.CreateTuningJobConfig(
168118
tuned_model_display_name="Model display name",
169119
epoch_count=1,
170120
learning_rate_multiplier=1.0,
171121
adapter_size="ADAPTER_SIZE_ONE",
172-
validation_dataset=genai_types.TuningDataset(
173-
vertex_dataset_resource="projects/613165508263/locations/us-central1/datasets/5556912525326417920",
122+
validation_dataset=genai_types.TuningValidationDataset(
123+
vertex_dataset_resource="projects/801452371447/locations/us-central1/datasets/1168232753779441664",
174124
),
175125
labels={"testlabelkey": "testlabelvalue"},
176-
# Not supported in Vertex AI
177-
# batch_size=4,
178-
# learning_rate=0.01,
179126
),
180127
),
181128
exception_if_mldev="vertex_dataset_resource parameter is not supported in Gemini API.",
@@ -208,7 +155,7 @@ def test_eval_config(client):
208155
epoch_count=1,
209156
learning_rate_multiplier=1.0,
210157
adapter_size="ADAPTER_SIZE_ONE",
211-
validation_dataset=genai_types.TuningDataset(
158+
validation_dataset=genai_types.TuningValidationDataset(
212159
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_validation_data.jsonl"
213160
),
214161
evaluation_config=evaluation_config,
@@ -255,7 +202,7 @@ def test_eval_config_with_metrics(client):
255202
epoch_count=1,
256203
learning_rate_multiplier=1.0,
257204
adapter_size="ADAPTER_SIZE_ONE",
258-
validation_dataset=genai_types.TuningDataset(
205+
validation_dataset=genai_types.TuningValidationDataset(
259206
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_validation_data.jsonl"
260207
),
261208
evaluation_config=evaluation_config,
@@ -280,7 +227,7 @@ async def test_eval_config_async(client):
280227
epoch_count=1,
281228
learning_rate_multiplier=1.0,
282229
adapter_size="ADAPTER_SIZE_ONE",
283-
validation_dataset=genai_types.TuningDataset(
230+
validation_dataset=genai_types.TuningValidationDataset(
284231
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini-2_0/text/sft_validation_data.jsonl"
285232
),
286233
evaluation_config=evaluation_config,

google/genai/tunings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,8 +700,8 @@ def _TuningValidationDataset_to_vertex(
700700

701701
if getv(from_object, ['vertex_dataset_resource']) is not None:
702702
setv(
703-
parent_object,
704-
['supervisedTuningSpec', 'trainingDatasetUri'],
703+
to_object,
704+
['validationDatasetUri'],
705705
getv(from_object, ['vertex_dataset_resource']),
706706
)
707707

google/genai/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10932,7 +10932,7 @@ class TuningValidationDataset(_common.BaseModel):
1093210932
)
1093310933
vertex_dataset_resource: Optional[str] = Field(
1093410934
default=None,
10935-
description="""The resource name of the Vertex Multimodal Dataset that is used as training dataset. Example: 'projects/my-project-id-or-number/locations/my-location/datasets/my-dataset-id'.""",
10935+
description="""The resource name of the Vertex Multimodal Dataset that is used as validation dataset. Example: 'projects/my-project-id-or-number/locations/my-location/datasets/my-dataset-id'.""",
1093610936
)
1093710937

1093810938

@@ -10942,7 +10942,7 @@ class TuningValidationDatasetDict(TypedDict, total=False):
1094210942
"""GCS URI of the file containing validation dataset in JSONL format."""
1094310943

1094410944
vertex_dataset_resource: Optional[str]
10945-
"""The resource name of the Vertex Multimodal Dataset that is used as training dataset. Example: 'projects/my-project-id-or-number/locations/my-location/datasets/my-dataset-id'."""
10945+
"""The resource name of the Vertex Multimodal Dataset that is used as validation dataset. Example: 'projects/my-project-id-or-number/locations/my-location/datasets/my-dataset-id'."""
1094610946

1094710947

1094810948
TuningValidationDatasetOrDict = Union[
@@ -10958,7 +10958,7 @@ class CreateTuningJobConfig(_common.BaseModel):
1095810958
)
1095910959
validation_dataset: Optional[TuningValidationDataset] = Field(
1096010960
default=None,
10961-
description="""Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file.""",
10961+
description="""Validation dataset for tuning. The dataset must be formatted as a JSONL file.""",
1096210962
)
1096310963
tuned_model_display_name: Optional[str] = Field(
1096410964
default=None,
@@ -11010,7 +11010,7 @@ class CreateTuningJobConfigDict(TypedDict, total=False):
1101011010
"""Used to override HTTP request options."""
1101111011

1101211012
validation_dataset: Optional[TuningValidationDatasetDict]
11013-
"""Cloud Storage path to file containing training dataset for tuning. The dataset must be formatted as a JSONL file."""
11013+
"""Validation dataset for tuning. The dataset must be formatted as a JSONL file."""
1101411014

1101511015
tuned_model_display_name: Optional[str]
1101611016
"""The display name of the tuned Model. The name can be up to 128 characters long and can consist of any UTF-8 characters."""

0 commit comments

Comments
 (0)