Skip to content

Commit de107b4

Browse files
speedstorm1copybara-github
authored andcommitted
feat: support hyperparameters in distillation tuning
PiperOrigin-RevId: 882708166
1 parent 1ddc853 commit de107b4

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

google/genai/tests/tunings/test_tune.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,25 @@
245245
),
246246
exception_if_mldev="not supported in Gemini API",
247247
),
248+
pytest_helper.TestTableItem(
249+
name="test_tune_oss_distillation_hyperparams",
250+
parameters=genai_types.CreateTuningJobParameters(
251+
base_model="qwen/qwen3@qwen3-4b",
252+
training_dataset=genai_types.TuningDataset(
253+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
254+
),
255+
config=genai_types.CreateTuningJobConfig(
256+
method="DISTILLATION",
257+
base_teacher_model="deepseek-ai/deepseek-r1-0528-maas",
258+
learning_rate=1e-4,
259+
batch_size=4,
260+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test",
261+
tuning_mode="TUNING_MODE_FULL",
262+
http_options=VERTEX_HTTP_OPTIONS,
263+
),
264+
),
265+
exception_if_mldev="not supported in Gemini API",
266+
),
248267
pytest_helper.TestTableItem(
249268
name="test_tune_encryption_spec",
250269
parameters=genai_types.CreateTuningJobParameters(

google/genai/tunings.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,13 @@ def _CreateTuningJobConfig_to_vertex(
409409
['supervisedTuningSpec', 'tuningMode'],
410410
getv(from_object, ['tuning_mode']),
411411
)
412+
elif discriminator == 'DISTILLATION':
413+
if getv(from_object, ['tuning_mode']) is not None:
414+
setv(
415+
parent_object,
416+
['distillationSpec', 'tuningMode'],
417+
getv(from_object, ['tuning_mode']),
418+
)
412419

413420
if getv(from_object, ['custom_base_model']) is not None:
414421
setv(
@@ -427,6 +434,13 @@ def _CreateTuningJobConfig_to_vertex(
427434
['supervisedTuningSpec', 'hyperParameters', 'batchSize'],
428435
getv(from_object, ['batch_size']),
429436
)
437+
elif discriminator == 'DISTILLATION':
438+
if getv(from_object, ['batch_size']) is not None:
439+
setv(
440+
parent_object,
441+
['distillationSpec', 'hyperParameters', 'batchSize'],
442+
getv(from_object, ['batch_size']),
443+
)
430444

431445
discriminator = getv(root_object, ['config', 'method'])
432446
if discriminator is None:
@@ -438,6 +452,13 @@ def _CreateTuningJobConfig_to_vertex(
438452
['supervisedTuningSpec', 'hyperParameters', 'learningRate'],
439453
getv(from_object, ['learning_rate']),
440454
)
455+
elif discriminator == 'DISTILLATION':
456+
if getv(from_object, ['learning_rate']) is not None:
457+
setv(
458+
parent_object,
459+
['distillationSpec', 'hyperParameters', 'learningRate'],
460+
getv(from_object, ['learning_rate']),
461+
)
441462

442463
discriminator = getv(root_object, ['config', 'method'])
443464
if discriminator is None:

google/genai/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13421,7 +13421,7 @@ class CreateTuningJobConfig(_common.BaseModel):
1342113421
default=None, description="""Adapter size for tuning."""
1342213422
)
1342313423
tuning_mode: Optional[TuningMode] = Field(
13424-
default=None, description="""Tuning mode for SFT tuning."""
13424+
default=None, description="""Tuning mode for tuning."""
1342513425
)
1342613426
custom_base_model: Optional[str] = Field(
1342713427
default=None,
@@ -13502,7 +13502,7 @@ class CreateTuningJobConfigDict(TypedDict, total=False):
1350213502
"""Adapter size for tuning."""
1350313503

1350413504
tuning_mode: Optional[TuningMode]
13505-
"""Tuning mode for SFT tuning."""
13505+
"""Tuning mode for tuning."""
1350613506

1350713507
custom_base_model: Optional[str]
1350813508
"""Custom base model for tuning. This is only supported for OSS models in Vertex."""

0 commit comments

Comments
 (0)