You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using the JAX backend for LSTM / GRU models, I'm unable to see any speed-up when training with 2 Nvidia 3090 vs using a single Nvidia 3090 (using keras-nightly and JAX 0.5.2). The distributed training across 2 GPUs seems to work fine, but it is just not faster and maybe even slower. See attached file for a modified version of the Keras timeseries weather forecasting example that showcases the problem.
I also can't seem to find any "official" Keras / Keras-IO example showing distributed training with a measurement of the training time. Shouldn't there be such an "official" example to showcase the gain by multi-device training?
Having looked more into this issue, it turns out I'm able to see a speed-up for very large batch sizes, e.g. 65536. However, using such a large batch size is likely not practical for most model trainings.
The lack of speed-up for "normal" batch size seems to be the result of the way lax.scan is implemented in JAX / XLA, see e.g. jax-ml/jax#25336 and links therein for a good overview. It therefore looks like this is really a bottleneck in JAX / XLA and not Keras. However, it is proably good to monitor the development of this in JAX / XLA to see if any improvements made there can directly benefit Keras.
larschristensen
changed the title
Unable to see any speedup from multi-GPU training using JAX backend
Unable to see any speedup from multi-GPU training using JAX backend for LSTM / GRU
Mar 11, 2025
Using the JAX backend for LSTM / GRU models, I'm unable to see any speed-up when training with 2 Nvidia 3090 vs using a single Nvidia 3090 (using keras-nightly and JAX 0.5.2). The distributed training across 2 GPUs seems to work fine, but it is just not faster and maybe even slower. See attached file for a modified version of the Keras timeseries weather forecasting example that showcases the problem.
I also can't seem to find any "official" Keras / Keras-IO example showing distributed training with a measurement of the training time. Shouldn't there be such an "official" example to showcase the gain by multi-device training?
timeseries_weather_forecasting_LC.zip
The text was updated successfully, but these errors were encountered: