Replies: 1 comment
-
your train_step in non nnx.vmap hasn't .5 factor, and you didn't hint what is the problem when executing your code (just what's blocking you not the code execution). I believe the problem is that using nnx.vmap, will only get 1 item per call thus, the BatchNorm has only 1 element to work with when using mean and variance (variance will be null).. NB: when sharing your code use <> Code option ( with 'python' at the start ) to output your code the right way, also the change with the nnx.vmap option only occurs in train_step (so you can avoid the excess code) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Update of question:
Dear Community,
I encountered a small challenge while using a batch normalization layer with nnx.vmap. To illustrate the issue, I have created a minimal example code snippet. Based on my understanding of the documentation for flax-nnx.vmap the issue seems to stem from the handling of BatchStat, which requires special consideration when using vmap.
Currently, I am struggling to make the final example in the attached code work. Does anyone have suggestions on how to adjust the loss function to work correctly with nnx.vmap?
I have a problem with the first defined train step, which uses nnx.vmap and does not execute.
Thank you very much!
Beta Was this translation helpful? Give feedback.
All reactions