You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I am looking for a way to execute a subset of layers on one device (say CPU) and feed the outputs of that compute to the next set of layers to be executed on another device (say GPU). How would I accomplish this with JIT? My jax.put_device() to transfer data from cpu to gpu seems to be ignored, and I end up getting an error complaining inputs on different devices cannot be executed with JIT.
Is there any way I can design my model to schedule layers across different devices?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi, I am looking for a way to execute a subset of layers on one device (say CPU) and feed the outputs of that compute to the next set of layers to be executed on another device (say GPU). How would I accomplish this with JIT? My jax.put_device() to transfer data from cpu to gpu seems to be ignored, and I end up getting an error complaining inputs on different devices cannot be executed with JIT.
Is there any way I can design my model to schedule layers across different devices?
Beta Was this translation helpful? Give feedback.
All reactions