Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restoring a checkpoint with Orbax CheckpointManager fails #4521

Open
bednarikjan opened this issue Feb 1, 2025 · 1 comment
Open

Restoring a checkpoint with Orbax CheckpointManager fails #4521

bednarikjan opened this issue Feb 1, 2025 · 1 comment

Comments

@bednarikjan
Copy link

I can use a regular ocp.StandardCheckpointer() to save and load the model state, but using the ocp.CheckpointManager() fails. I believe I proceed in a standard way as described in the minimum example below:

# Define a very simple NNX model.
class OneLayerMLP(nnx.Module):
  def __init__(self, dim, rngs: nnx.Rngs):
    self.linear = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)    

  def __call__(self, x):
    return self.linear(x)

# Create the model
model = OneLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)

# Retrieve the state
_, state = nnx.split(model)
ckpt_dir = '/some/path/...'

# Create the checkpoint manager and save the state.
handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
    str(ckpt_dir / 'state'),
    checkpointer,
    ocp.CheckpointManagerOptions(
        save_interval_steps=1, max_to_keep=5
    ),
)
checkpoint_manager.save(123, state)

# Restore the state.
abstract_model = nnx.eval_shape(lambda: OneLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)

handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
    str(ckpt_dir / 'state'),
    checkpointer)
state_restored = checkpoint_manager.restore(checkpoint_manager.latest_step())
# jax.tree.map(np.testing.assert_array_equal, state, state_restored)  # This call would already fail, see below [2]

# Run the restored model.
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4)  # This call fails on error [1], see below

The error [1] which I see when trying to feed-forward through the loaded model is:

TypeError: Unexpected input type for array: <class 'dict'>

When uncommenting the line above which compares the original and loaded state, I get the error [2] already indicating a problem:

ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.statelib.State'>, ...

What am I missing? How should one use the CheckpointManager correctly together with the nnx models?

System information

flax version: 0.10.2
orbax.checkpoint version: 0.11.0

@IvyZX
Copy link
Collaborator

IvyZX commented Feb 6, 2025

When you call checkpoint_manager.restore() without giving a target, Orbax will by default read the checkpoint using only Python builtin containers, because it doesn't know that it should restore it as an nnx.State, just like state.

So you need to add items arg to tell Orbax to restore it as the same structure as state:

state_restored = checkpoint_manager.restore(checkpoint_manager.latest_step(), items=abstract_state)

I hope this helps. Check out our checkpoint guide for more code examples - for example, an alternative is to always save the stripped-out pure dict instead of the raw nnx.State, like this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants