Skip to content

Commit 3908042

Browse files
committed
Fixes #101201
- Replaced deprecated _add_existing_weight implementation that previously relied on unsupported arguments (experimental_autocast, getter) removed in Keras 3.x. - Introduced custom weight lists (_hub_trainable_weights, _hub_non_trainable_weights) instead of using the deprecated _track_variable approach. Overrode trainable_weights and non_trainable_weights properties to properly expose hub weights to the Keras layer system. - Resolves the ValueError: Only instances of keras.Layer can be added to a Sequential model and related integration failures when using hub.KerasLayer in Sequential/Functional models.
1 parent 26c94f6 commit 3908042

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

tensorflow_hub/keras_layer.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@
2525
# pylint: disable=g-import-not-at-top
2626
# Use Keras 2.
2727
version_fn = getattr(tf.keras, "version", None)
28-
if version_fn and version_fn().startswith("3."):
29-
import tf_keras as keras
30-
else:
31-
keras = tf.keras
28+
# Always align with tf.keras to avoid mismatched Layer types
29+
keras = tf.keras
3230

3331
# pylint: disable=g-direct-tensorflow-import
3432
from tensorflow.python.framework import smart_cond
@@ -210,11 +208,34 @@ def _setup_layer(self, trainable=False, **kwargs):
210208
self.add_loss(self._call_loss_if_trainable(l)) # Supports callables.
211209

212210
def _add_existing_weight(self, weight, trainable=None):
213-
"""Calls add_weight() to register but not create an existing weight."""
214-
if trainable is None: trainable = weight.trainable
215-
self.add_weight(name=weight.name, shape=weight.shape, dtype=weight.dtype,
216-
trainable=trainable, experimental_autocast=False,
217-
getter=lambda *_, **__: weight)
211+
"""Registers an existing tf.Variable with this layer."""
212+
if trainable is None:
213+
trainable = getattr(weight, "trainable", False)
214+
215+
# Create custom weight lists if they don't exist
216+
if not hasattr(self, '_hub_trainable_weights'):
217+
self._hub_trainable_weights = []
218+
if not hasattr(self, '_hub_non_trainable_weights'):
219+
self._hub_non_trainable_weights = []
220+
221+
# Add to appropriate list
222+
if trainable:
223+
self._hub_trainable_weights.append(weight)
224+
else:
225+
self._hub_non_trainable_weights.append(weight)
226+
@property
227+
def trainable_weights(self):
228+
"""Override to include hub weights."""
229+
base_weights = super().trainable_weights
230+
hub_weights = getattr(self, '_hub_trainable_weights', [])
231+
return base_weights + hub_weights
232+
233+
@property
234+
def non_trainable_weights(self):
235+
"""Override to include hub weights."""
236+
base_weights = super().non_trainable_weights
237+
hub_weights = getattr(self, '_hub_non_trainable_weights', [])
238+
return base_weights + hub_weights
218239

219240
def _call_loss_if_trainable(self, loss):
220241
"""Returns `loss` conditioned on whether this layer is trainable."""
@@ -338,6 +359,7 @@ def get_config(self):
338359
if not isinstance(self._handle, str):
339360
# Need to raise this type in order for tf.saved_model.save() to fall back
340361
# to not using config, instead of crashing.
362+
# TODO(b/134528831): Reconsider the usability implications.
341363
raise NotImplementedError(
342364
"Can only generate a valid config for `hub.KerasLayer(handle, ...)`"
343365
"that uses a string `handle`.\n\n"

0 commit comments

Comments
 (0)