How to train sharded model using multi-host TPUs? #3318
Unanswered
Quasar-Kim
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello FLAX community,
I've been experimenting with parallel training feature in JAX/FLAX. I was able to utilize TPU v3-8 by annotating parameters with
nn.with_partitioning
and activations withlax.with_sharding_constraints
, as described in the parallel training guide.Now I want to scale the model to work with multi-host TPUs (e.g. v3-32). But I wasn't able to find any guide / example regarding this. So my questions are:
jit()
viain_shardings
andout_shardings
parameters sufficient? I'm concerned about that it might cause parameters to be replicated since each hosts might try to place parameters respectively.Beta Was this translation helpful? Give feedback.
All reactions