How to create a param in Module with a shape dependent on the input? #1154
-
Original question by @psc-g: the value for features for one of my layers needs to depend on the shape of the input. how can i do this in setup if it doesn't seem like the input is explicitly passed in as a parameter? for initialization the input is passed in, so my guess is it would have to be set up during the init call, but i'm not sure how to go about this. is the recommended approach to decorate |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Answer by @avital (reworded by me): If you want to use class Foo(nn.Module):
in_features: int
out_features: int
def setup(self):
self.kernel = self.param('kernel',
initializers.lecun_normal(),
(self.in_features, self.out_features))
def __call__(self, x):
return jnp.dot(self.kernel, x) But using class Foo(nn.Module):
@nn.compact
def __call__(self, x):
kernel = self.param('kernel',
initializers.lecun_normal(),
(x[-1], x[-2]))
return jnp.dot(self.kernel, x) |
Beta Was this translation helpful? Give feedback.
Answer by @avital (reworded by me):
If you want to use
setup
then you'd probably just add another attribute for the input dimensions (this is how /all/ modules in PyTorch are defined). Likein_features
or something:But using
@nn.compact
is probably easier and then you get shape inference for free. Suppose you want to get the shapes for your kernel from the last two dimensions of your…