You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am one of the participant of Huggingface Community Week. I got a question about pmap and training loop, would really appreciate if someone could offer me some advice on that:
I have an memory & time expensive step that needa run per N iterations, e.g.:
for it, batch in enumerate(dataset):
state = pmap_train(keys, model, state, batch, lr)
if it % N == 0:
run_expensive_step(keys, model, state, batch, clip_model, lr)
Questions
will this run_expensive_step create a big bottleneck on pmap_train call, or it wont be blocked (maybe becoz of async nature?)...
is there any way to enforce run_expensive_step to be called only on the host core? (coz my host have 300+GB and run_expensive_step is super memory hungry)
model has been passed through flax.jax_utils.replicate and batch has been passed through shard, but clip_model haven't been flax.jax_utils.replicate. Do I needa reduce model and batch dimension before each run_expensive_step call?
Yes this would be blocking - run_expensive_step() needs to wait for the computation of state to finish first before transferring it to the host memory. And then your training loop is blocked until run_expensive_step() finishes. If that's an issue you could run run_expensive_step() in a separate thread.
Inside run_expensive_step() you could use numpy for the computation. That would automatically copy the values to host memory and use CPU for computations. Alternatively, you could jax.jit(run_expensive_step, backend='cpu').
In either case you would need to flax.jax_utils.unreplicate() the pytrees first to get rid of the leading device dimension. As for batch you would need to transform it via jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:], batch) or similar (i.e. reshaping, not slicing).
This discussion was converted from issue #1435 on July 15, 2021 07:20.
Heading
Bold
Italic
Quote
Code
Link
Numbered list
Unordered list
Task list
Attach files
Mention
Reference
Menu
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I am one of the participant of Huggingface Community Week. I got a question about pmap and training loop, would really appreciate if someone could offer me some advice on that:
I have an memory & time expensive step that needa run per N iterations, e.g.:
Questions
run_expensive_step
create a big bottleneck onpmap_train
call, or it wont be blocked (maybe becoz of async nature?)...run_expensive_step
is super memory hungry)model
has been passed throughflax.jax_utils.replicate
andbatch
has been passed throughshard
, butclip_model
haven't beenflax.jax_utils.replicate
. Do I needa reducemodel
andbatch
dimension before eachrun_expensive_step
call?Beta Was this translation helpful? Give feedback.
All reactions