You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I can use a regular ocp.StandardCheckpointer() to save and load the model state, but using the ocp.CheckpointManager() fails. I believe I proceed in a standard way as described in the minimum example below:
# Define a very simple NNX model.
class OneLayerMLP(nnx.Module):
def __init__(self, dim, rngs: nnx.Rngs):
self.linear = nnx.Linear(dim, dim, rngs=rngs, use_bias=False)
def __call__(self, x):
return self.linear(x)
# Create the model
model = OneLayerMLP(4, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(42), (3, 4))
assert model(x).shape == (3, 4)
# Retrieve the state
_, state = nnx.split(model)
ckpt_dir = '/some/path/...'
# Create the checkpoint manager and save the state.
handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
str(ckpt_dir / 'state'),
checkpointer,
ocp.CheckpointManagerOptions(
save_interval_steps=1, max_to_keep=5
),
)
checkpoint_manager.save(123, state)
# Restore the state.
abstract_model = nnx.eval_shape(lambda: OneLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
handler = ocp.StandardCheckpointHandler()
checkpointer = ocp.Checkpointer(handler)
checkpoint_manager = ocp.CheckpointManager(
str(ckpt_dir / 'state'),
checkpointer)
state_restored = checkpoint_manager.restore(checkpoint_manager.latest_step())
# jax.tree.map(np.testing.assert_array_equal, state, state_restored) # This call would already fail, see below [2]
# Run the restored model.
model = nnx.merge(graphdef, state_restored)
assert model(x).shape == (3, 4) # This call fails on error [1], see below
The error [1] which I see when trying to feed-forward through the loaded model is:
TypeError: Unexpected input type for array: <class 'dict'>
When uncommenting the line above which compares the original and loaded state, I get the error [2] already indicating a problem:
ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.statelib.State'>, ...
What am I missing? How should one use the CheckpointManager correctly together with the nnx models?
When you call checkpoint_manager.restore() without giving a target, Orbax will by default read the checkpoint using only Python builtin containers, because it doesn't know that it should restore it as an nnx.State, just like state.
So you need to add items arg to tell Orbax to restore it as the same structure as state:
I hope this helps. Check out our checkpoint guide for more code examples - for example, an alternative is to always save the stripped-out pure dict instead of the raw nnx.State, like this.
I can use a regular
ocp.StandardCheckpointer()
to save and load the model state, but using theocp.CheckpointManager()
fails. I believe I proceed in a standard way as described in the minimum example below:The error [1] which I see when trying to feed-forward through the loaded model is:
TypeError: Unexpected input type for array: <class 'dict'>
When uncommenting the line above which compares the original and loaded state, I get the error [2] already indicating a problem:
ValueError: Custom node type mismatch: expected type: <class 'flax.nnx.statelib.State'>, ...
What am I missing? How should one use the
CheckpointManager
correctly together with the nnx models?System information
flax version: 0.10.2
orbax.checkpoint version: 0.11.0
The text was updated successfully, but these errors were encountered: