|
1714 | 1714 | ), |
1715 | 1715 | ) |
1716 | 1716 |
|
| 1717 | +gemma3_12b_32768_v6e256 = _add_to_model_dictionary( |
| 1718 | + trillium_model_dict, |
| 1719 | + MaxTextModel( |
| 1720 | + model_name="gemma3-12b-32768-v6e256", |
| 1721 | + model_type="gemma3-12b", |
| 1722 | + tuning_params={ |
| 1723 | + "per_device_batch_size": 1, |
| 1724 | + "num_vocab_tiling": 16, |
| 1725 | + "ici_fsdp_parallelism": -1, |
| 1726 | + "remat_policy": "custom", |
| 1727 | + "decoder_layer_input": "device", |
| 1728 | + "query_proj": "remat", |
| 1729 | + "key_proj": "remat", |
| 1730 | + "value_proj": "remat", |
| 1731 | + "max_target_length": 32768, |
| 1732 | + "attention": "flash", |
| 1733 | + "gcs_metrics": True, |
| 1734 | + "use_iota_embed": True, |
| 1735 | + "dataset_path": "gs://max-datasets-rogue", |
| 1736 | + "dataset_type": "synthetic", |
| 1737 | + "reuse_example_batch": 1, |
| 1738 | + "enable_checkpointing": False, |
| 1739 | + "profiler": "xplane", |
| 1740 | + "skip_first_n_steps_for_profiler": 10, |
| 1741 | + "profiler_steps": 2, |
| 1742 | + "tokenizer_path": os.path.join("assets", "tokenizer.gemma3"), |
| 1743 | + "sa_block_q": 1024, |
| 1744 | + "sa_block_kv": 1024, |
| 1745 | + "sa_block_kv_compute": 1024, |
| 1746 | + "sa_block_q_dkv": 512, |
| 1747 | + "sa_block_kv_dkv": 2048, |
| 1748 | + "sa_block_kv_dkv_compute": 512, |
| 1749 | + "sa_block_q_dq": 1024, |
| 1750 | + "sa_block_kv_dq": 1024, |
| 1751 | + }, |
| 1752 | + xla_flags=(xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)), |
| 1753 | + ), |
| 1754 | +) |
| 1755 | + |
| 1756 | +gemma3_12b_32768_2x_v6e256 = _add_to_model_dictionary( |
| 1757 | + trillium_model_dict, |
| 1758 | + MaxTextModel( |
| 1759 | + model_name="gemma3-12b-32768-2x-v6e256", |
| 1760 | + model_type="gemma3-12b", |
| 1761 | + tuning_params={ |
| 1762 | + "per_device_batch_size": 1, |
| 1763 | + "num_vocab_tiling": 16, |
| 1764 | + "ici_fsdp_parallelism": 1, |
| 1765 | + "ici_fsdp_transpose_parallelism": -1, |
| 1766 | + "remat_policy": "custom", |
| 1767 | + "decoder_layer_input": "device", |
| 1768 | + "query_proj": "remat", |
| 1769 | + "key_proj": "remat", |
| 1770 | + "value_proj": "remat", |
| 1771 | + "max_target_length": 32768, |
| 1772 | + "attention": "flash", |
| 1773 | + "gcs_metrics": True, |
| 1774 | + "use_iota_embed": True, |
| 1775 | + "dataset_path": "gs://max-datasets-rogue", |
| 1776 | + "dataset_type": "synthetic", |
| 1777 | + "reuse_example_batch": 1, |
| 1778 | + "enable_checkpointing": False, |
| 1779 | + "profiler": "xplane", |
| 1780 | + "skip_first_n_steps_for_profiler": 10, |
| 1781 | + "profiler_steps": 2, |
| 1782 | + "tokenizer_path": os.path.join("assets", "tokenizer.gemma3"), |
| 1783 | + "sa_block_q": 1024, |
| 1784 | + "sa_block_kv": 1024, |
| 1785 | + "sa_block_kv_compute": 1024, |
| 1786 | + "sa_block_q_dkv": 512, |
| 1787 | + "sa_block_kv_dkv": 2048, |
| 1788 | + "sa_block_kv_dkv_compute": 512, |
| 1789 | + "sa_block_q_dq": 1024, |
| 1790 | + "sa_block_kv_dq": 1024, |
| 1791 | + }, |
| 1792 | + xla_flags=(xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)), |
| 1793 | + ), |
| 1794 | +) |
| 1795 | + |
| 1796 | +gemma3_12b_32768_4x_v6e256 = _add_to_model_dictionary( |
| 1797 | + trillium_model_dict, |
| 1798 | + MaxTextModel( |
| 1799 | + model_name="gemma3-12b-32768-4x-v6e256", |
| 1800 | + model_type="gemma3-12b", |
| 1801 | + tuning_params={ |
| 1802 | + "per_device_batch_size": 1, |
| 1803 | + "num_vocab_tiling": 16, |
| 1804 | + "ici_fsdp_parallelism": 1, |
| 1805 | + "ici_fsdp_transpose_parallelism": -1, |
| 1806 | + "remat_policy": "custom", |
| 1807 | + "decoder_layer_input": "device", |
| 1808 | + "query_proj": "remat", |
| 1809 | + "key_proj": "remat", |
| 1810 | + "value_proj": "remat", |
| 1811 | + "max_target_length": 32768, |
| 1812 | + "attention": "flash", |
| 1813 | + "gcs_metrics": True, |
| 1814 | + "use_iota_embed": True, |
| 1815 | + "dataset_path": "gs://max-datasets-rogue", |
| 1816 | + "dataset_type": "synthetic", |
| 1817 | + "reuse_example_batch": 1, |
| 1818 | + "enable_checkpointing": False, |
| 1819 | + "profiler": "xplane", |
| 1820 | + "skip_first_n_steps_for_profiler": 10, |
| 1821 | + "profiler_steps": 2, |
| 1822 | + "tokenizer_path": os.path.join("assets", "tokenizer.gemma3"), |
| 1823 | + "sa_block_q": 1024, |
| 1824 | + "sa_block_kv": 1024, |
| 1825 | + "sa_block_kv_compute": 1024, |
| 1826 | + "sa_block_q_dkv": 512, |
| 1827 | + "sa_block_kv_dkv": 2048, |
| 1828 | + "sa_block_kv_dkv_compute": 512, |
| 1829 | + "sa_block_q_dq": 1024, |
| 1830 | + "sa_block_kv_dq": 1024, |
| 1831 | + }, |
| 1832 | + xla_flags=(xla_flags_library.CUSTOM_VMEM_LIMIT_FLAG(vmem_limit=122880)), |
| 1833 | + ), |
| 1834 | +) |
| 1835 | + |
1717 | 1836 | # Config for Llama3.1 70B model with 131072 max target length aka context length |
1718 | 1837 | llama3_1_70b_131072 = _add_to_model_dictionary( |
1719 | 1838 | trillium_model_dict, |
|
0 commit comments