-
I have a module, one of the parameters of which is "shape-agnostic", i.e. it may slightly vary during inference (actually, it firstly doubles its size during Riemannian optimization step and then shrink back to its initial size). And during this shape increase it fails to run Consider the following simple example: we have a module and "shape-agnostic" parameter: from flax import linen
class A(linen.Module):
size: int
def setup(self):
self.array = self.param('array', lambda _: jnp.zeros(self.size))
def __call__(self):
return self.array.mean()
model = A(10)
params = model.init(jax.random.PRNGKey(0))
print(model.apply(params)) As expected, this works perfectly well. Please note, that Next, we want to do some model-surgery and call from flax.core import freeze, unfreeze
params = unfreeze(params)
params['params']['array'] = jnp.concatenate([params['params']['array'], params['params']['array']])
parms = freeze(params)
print(model.apply(params)) That would raise As far, as I understand, to rebuild the model from params, flax firstly reevaluates |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
I'm not sure if I understand this question. If your model doubles in size, why don't you just construct it with the new size? So your last line looks as follows: print(A(20).apply(params)) Or you derive the desired size from the input params: size = params['params']['array'].shape[0]
print(A(size).apply(params)) Also, as a general rule: I would suggest to name the return value of |
Beta Was this translation helpful? Give feedback.
-
For params we use shape inference to check that the initialiser and it's value have the same shape. This avoids a lot of issues with hyper paramaters and params being out of sync for example after restoring a checkpoint. We might at some point at a keyword arg to disable this check but for now an easy workaround is: class A(linen.Module):
size: int
def setup(self):
self.array = self.variable('params', 'array', jnp.zeros, self.size)
def __call__(self):
return self.array.mean() btw I would consider putting such a variable in a separate collection than "params" anyway. Quite often you need to enforce shape invariance outside of the model as well for example when the params are being optimized. |
Beta Was this translation helpful? Give feedback.
For params we use shape inference to check that the initialiser and it's value have the same shape. This avoids a lot of issues with hyper paramaters and params being out of sync for example after restoring a checkpoint. We might at some point at a keyword arg to disable this check but for now an easy workaround is:
btw I would consider putting such a variable in a separate collection than "params" anyway. Quite often you need to enforce shape invariance outside of the model as well …