Skip to content

Commit 6a5c4df

Browse files
Add scope reuse support.
PiperOrigin-RevId: 415459931
1 parent 91c5436 commit 6a5c4df

File tree

2 files changed

+47
-7
lines changed

2 files changed

+47
-7
lines changed

tensorflow_examples/lite/model_maker/third_party/efficientdet/efficientdet_arch.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,16 @@ def fuse_features(nodes, weight_method):
440440
nodes = tf.stack(nodes, axis=-1)
441441
new_node = tf.reduce_sum(nodes * normalized_weights, -1)
442442
elif weight_method == 'fastattn':
443-
edge_weights = [
444-
tf.nn.relu(tf.cast(tf.Variable(1.0, name='WSM'), dtype=dtype))
445-
for _ in nodes
446-
]
443+
edge_weights = []
444+
for i, _ in enumerate(nodes):
445+
edge_weights.append(
446+
tf.nn.relu(
447+
tf.get_variable(
448+
'WSM' if i == 0 else f'WSM_{i}',
449+
shape=[],
450+
initializer=tf.ones_initializer(),
451+
dtype=dtype)))
452+
447453
weights_sum = tf.add_n(edge_weights)
448454
nodes = [nodes[i] * edge_weights[i] / (weights_sum + 0.0001)
449455
for i in range(len(nodes))]

tensorflow_examples/lite/model_maker/third_party/efficientdet/utils.py

+37-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
from typing import Text, Tuple, Union
1919
from absl import logging
20+
import gin
2021
import numpy as np
2122
import tensorflow.compat.v1 as tf
2223
import tensorflow.compat.v2 as tf2
@@ -266,10 +267,43 @@ def batch_norm_class(is_training, strategy=None):
266267
else:
267268
return BatchNormalization
268269

270+
# A cache of variable scope to BatchNorm layer.
271+
_BN_LAYER_CACHE = {}
269272

270-
def batch_normalization(inputs, training=False, strategy=None, **kwargs):
271-
"""A wrapper for TpuBatchNormalization."""
272-
bn_layer = batch_norm_class(training, strategy)(**kwargs)
273+
274+
@gin.configurable
275+
def batch_normalization(inputs,
276+
training=False,
277+
strategy=None,
278+
reuse_scope: bool = False,
279+
**kwargs):
280+
"""A wrapper for TpuBatchNormalization.
281+
282+
Keras layers are incompatible with automatic tf scope reuse.
283+
284+
Supports reuse of the exiting variable scope when a model is called multiple
285+
times. Otherwise, checkpoint weights would not be restored correctly.
286+
287+
Args:
288+
inputs: Input to BatchNorm layer.
289+
training: Argument of Keras BatchNorm layer.
290+
strategy: Argument of Keras BatchNorm layer.
291+
reuse_scope: Whether to reuse existing layer in same scope.
292+
**kwargs: Arguments passed to Keras BatchNorm layer.
293+
294+
Returns:
295+
Result of BatchNorm applied to inputs.
296+
"""
297+
if reuse_scope:
298+
scope_name = tf.get_variable_scope().name
299+
if scope_name in _BN_LAYER_CACHE:
300+
bn_layer = _BN_LAYER_CACHE[scope_name]
301+
logging.info('Reusing variable scope %s', scope_name)
302+
else:
303+
bn_layer = batch_norm_class(training, strategy)(**kwargs)
304+
_BN_LAYER_CACHE[scope_name] = bn_layer
305+
else:
306+
bn_layer = batch_norm_class(training, strategy)(**kwargs)
273307
return bn_layer(inputs, training=training)
274308

275309

0 commit comments

Comments
 (0)