Skip to content

Automatically Optimize grain_worker_count for Improved Data Loading Performance #2509

@bzantium

Description

@bzantium

What problem are you trying to solve?

The grain_worker_count parameter, which controls the number of parallel data loading workers, has a significant impact on training performance, especially when tokenizing raw text data on the fly. An incorrectly configured grain_worker_count can cause the data input and preprocessing pipeline to become a major bottleneck, leading to drastically reduced hardware utilization (TFLOP/s). Users must currently find the optimal value through manual trial and error, which is inefficient.

Why is this problem important?

Manually tuning grain_worker_count is a tedious process that requires multiple experimental runs to identify the best setting for a specific hardware and dataset configuration. This creates a poor user experience and can prevent users from achieving optimal training performance. An automatic solution would save significant time and effort, making it easier for users to maximize the efficiency of their training jobs right from the start.

Describe your requested feature or solution

We propose that MaxText automatically determine the optimal number of data loading workers. This can be implemented by defaulting grain_worker_count to -1, which would leverage a function like grain.experimental.pick_performance_config to dynamically select the best value based on the system's capabilities. This would eliminate the guesswork for the user and prevent the data pipeline from bottlenecking the training process.

Describe alternatives you’ve considered (if any)

The only alternative is the current method: manually setting and testing different integer values for grain_worker_count. This is the inefficient process we are seeking to improve.

Additional context or examples

The performance impact is clear from training experiments on a v6e-32 pod. Below is a summary of the TeraFLOPs per second per device (TFLOP/s/device) and the average time per step observed when training a Llama3-8B model with varying grain_worker_count values.

As shown, a low worker count leads to slow, erratic step times and low throughput. Performance stabilizes and improves dramatically with 4 and 8 workers, achieving fast, consistent step times of around 4.3 seconds.

grain_worker_count Average TFLOP/s/device (Steps 3-9) Average Time/Step (s) (Steps 3-9) Stability
1 ~29 TFLOP/s ~30.6 s Unstable
2 ~60 TFLOP/s ~13.5 s Highly Unstable
4 ~195 TFLOP/s ~4.3 s Stable
8 ~195 TFLOP/s ~4.3 s Stable

An automatic configuration would ideally select a value like 4 or 8, ensuring stable and efficient training with minimal step times.

Full Logs for Reference:

Click to expand logs

grain_worker_count = 1

# TFLOP/s/device values: 13.8, 78.8, 29.4, 27.7, 28.7, 25.1, 23.1, 28.1, 29.9, 32.0
# seconds per step: 61.1, 10.7, 28.7, 30.5, 29.4, 33.6, 36.5, 30.0, 28.2, 26.3

grain_worker_count = 2

# TFLOP/s/device values: 14.5, 3267.4, 73.2, 129458.3, 28.2, 129717.1, 31.9, 517.3, 33.5, 455.1
# seconds per step: 58.3, 0.2, 11.5, 0.007, 29.9, 0.007, 26.5, 1.6, 25.2, 1.9

grain_worker_count = 4

# TFLOP/s/device values: 15.2, 3385.9, 207.3, 195.9, 195.3, 195.0, 195.3, 195.4, 85.9, 137321.5
# seconds per step: 55.3, 0.2, 4.1, 4.3, 4.3, 4.3, 4.3, 4.3, 9.8, 0.006

grain_worker_count = 8

# TFLOP/s/device values: 17.7, 3253.9, 109.3, 196.1, 195.4, 195.5, 195.6, 195.4, 195.5, 195.6
# seconds per step: 47.6, 0.3, 7.7, 4.3, 4.3, 4.3, 4.3, 4.3, 4.3, 4.3

Additional Context

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions