Skip to content

Batch Normalization in Flax #2080

Answered by andsteing
marcvanzee asked this question in General
Jan 22, 2021 · 7 comments · 2 replies
Discussion options

You must be logged in to vote

Since @marcvanzee already provides an answer to "how to use batchnorm" in his original post above, I'm only going to cover the question "when and how to normalize statistics" asked by @cgarciae and more recently by @laoreja (on an internal forum), and answered above by @jheek and by @levskaya (on an internal forum)

Our examples/imagenet/ uses nn.BatchNorm in the ResNet model:

norm = partial(nn.BatchNorm,
use_running_average=not train,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype)

Then we define a utility function sync_batch_stats()...

Replies: 7 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@nalzok
Comment options

@marcvanzee
Comment options

marcvanzee Aug 11, 2022
Maintainer Author

Answer selected by andsteing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
7 participants
Converted from issue

This discussion was converted from issue #932 on April 27, 2022 19:40.