Batch Normalization in Flax #2080
-
Batch Normalization is more complicated than most layers because of the mutation of moving averages during training. The BatchNorm module is in normalization.py. The canonical example using it is Imagenet. In a multi-device setting, every device updates its normalizing parameters ( If they aren't synced, they can theoretically diverge, but if your data is fairly uniform across shards they're likely to trend towards similar values. Syncing before eval is definitely a good idea though, since otherwise your eval results will depend on which devices process which examples.
Also take a look at the comments at #1489. Let's consider an example. We define a trivial conv + BN layer. class Foo(nn.Module):
train: bool
filters: int
@nn.compact
def __call__(self, x):
x = nn.Conv(self.filters, (1, 1), use_bias=False, dtype=jnp.float32)(x)
x = nn.BatchNorm(use_running_average=not self.train,
momentum=0.9,
epsilon=1e-5,
dtype=jnp.float32)(x)
return x
key = random.PRNGKey(0)
x = jnp.ones((5,4,4,3))
# We instantiate the layer then call its init function to get initial variable collections.
foo_vars = Foo(filters=7, train=True).init(key, x)
foo_vars This returns the following:
We explicitly say which variable collections are to be mutated by the apply function, those are then returned as auxilliary variables. y1, new_batch_stats = Foo(filters=7, train=True).apply(foo_vars, x, mutable=['batch_stats'])
new_batch_state This returns the following:
We stitch together params and batch stats collections to evaluate again. new_foo_vars = {'params': foo_vars['params'], 'batch_stats': new_batch_stats}
y2, even_newer_batch_stats = Foo(filters=7, train=True).apply(new_foo_vars, x, mutable=['batch_stats']) |
Beta Was this translation helpful? Give feedback.
Replies: 7 comments 2 replies
-
Hey, thanks for the guide. This is useful. I found something unintuitive:
and
Why are they different? BTW, I found this usage here: flax/examples/imagenet/train.py Line 69 in d068512 |
Beta Was this translation helpful? Give feedback.
-
Hi @cccntu -- what do you mean by "different"? variables = model.init(...)
# assume variables['params'] and variables['batch_stats'] are present here
other_variables, params = variables.pop('params')
# here params == variables['params'], and other_variables['batch_stats'] == variables['batch_stats'] Maybe the docstring for (BTW in your example you wrote Does this help? What could we improve in our documentation so that this would be less confusing? |
Beta Was this translation helpful? Give feedback.
-
@avital Thanks for the reply. I didn't find the doc for FrozenDict from google search and readthedocs search is extremely slow. So I didn't read the doc. |
Beta Was this translation helpful? Give feedback.
-
I don't think we even have reference docs for FrozenDict on RTD, we probably should. I filed an issue: #969 |
Beta Was this translation helpful? Give feedback.
-
Hey! I am very interested in the best practices for BatchNorm (or Thanks! |
Beta Was this translation helpful? Give feedback.
-
See the ImageNet example for the canonical example of combining BatchNorm and pmap. There indeed we sync the statistics before evaluation with |
Beta Was this translation helpful? Give feedback.
-
Since @marcvanzee already provides an answer to "how to use batchnorm" in his original post above, I'm only going to cover the question "when and how to normalize statistics" asked by @cgarciae and more recently by @laoreja (on an internal forum), and answered above by @jheek and by @levskaya (on an internal forum) Our flax/examples/imagenet/models.py Lines 95 to 99 in b3236ce Then we define a utility function flax/examples/imagenet/train.py Lines 211 to 218 in b3236ce ...which we then call before evaluating... flax/examples/imagenet/train.py Lines 357 to 361 in b3236ce ...and before writing a checkpoint flax/examples/imagenet/train.py Lines 369 to 371 in b3236ce Alternatively, we could have specified Syncing would then happen here: flax/flax/linen/normalization.py Lines 84 to 90 in b3236ce |
Beta Was this translation helpful? Give feedback.
Since @marcvanzee already provides an answer to "how to use batchnorm" in his original post above, I'm only going to cover the question "when and how to normalize statistics" asked by @cgarciae and more recently by @laoreja (on an internal forum), and answered above by @jheek and by @levskaya (on an internal forum)
Our
examples/imagenet/
usesnn.BatchNorm
in theResNet
model:flax/examples/imagenet/models.py
Lines 95 to 99 in b3236ce
Then we define a utility function
sync_batch_stats()
...flax/examples/imagenet/train.py
Lines 211 to 218 in b3…