|
17 | 17 | import os
|
18 | 18 | from typing import Text, Tuple, Union
|
19 | 19 | from absl import logging
|
| 20 | +import gin |
20 | 21 | import numpy as np
|
21 | 22 | import tensorflow.compat.v1 as tf
|
22 | 23 | import tensorflow.compat.v2 as tf2
|
@@ -266,10 +267,43 @@ def batch_norm_class(is_training, strategy=None):
|
266 | 267 | else:
|
267 | 268 | return BatchNormalization
|
268 | 269 |
|
| 270 | +# A cache of variable scope to BatchNorm layer. |
| 271 | +_BN_LAYER_CACHE = {} |
269 | 272 |
|
270 |
| -def batch_normalization(inputs, training=False, strategy=None, **kwargs): |
271 |
| - """A wrapper for TpuBatchNormalization.""" |
272 |
| - bn_layer = batch_norm_class(training, strategy)(**kwargs) |
| 273 | + |
| 274 | +@gin.configurable |
| 275 | +def batch_normalization(inputs, |
| 276 | + training=False, |
| 277 | + strategy=None, |
| 278 | + reuse_scope: bool = False, |
| 279 | + **kwargs): |
| 280 | + """A wrapper for TpuBatchNormalization. |
| 281 | +
|
| 282 | + Keras layers are incompatible with automatic tf scope reuse. |
| 283 | +
|
| 284 | + Supports reuse of the exiting variable scope when a model is called multiple |
| 285 | + times. Otherwise, checkpoint weights would not be restored correctly. |
| 286 | +
|
| 287 | + Args: |
| 288 | + inputs: Input to BatchNorm layer. |
| 289 | + training: Argument of Keras BatchNorm layer. |
| 290 | + strategy: Argument of Keras BatchNorm layer. |
| 291 | + reuse_scope: Whether to reuse existing layer in same scope. |
| 292 | + **kwargs: Arguments passed to Keras BatchNorm layer. |
| 293 | +
|
| 294 | + Returns: |
| 295 | + Result of BatchNorm applied to inputs. |
| 296 | + """ |
| 297 | + if reuse_scope: |
| 298 | + scope_name = tf.get_variable_scope().name |
| 299 | + if scope_name in _BN_LAYER_CACHE: |
| 300 | + bn_layer = _BN_LAYER_CACHE[scope_name] |
| 301 | + logging.info('Reusing variable scope %s', scope_name) |
| 302 | + else: |
| 303 | + bn_layer = batch_norm_class(training, strategy)(**kwargs) |
| 304 | + _BN_LAYER_CACHE[scope_name] = bn_layer |
| 305 | + else: |
| 306 | + bn_layer = batch_norm_class(training, strategy)(**kwargs) |
273 | 307 | return bn_layer(inputs, training=training)
|
274 | 308 |
|
275 | 309 |
|
|
0 commit comments