Skip to content

Inconsistent shapes between value and initializer for linen.Module #1130

Answered by jheek
PgLoLo asked this question in Q&A
Discussion options

You must be logged in to vote

For params we use shape inference to check that the initialiser and it's value have the same shape. This avoids a lot of issues with hyper paramaters and params being out of sync for example after restoring a checkpoint. We might at some point at a keyword arg to disable this check but for now an easy workaround is:

class A(linen.Module):
    size: int
        
    def setup(self):
        self.array = self.variable('params', 'array', jnp.zeros, self.size)
        
    def __call__(self):
        return self.array.mean()

btw I would consider putting such a variable in a separate collection than "params" anyway. Quite often you need to enforce shape invariance outside of the model as well …

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@PgLoLo
Comment options

Comment options

You must be logged in to vote
1 reply
@PgLoLo
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants