Skip to content

How to create a param in Module with a shape dependent on the input? #1154

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Answer by @avital (reworded by me):

If you want to use setup then you'd probably just add another attribute for the input dimensions (this is how /all/ modules in PyTorch are defined). Like in_features or something:

class Foo(nn.Module):
  in_features: int
  out_features: int

  def setup(self):
     self.kernel = self.param('kernel',
                              initializers.lecun_normal(),
                              (self.in_features, self.out_features))

  def __call__(self, x):
    return jnp.dot(self.kernel, x)

But using @nn.compact is probably easier and then you get shape inference for free. Suppose you want to get the shapes for your kernel from the last two dimensions of your…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant