|
25 | 25 | # pylint: disable=g-import-not-at-top |
26 | 26 | # Use Keras 2. |
27 | 27 | version_fn = getattr(tf.keras, "version", None) |
28 | | -if version_fn and version_fn().startswith("3."): |
29 | | - import tf_keras as keras |
30 | | -else: |
31 | | - keras = tf.keras |
| 28 | +# Always align with tf.keras to avoid mismatched Layer types |
| 29 | +keras = tf.keras |
32 | 30 |
|
33 | 31 | # pylint: disable=g-direct-tensorflow-import |
34 | 32 | from tensorflow.python.framework import smart_cond |
@@ -210,11 +208,34 @@ def _setup_layer(self, trainable=False, **kwargs): |
210 | 208 | self.add_loss(self._call_loss_if_trainable(l)) # Supports callables. |
211 | 209 |
|
212 | 210 | def _add_existing_weight(self, weight, trainable=None): |
213 | | - """Calls add_weight() to register but not create an existing weight.""" |
214 | | - if trainable is None: trainable = weight.trainable |
215 | | - self.add_weight(name=weight.name, shape=weight.shape, dtype=weight.dtype, |
216 | | - trainable=trainable, experimental_autocast=False, |
217 | | - getter=lambda *_, **__: weight) |
| 211 | + """Registers an existing tf.Variable with this layer.""" |
| 212 | + if trainable is None: |
| 213 | + trainable = getattr(weight, "trainable", False) |
| 214 | + |
| 215 | + # Create custom weight lists if they don't exist |
| 216 | + if not hasattr(self, '_hub_trainable_weights'): |
| 217 | + self._hub_trainable_weights = [] |
| 218 | + if not hasattr(self, '_hub_non_trainable_weights'): |
| 219 | + self._hub_non_trainable_weights = [] |
| 220 | + |
| 221 | + # Add to appropriate list |
| 222 | + if trainable: |
| 223 | + self._hub_trainable_weights.append(weight) |
| 224 | + else: |
| 225 | + self._hub_non_trainable_weights.append(weight) |
| 226 | + @property |
| 227 | + def trainable_weights(self): |
| 228 | + """Override to include hub weights.""" |
| 229 | + base_weights = super().trainable_weights |
| 230 | + hub_weights = getattr(self, '_hub_trainable_weights', []) |
| 231 | + return base_weights + hub_weights |
| 232 | + |
| 233 | + @property |
| 234 | + def non_trainable_weights(self): |
| 235 | + """Override to include hub weights.""" |
| 236 | + base_weights = super().non_trainable_weights |
| 237 | + hub_weights = getattr(self, '_hub_non_trainable_weights', []) |
| 238 | + return base_weights + hub_weights |
218 | 239 |
|
219 | 240 | def _call_loss_if_trainable(self, loss): |
220 | 241 | """Returns `loss` conditioned on whether this layer is trainable.""" |
@@ -338,6 +359,7 @@ def get_config(self): |
338 | 359 | if not isinstance(self._handle, str): |
339 | 360 | # Need to raise this type in order for tf.saved_model.save() to fall back |
340 | 361 | # to not using config, instead of crashing. |
| 362 | + # TODO(b/134528831): Reconsider the usability implications. |
341 | 363 | raise NotImplementedError( |
342 | 364 | "Can only generate a valid config for `hub.KerasLayer(handle, ...)`" |
343 | 365 | "that uses a string `handle`.\n\n" |
|
0 commit comments