Skip to content

Commit 0095923

Browse files
authored
Merge pull request #100 from AI-Hypercomputer/carlosbus/training_v6e_gemma3_12b
Add recipes for Gemma3-12B on v6e
2 parents 1d6678a + 6472c99 commit 0095923

File tree

6 files changed

+293
-0
lines changed

6 files changed

+293
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Instructions for training Gemma3-12B-MaxText on TPU trillium (2 slices of v6e-256)
2+
3+
## XPK setup
4+
Please follow the [XPK_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/XPK_README.md) to create your GKE cluster with XPK
5+
6+
## Prep for Maxtext
7+
8+
### Install MaxText and Build Docker Image
9+
Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set:
10+
11+
In step 1, use the MaxText [tpu-recipes-v0.1.5](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.5) tag to run this recipe:
12+
```
13+
git checkout tpu-recipes-v0.1.5
14+
```
15+
16+
In step 3, use:
17+
```
18+
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable JAX_VERSION=0.7.0
19+
```
20+
21+
## Run Maxtext Gemma3-12B workloads on GKE
22+
23+
### Starting workload
24+
25+
From the MaxText root directory, start your Gemma3-12B workload.
26+
```
27+
python3 -m benchmarks.benchmark_runner xpk \
28+
--project=$PROJECT \
29+
--zone=$ZONE \
30+
--device_type=v6e-256 \
31+
--num_slices=2 \
32+
--cluster_name=${CLUSTER_NAME} \
33+
--base_output_directory=${OUTPUT_DIR} \
34+
--model_name="gemma3_12b_32768_2x_v6e256" \
35+
--base_docker_image=maxtext_base_image
36+
```
37+
38+
From your workload logs, you should start seeing step time logs like the following:
39+
```
40+
completed step: 29, seconds: 7.793, TFLOP/s/device: 328.139, Tokens/s/device: 4204.799, total_weights: 16777216, loss: 11.151
41+
```
42+
43+
### Workload Details
44+
45+
For reference, here are the `gemma3_12b_32768_2x_v6e256` workload details as found in `[email protected]`:
46+
47+
```
48+
MaxTextModel(
49+
model_name="gemma3-12b-32768-2x-v6e256",
50+
model_type="gemma3-12b",
51+
tuning_params={
52+
"per_device_batch_size": 1,
53+
"num_vocab_tiling": 16,
54+
"ici_fsdp_parallelism": 1,
55+
"ici_fsdp_transpose_parallelism": -1,
56+
"remat_policy": "custom",
57+
"decoder_layer_input": "device",
58+
"query_proj": "remat",
59+
"key_proj": "remat",
60+
"value_proj": "remat",
61+
"max_target_length": 32768,
62+
"attention": "flash",
63+
"gcs_metrics": True,
64+
"use_iota_embed": True,
65+
"dataset_path": "gs://max-datasets-rogue",
66+
"dataset_type": "synthetic",
67+
"reuse_example_batch": 1,
68+
"enable_checkpointing": False,
69+
"profiler": "xplane",
70+
"skip_first_n_steps_for_profiler": 10,
71+
"profiler_steps": 2,
72+
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
73+
"sa_block_q": 1024,
74+
"sa_block_kv": 1024,
75+
"sa_block_kv_compute": 1024,
76+
"sa_block_q_dkv": 512,
77+
"sa_block_kv_dkv": 2048,
78+
"sa_block_kv_dkv_compute": 512,
79+
"sa_block_q_dq": 1024,
80+
"sa_block_kv_dq": 1024,
81+
},
82+
xla_flags=(
83+
xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)
84+
),
85+
)
86+
```
87+
88+
This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/50bafeb98299458f73d853b1325787a6d241d10c/benchmarks/maxtext_trillium_model_configs.py) file within the MaxText repository.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Run this command from the MaxText root directory using the setup described in the README.
2+
python3 -m benchmarks.benchmark_runner xpk \
3+
--project=$PROJECT \
4+
--zone=$ZONE \
5+
--device_type=v6e-256 \
6+
--num_slices=2 \
7+
--cluster_name=${CLUSTER_NAME} \
8+
--base_output_directory=${OUTPUT_DIR} \
9+
--model_name="gemma3_12b_32768_2x_v6e256" \
10+
--base_docker_image=maxtext_base_image
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Instructions for training Gemma3-12B-MaxText on TPU trillium (4 slices of v6e-256)
2+
3+
## XPK setup
4+
Please follow the [XPK_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/XPK_README.md) to create your GKE cluster with XPK
5+
6+
## Prep for Maxtext
7+
8+
### Install MaxText and Build Docker Image
9+
Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set:
10+
11+
In step 1, use the MaxText [tpu-recipes-v0.1.5](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.5) tag to run this recipe:
12+
```
13+
git checkout tpu-recipes-v0.1.5
14+
```
15+
16+
In step 3, use:
17+
```
18+
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable JAX_VERSION=0.7.0
19+
```
20+
21+
## Run Maxtext Gemma3-12B workloads on GKE
22+
23+
### Starting workload
24+
25+
From the MaxText root directory, start your Gemma3-12B workload.
26+
```
27+
python3 -m benchmarks.benchmark_runner xpk \
28+
--project=$PROJECT \
29+
--zone=$ZONE \
30+
--device_type=v6e-256 \
31+
--num_slices=4 \
32+
--cluster_name=${CLUSTER_NAME} \
33+
--base_output_directory=${OUTPUT_DIR} \
34+
--model_name="gemma3_12b_32768_4x_v6e256" \
35+
--base_docker_image=maxtext_base_image
36+
```
37+
38+
From your workload logs, you should start seeing step time logs like the following:
39+
```
40+
completed step: 29, seconds: 8.390, TFLOP/s/device: 304.788, Tokens/s/device: 3905.572, total_weights: 33554432, loss: 11.643
41+
```
42+
43+
### Workload Details
44+
45+
For reference, here are the `gemma3_12b_32768_4x_v6e256` workload details as found in `[email protected]`:
46+
47+
```
48+
MaxTextModel(
49+
model_name="gemma3-12b-32768-4x-v6e256",
50+
model_type="gemma3-12b",
51+
tuning_params={
52+
"per_device_batch_size": 1,
53+
"num_vocab_tiling": 16,
54+
"ici_fsdp_parallelism": 1,
55+
"ici_fsdp_transpose_parallelism": -1,
56+
"remat_policy": "custom",
57+
"decoder_layer_input": "device",
58+
"query_proj": "remat",
59+
"key_proj": "remat",
60+
"value_proj": "remat",
61+
"max_target_length": 32768,
62+
"attention": "flash",
63+
"gcs_metrics": True,
64+
"use_iota_embed": True,
65+
"dataset_path": "gs://max-datasets-rogue",
66+
"dataset_type": "synthetic",
67+
"reuse_example_batch": 1,
68+
"enable_checkpointing": False,
69+
"profiler": "xplane",
70+
"skip_first_n_steps_for_profiler": 10,
71+
"profiler_steps": 2,
72+
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
73+
"sa_block_q": 1024,
74+
"sa_block_kv": 1024,
75+
"sa_block_kv_compute": 1024,
76+
"sa_block_q_dkv": 512,
77+
"sa_block_kv_dkv": 2048,
78+
"sa_block_kv_dkv_compute": 512,
79+
"sa_block_q_dq": 1024,
80+
"sa_block_kv_dq": 1024,
81+
},
82+
xla_flags=(
83+
xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)
84+
),
85+
)
86+
```
87+
88+
This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/50bafeb98299458f73d853b1325787a6d241d10c/benchmarks/maxtext_trillium_model_configs.py) file within the MaxText repository.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Run this command from the MaxText root directory using the setup described in the README.
2+
python3 -m benchmarks.benchmark_runner xpk \
3+
--project=$PROJECT \
4+
--zone=$ZONE \
5+
--device_type=v6e-256 \
6+
--num_slices=4 \
7+
--cluster_name=${CLUSTER_NAME} \
8+
--base_output_directory=${OUTPUT_DIR} \
9+
--model_name="gemma3_12b_32768_4x_v6e256" \
10+
--base_docker_image=maxtext_base_image
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Instructions for training Gemma3-12B-MaxText on TPU trillium (v6e-256)
2+
3+
## XPK setup
4+
Please follow the [XPK_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/XPK_README.md) to create your GKE cluster with XPK
5+
6+
## Prep for Maxtext
7+
8+
### Install MaxText and Build Docker Image
9+
Please follow the [MAXTEXT_README](https://github.com/AI-Hypercomputer/tpu-recipes/blob/main/training/MAXTEXT_README.md) to install maxtext and build the docker image. The following variables should be set:
10+
11+
In step 1, use the MaxText [tpu-recipes-v0.1.5](https://github.com/AI-Hypercomputer/maxtext/releases/tag/tpu-recipes-v0.1.5) tag to run this recipe:
12+
```
13+
git checkout tpu-recipes-v0.1.5
14+
```
15+
16+
In step 3, use:
17+
```
18+
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable JAX_VERSION=0.7.0
19+
```
20+
21+
## Run Maxtext Gemma3-12B workloads on GKE
22+
23+
### Starting workload
24+
25+
From the MaxText root directory, start your Gemma3-12B workload.
26+
```
27+
python3 -m benchmarks.benchmark_runner xpk \
28+
--project=$PROJECT \
29+
--zone=$ZONE \
30+
--device_type=v6e-256 \
31+
--num_slices=1 \
32+
--cluster_name=${CLUSTER_NAME} \
33+
--base_output_directory=${OUTPUT_DIR} \
34+
--model_name="gemma3_12b_32768_v6e256" \
35+
--base_docker_image=maxtext_base_image
36+
```
37+
38+
From your workload logs, you should start seeing step time logs like the following:
39+
```
40+
completed step: 29, seconds: 7.318, TFLOP/s/device: 349.442, Tokens/s/device: 4477.768, total_weights: 8388608, loss: 10.495
41+
```
42+
43+
### Workload Details
44+
45+
For reference, here are the `gemma3_12b_32768_v6e256` workload details as found in `[email protected]`:
46+
47+
```
48+
MaxTextModel(
49+
model_name="gemma3-12b-32768-v6e256",
50+
model_type="gemma3-12b",
51+
tuning_params={
52+
"per_device_batch_size": 1,
53+
"num_vocab_tiling": 16,
54+
"ici_fsdp_parallelism": -1,
55+
"remat_policy": "custom",
56+
"decoder_layer_input": "device",
57+
"query_proj": "remat",
58+
"key_proj": "remat",
59+
"value_proj": "remat",
60+
"max_target_length": 32768,
61+
"attention": "flash",
62+
"gcs_metrics": True,
63+
"use_iota_embed": True,
64+
"dataset_path": "gs://max-datasets-rogue",
65+
"dataset_type": "synthetic",
66+
"reuse_example_batch": 1,
67+
"enable_checkpointing": False,
68+
"profiler": "xplane",
69+
"skip_first_n_steps_for_profiler": 10,
70+
"profiler_steps": 2,
71+
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
72+
"sa_block_q": 1024,
73+
"sa_block_kv": 1024,
74+
"sa_block_kv_compute": 1024,
75+
"sa_block_q_dkv": 512,
76+
"sa_block_kv_dkv": 2048,
77+
"sa_block_kv_dkv_compute": 512,
78+
"sa_block_q_dq": 1024,
79+
"sa_block_kv_dq": 1024,
80+
},
81+
xla_flags=(
82+
xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)
83+
),
84+
)
85+
```
86+
87+
This equivalent workload code can be found in the [maxtext_trillium_model_configs.py](https://github.com/AI-Hypercomputer/maxtext/blob/50bafeb98299458f73d853b1325787a6d241d10c/benchmarks/maxtext_trillium_model_configs.py) file within the MaxText repository.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Run this command from the MaxText root directory using the setup described in the README.
2+
python3 -m benchmarks.benchmark_runner xpk \
3+
--project=$PROJECT \
4+
--zone=$ZONE \
5+
--device_type=v6e-256 \
6+
--num_slices=1 \
7+
--cluster_name=${CLUSTER_NAME} \
8+
--base_output_directory=${OUTPUT_DIR} \
9+
--model_name="gemma3_12b_32768_v6e256" \
10+
--base_docker_image=maxtext_base_image

0 commit comments

Comments
 (0)