NNX version of UNET #4578
-
Hello, I have seen in this discussion thread #2309 an implementation of UNet using FLAX linen API, is there a chance that someone has cooked a NNX version? Thanks. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
In case it may interest someone, I've found this one from Y. Song that I have adapted to NNX class Denoiser(nnx.Module):
channels: Tuple[int] = (32, 64, 128, 256)
def __init__(self, rngs: nnx.Rngs, din:int=1):
self.act = nnx.swish
self.conv1 = nnx.Conv(din, self.channels[0], (3, 3), (1, 1), padding='VALID',
use_bias=False, rngs=rngs)
self.Norm1 = nnx.GroupNorm(self.channels[0],num_groups=4, use_bias=False, rngs=rngs)
self.conv2 = nnx.Conv(self.channels[0], self.channels[1], (3, 3), (2, 2), padding='VALID',
use_bias=False, rngs=rngs)
self.Norm2 = nnx.GroupNorm(self.channels[1], use_bias=False,rngs=rngs) # num_groups=32 by default
self.conv3 = nnx.Conv(self.channels[1],self.channels[2], (3, 3), (2, 2), padding='VALID',
use_bias=False, rngs=rngs)
self.Norm3 = nnx.GroupNorm(self.channels[2], use_bias=False,rngs=rngs)
self.conv4 = nnx.Conv(self.channels[2], self.channels[3], (3, 3), (2, 2), padding='VALID',
use_bias=False, rngs=rngs)
self.Norm4 = nnx.GroupNorm(self.channels[3], use_bias=False,rngs=rngs)
#decoder
self.conv5 = nnx.Conv(self.channels[3], self.channels[2], (3, 3), (1, 1), padding=((2, 2), (2, 2)),
input_dilation=(2, 2), use_bias=False, rngs=rngs)
self.Norm5 = nnx.GroupNorm(self.channels[2], use_bias=False,rngs=rngs)
self.conv6 = nnx.Conv(2*self.channels[2], self.channels[1], (3, 3), (1, 1), padding=((2, 3), (2, 3)),
input_dilation=(2, 2), use_bias=False, rngs=rngs)
self.Norm6 = nnx.GroupNorm(self.channels[1], use_bias=False,rngs=rngs)
self.conv7 = nnx.Conv(2*self.channels[1], self.channels[0], (3, 3), (1, 1), padding=((2, 3), (2, 3)),
input_dilation=(2, 2), use_bias=False, rngs=rngs)
self.Norm7 = nnx.GroupNorm(self.channels[0], use_bias=False,rngs=rngs)
self.conv8 = nnx.Conv(2*self.channels[0], din, (3, 3), (1, 1), padding=((2, 2), (2, 2)), rngs=rngs)
def __call__(self, x):
h1 = self.conv1(x)
h1 = self.Norm1(h1)
h1 = self.act(h1)
h2 = self.conv2(h1)
h2 = self.Norm2(h2)
h2 = self.act(h2)
h3 = self.conv3(h2)
h3 = self.Norm3(h3)
h3 = self.act(h3)
h4 = self.conv4(h3)
h4 = self.Norm4(h4)
h4 = self.act(h4)
# decondig
h = self.conv5(h4)
h = self.Norm5(h)
h = self.act(h)
h = self.conv6(jnp.concatenate([h, h3], axis=-1))
h = self.Norm6(h)
h = self.act(h)
h = self.conv7(jnp.concatenate([h, h2], axis=-1))
h = self.Norm7(h)
h = self.act(h)
h = self.conv8(jnp.concatenate([h, h1], axis=-1))
return h If someone can show me show how to use block-schema and loop over blocks architecture then it can be nice... |
Beta Was this translation helpful? Give feedback.
-
Hi @jecampagne, in NNX you can store Modules inside class Denoiser(nnx.Module):
channels: tuple[int] = (32, 64, 128, 256)
def __init__(self, rngs: nnx.Rngs, din: int = 1):
self.act = nnx.swish
self.encoder = [
(nnx.Conv(din if i == 0 else self.channels[i-1], ch, (3, 3), (1 if i == 0 else 2, 1 if i == 0 else 2), padding='VALID', use_bias=False, rngs=rngs),
nnx.GroupNorm(ch, num_groups=4 if i == 0 else None, use_bias=False, rngs=rngs))
for i, ch in enumerate(self.channels)
]
self.decoder = [
(nnx.Conv(self.channels[i], self.channels[i-1] if i > 0 else din, (3, 3), (1, 1), padding=((2, 2), (2, 2)) if i == 0 else ((2, 3), (2, 3)),
input_dilation=(2, 2), use_bias=False, rngs=rngs),
nnx.GroupNorm(self.channels[i-1] if i > 0 else din, use_bias=False, rngs=rngs) if i < len(self.channels) - 1 else None)
for i in range(len(self.channels)-1, -1, -1)
]
def __call__(self, x):
h = x
encoder_outputs = []
for conv, norm in self.encoder:
h = conv(h)
if norm:
h = norm(h)
h = self.act(h)
encoder_outputs.append(h)
for i, (conv, norm) in enumerate(self.decoder):
if i > 0:
h = jnp.concatenate([h, encoder_outputs[-i]], axis=-1)
h = conv(h)
if norm:
h = norm(h)
h = self.act(h) if norm else h
return h |
Beta Was this translation helpful? Give feedback.
Hi @jecampagne, in NNX you can store Modules inside
list
s,dict
s, andtuple
s so you could easily rewrite the above to use a for loop e.g.