-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
train.py OOM on TPUv3-8 #10
Comments
I got the same error |
I guess TPUv3-8 doesn't have enough memory. Is there easy way to use PEFT for fine-tuning? |
I've tested on TPU v3-32 and there is no OOM error |
I haven't started to implement PEFT in this project yet |
Will it be working if you use |
@zhangzx-uiuc Using Besides, I contacted the OP of google-deepmind/optax#377, and I learnt that "it is bad practice to keep the actual params in bf16 during training". I think the performance would be better if we stick to Another thing that is worth noticing is the precision of the rotary embedding: https://www.qbitai.com/2023/08/78565.html. I haven't fixed this yet. |
@ayaka14732 if you're using multisteps you need to add context manager to init it inside CPU memory instead of TPU device:0 with jax.default_device(jax.devices("cpu")[0]):
# your optax multisteps code here |
@lodestone-rock optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps) Should I use MutliSteps optimizer like this to avoid TPU OOM? with jax.default_device(jax.devices("cpu")[0]):
# your optax multisteps code here
optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps) |
@Beomi |
The text was updated successfully, but these errors were encountered: