Skip to content

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Oct 15, 2025

MaxText Estimator: Automatic Batch Size & Remat Policy Search

This PR introduces an estimation tool (MaxText.estimator) that automatically searches for optimal training configurations to maximize performance and prevent OOM errors.

It can operate in several modes:

  • Search for the optimal batch size and rematerialization policy simultaneously.
  • Search for the optimal rematerialization policy for a fixed batch size.
  • Search for the best configuration while respecting a partially fixed policy (e.g., if you force context=offload).

Example Usage

Here are a few examples of running the estimator for a llama3.1-405b model on a v5p-1024 cluster.

Search for both Batch Size and Remat Policy

This is the most common use case. The estimator will find the best per_device_batch_size and the corresponding remat policy.

python -m MaxText.estimator \
MaxText/configs/base.yml \
steps=1 \
compile_topology=v5p-1024 \
compile_topology_num_slices=1 \
model_name=llama3.1-405b \
num_vocab_tiling=4

Search for Remat Policy with a Fixed Batch Size

If you know your target batch size (e.g., per_device_batch_size=4), the tool will find the "lightest" remat policy (least rematerialization) that allows that batch size to fit in memory.

python -m MaxText.estimator \
MaxText/configs/base.yml \
steps=1 \
compile_topology=v5p-1024 \
compile_topology_num_slices=1 \
model_name=llama3.1-405b \
num_vocab_tiling=4 \
per_device_batch_size=4

Search with a Partially Fixed Policy

If you want to enforce specific remat settings (e.g., you know you want to remat mlpwo and offload context), you can fix them. The estimator will then search for the best batch size and policy for the remaining tensors.

python -m MaxText.estimator \
MaxText/configs/base.yml \
steps=1 \
compile_topology=v5p-1024 \
compile_topology_num_slices=1 \
model_name=llama3.1-405b \
num_vocab_tiling=4 \
remat_policy=custom \
mlpwo=remat \
context=offload

Search with Fixed Batch Size AND Partially Fixed Policy

You can also combine a fixed batch size with a partially fixed policy

python -m MaxText.estimator \
MaxText/configs/base.yml \
steps=1 \
compile_topology=v5p-1024 \
compile_topology_num_slices=1 \
model_name=llama3.1-405b \
num_vocab_tiling=4 \
per_device_batch_size=5 \
remat_policy=custom \
mlpwo=remat \
context=offload

Output

The program runs its search and saves all suggested runnable commands that it believes will yield high MFU without OOMing.

The results are written to a file named remat_commands_from_estimator.txt.

Example

python -m MaxText.train MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs run_name=estimation steps=1 compile_topology=v5p-1024 compile_topology_num_slices=1 model_name=llama3.1-405b num_vocab_tiling=4 per_device_batch_size=1 remat_policy=custom mlpwo=offload mlpwi_0=offload mlpwi_1=device query_proj=device out_proj=device key_proj=device value_proj=device context=device
python -m MaxText.train MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs run_name=estimation steps=1 compile_topology=v5p-1024 compile_topology_num_slices=1 model_name=llama3.1-405b num_vocab_tiling=4 per_device_batch_size=2 remat_policy=custom mlpwo=offload mlpwi_0=offload mlpwi_1=offload query_proj=device out_proj=device key_proj=device value_proj=device context=device

FIXES: b/449559587

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NuojCheng NuojCheng added the draft Draft PR label Oct 15, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-estimator branch 4 times, most recently from c35671e to 8a4f7e3 Compare October 16, 2025 22:46
@NuojCheng NuojCheng changed the title Draft Remat policy + batch size estimation using AOT compilation Oct 16, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-estimator branch 2 times, most recently from 221011b to 74918d7 Compare October 16, 2025 23:37
@NuojCheng NuojCheng marked this pull request as ready for review October 16, 2025 23:38
@NuojCheng NuojCheng force-pushed the chengnuojin-estimator branch from 74918d7 to 9191796 Compare October 16, 2025 23:40
@NuojCheng NuojCheng added gemini-review and removed draft Draft PR labels Oct 16, 2025
@github-actions
Copy link

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Collaborator

@khatwanimohit khatwanimohit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@NuojCheng NuojCheng force-pushed the chengnuojin-estimator branch 2 times, most recently from 466284a to a05c9a2 Compare October 22, 2025 18:34
@NuojCheng NuojCheng force-pushed the chengnuojin-estimator branch from a05c9a2 to e234856 Compare October 22, 2025 18:51
@NuojCheng NuojCheng requested a review from parambole as a code owner October 22, 2025 18:51
@NuojCheng NuojCheng force-pushed the chengnuojin-estimator branch 2 times, most recently from 3813801 to 584312b Compare October 22, 2025 19:03
@NuojCheng NuojCheng force-pushed the chengnuojin-estimator branch from 584312b to f250b2f Compare October 22, 2025 19:15
@copybara-service copybara-service bot merged commit 33b8ac1 into main Oct 23, 2025
35 checks passed
@copybara-service copybara-service bot deleted the chengnuojin-estimator branch October 23, 2025 00:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants