Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to see any speedup from multi-GPU training using JAX backend for LSTM / GRU #20998

Open
larschristensen opened this issue Mar 6, 2025 · 1 comment
Assignees

Comments

@larschristensen
Copy link

larschristensen commented Mar 6, 2025

Using the JAX backend for LSTM / GRU models, I'm unable to see any speed-up when training with 2 Nvidia 3090 vs using a single Nvidia 3090 (using keras-nightly and JAX 0.5.2). The distributed training across 2 GPUs seems to work fine, but it is just not faster and maybe even slower. See attached file for a modified version of the Keras timeseries weather forecasting example that showcases the problem.

I also can't seem to find any "official" Keras / Keras-IO example showing distributed training with a measurement of the training time. Shouldn't there be such an "official" example to showcase the gain by multi-device training?

timeseries_weather_forecasting_LC.zip

@larschristensen
Copy link
Author

Having looked more into this issue, it turns out I'm able to see a speed-up for very large batch sizes, e.g. 65536. However, using such a large batch size is likely not practical for most model trainings.

The lack of speed-up for "normal" batch size seems to be the result of the way lax.scan is implemented in JAX / XLA, see e.g. jax-ml/jax#25336 and links therein for a good overview. It therefore looks like this is really a bottleneck in JAX / XLA and not Keras. However, it is proably good to monitor the development of this in JAX / XLA to see if any improvements made there can directly benefit Keras.

@larschristensen larschristensen changed the title Unable to see any speedup from multi-GPU training using JAX backend Unable to see any speedup from multi-GPU training using JAX backend for LSTM / GRU Mar 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants