Skip to content

Commit 50bafeb

Browse files
Merge pull request #2488 from AI-Hypercomputer:carlosbus/training_v6e_gemma3_12b
PiperOrigin-RevId: 820979201
2 parents e0b5028 + 291ec23 commit 50bafeb

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

benchmarks/maxtext_trillium_model_configs.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,125 @@
17141714
),
17151715
)
17161716

1717+
gemma3_12b_32768_v6e256 = _add_to_model_dictionary(
1718+
trillium_model_dict,
1719+
MaxTextModel(
1720+
model_name="gemma3-12b-32768-v6e256",
1721+
model_type="gemma3-12b",
1722+
tuning_params={
1723+
"per_device_batch_size": 1,
1724+
"num_vocab_tiling": 16,
1725+
"ici_fsdp_parallelism": -1,
1726+
"remat_policy": "custom",
1727+
"decoder_layer_input": "device",
1728+
"query_proj": "remat",
1729+
"key_proj": "remat",
1730+
"value_proj": "remat",
1731+
"max_target_length": 32768,
1732+
"attention": "flash",
1733+
"gcs_metrics": True,
1734+
"use_iota_embed": True,
1735+
"dataset_path": "gs://max-datasets-rogue",
1736+
"dataset_type": "synthetic",
1737+
"reuse_example_batch": 1,
1738+
"enable_checkpointing": False,
1739+
"profiler": "xplane",
1740+
"skip_first_n_steps_for_profiler": 10,
1741+
"profiler_steps": 2,
1742+
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
1743+
"sa_block_q": 1024,
1744+
"sa_block_kv": 1024,
1745+
"sa_block_kv_compute": 1024,
1746+
"sa_block_q_dkv": 512,
1747+
"sa_block_kv_dkv": 2048,
1748+
"sa_block_kv_dkv_compute": 512,
1749+
"sa_block_q_dq": 1024,
1750+
"sa_block_kv_dq": 1024,
1751+
},
1752+
xla_flags=(xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)),
1753+
),
1754+
)
1755+
1756+
gemma3_12b_32768_2x_v6e256 = _add_to_model_dictionary(
1757+
trillium_model_dict,
1758+
MaxTextModel(
1759+
model_name="gemma3-12b-32768-2x-v6e256",
1760+
model_type="gemma3-12b",
1761+
tuning_params={
1762+
"per_device_batch_size": 1,
1763+
"num_vocab_tiling": 16,
1764+
"ici_fsdp_parallelism": 1,
1765+
"ici_fsdp_transpose_parallelism": -1,
1766+
"remat_policy": "custom",
1767+
"decoder_layer_input": "device",
1768+
"query_proj": "remat",
1769+
"key_proj": "remat",
1770+
"value_proj": "remat",
1771+
"max_target_length": 32768,
1772+
"attention": "flash",
1773+
"gcs_metrics": True,
1774+
"use_iota_embed": True,
1775+
"dataset_path": "gs://max-datasets-rogue",
1776+
"dataset_type": "synthetic",
1777+
"reuse_example_batch": 1,
1778+
"enable_checkpointing": False,
1779+
"profiler": "xplane",
1780+
"skip_first_n_steps_for_profiler": 10,
1781+
"profiler_steps": 2,
1782+
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
1783+
"sa_block_q": 1024,
1784+
"sa_block_kv": 1024,
1785+
"sa_block_kv_compute": 1024,
1786+
"sa_block_q_dkv": 512,
1787+
"sa_block_kv_dkv": 2048,
1788+
"sa_block_kv_dkv_compute": 512,
1789+
"sa_block_q_dq": 1024,
1790+
"sa_block_kv_dq": 1024,
1791+
},
1792+
xla_flags=(xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)),
1793+
),
1794+
)
1795+
1796+
gemma3_12b_32768_4x_v6e256 = _add_to_model_dictionary(
1797+
trillium_model_dict,
1798+
MaxTextModel(
1799+
model_name="gemma3-12b-32768-4x-v6e256",
1800+
model_type="gemma3-12b",
1801+
tuning_params={
1802+
"per_device_batch_size": 1,
1803+
"num_vocab_tiling": 16,
1804+
"ici_fsdp_parallelism": 1,
1805+
"ici_fsdp_transpose_parallelism": -1,
1806+
"remat_policy": "custom",
1807+
"decoder_layer_input": "device",
1808+
"query_proj": "remat",
1809+
"key_proj": "remat",
1810+
"value_proj": "remat",
1811+
"max_target_length": 32768,
1812+
"attention": "flash",
1813+
"gcs_metrics": True,
1814+
"use_iota_embed": True,
1815+
"dataset_path": "gs://max-datasets-rogue",
1816+
"dataset_type": "synthetic",
1817+
"reuse_example_batch": 1,
1818+
"enable_checkpointing": False,
1819+
"profiler": "xplane",
1820+
"skip_first_n_steps_for_profiler": 10,
1821+
"profiler_steps": 2,
1822+
"tokenizer_path": os.path.join("assets", "tokenizer.gemma3"),
1823+
"sa_block_q": 1024,
1824+
"sa_block_kv": 1024,
1825+
"sa_block_kv_compute": 1024,
1826+
"sa_block_q_dkv": 512,
1827+
"sa_block_kv_dkv": 2048,
1828+
"sa_block_kv_dkv_compute": 512,
1829+
"sa_block_q_dq": 1024,
1830+
"sa_block_kv_dq": 1024,
1831+
},
1832+
xla_flags=(xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)),
1833+
),
1834+
)
1835+
17171836
# Config for Llama3.1 70B model with 131072 max target length aka context length
17181837
llama3_1_70b_131072 = _add_to_model_dictionary(
17191838
trillium_model_dict,

0 commit comments

Comments
 (0)