Model performs poorly if its state is restored from checkpoint #3259
Replies: 1 comment
-
nvm...some silly mistake with how I calculated the warm-up steps. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am learning Jax and Flax and want to build a simple Masked Autoencoder model. I have built the model and trained it using the MNIST dataset. The model can reconstruct the images from the masked images (mask ratio = .6).
I want to save the model state and restart the training from the checkpoint. However, I found the model performed poorly if trained from a restored state.
The image below shows the loss metric of 4 training loops.
From the image, it is clear that the model trained with 50 epochs failed to learn much during the 2nd and 3rd training sessions, and its performance could not match the model trained with 200 epochs. The images reconstructed by the later model also look better.
I used the
orbax.checkpoint.CheckpointManager
to manage the saving and restoring of the checkpoints. This is how I created the object.When restoring from a checkpoint, I first create a new
state
object using a typicalcreate_train_state
function, then call therestore
function like this:This is how I saved the checkpoint.
I am not sure if the state is saved and restored correctly using the above code. Another thing is about the optimizer in the
create_train_state
function. I am unsure if I need to adjust the parameters if I restore the training from a checkpoint. I used the cosine warmup learning rate scheduler and adam optimizer.Beta Was this translation helpful? Give feedback.
All reactions