Skip to content

NNX version of UNET #4578

Answered by cgarciae
jecampagne asked this question in Q&A
Feb 26, 2025 · 2 comments · 1 reply
Discussion options

You must be logged in to vote

Hi @jecampagne, in NNX you can store Modules inside lists, dicts, and tuples so you could easily rewrite the above to use a for loop e.g.

class Denoiser(nnx.Module):
  channels: tuple[int] = (32, 64, 128, 256)

  def __init__(self, rngs: nnx.Rngs, din: int = 1):
    self.act = nnx.swish

    self.encoder = [
      (nnx.Conv(din if i == 0 else self.channels[i-1], ch, (3, 3), (1 if i == 0 else 2, 1 if i == 0 else 2), padding='VALID', use_bias=False, rngs=rngs),
       nnx.GroupNorm(ch, num_groups=4 if i == 0 else None, use_bias=False, rngs=rngs))
      for i, ch in enumerate(self.channels)
    ]

    self.decoder = [
      (nnx.Conv(self.channels[i], self.channels[i-1] if i > 0 else din, (3, 3

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@jecampagne
Comment options

Answer selected by jecampagne
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants