From ce07070394c732e3fa497d99657be168b298deb5 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Sun, 8 Nov 2020 16:16:48 -0800 Subject: [PATCH 1/3] Revert "Remove subclassing of `tf.Module` by `DistributionLambda`, since `tfkl.Layer` now subclasses `tf.Module`. Re-enable tests that were broken by Keras change." This reverts commit 9fa0534bed7ab66e9d27042838cf62eda50eff39. NOTE: `tfkl.Layer` does not subclass `tf.Module` in TF 2.4.0rc0. --- .../python/layers/distribution_layer.py | 34 +++++++++++++++---- .../python/layers/distribution_layer_test.py | 3 +- .../distribution_tensor_coercible_test.py | 3 +- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/tensorflow_probability/python/layers/distribution_layer.py b/tensorflow_probability/python/layers/distribution_layer.py index c57a688b31..383ca3fca4 100644 --- a/tensorflow_probability/python/layers/distribution_layer.py +++ b/tensorflow_probability/python/layers/distribution_layer.py @@ -97,7 +97,8 @@ def _event_size(event_shape, name=None): return tf.reduce_prod(event_shape) -class DistributionLambda(tf.keras.layers.Lambda): +# We mix-in `tf.Module` since Keras base class doesn't track tf.Modules. +class DistributionLambda(tf.keras.layers.Lambda, tf.Module): """Keras layer enabling plumbing TFP distributions through Keras models. A `DistributionLambda` is minimially characterized by a function that returns @@ -204,12 +205,12 @@ def _fn(*fargs, **fkwargs): super(DistributionLambda, self).__init__(_fn, **kwargs) # We need to ensure Keras tracks variables (eg, from activity regularizers - # for type-II MLE). To accomplish this, we add the built distribution - # variables and kwargs as members so `vars` picks them up (this is how - # tf.Module implements its introspection). + # for type-II MLE). To accomplish this, we add the built distribution and + # kwargs as members so `vars` picks them up (this is how tf.Module + # implements its introspection). # Note also that we track all variables to support the user pattern: # `v.initializer for v in model.variable]`. - self._most_recently_built_distribution_vars = None + self._most_recently_built_distribution = None self._kwargs = kwargs self._make_distribution_fn = make_distribution_fn @@ -220,6 +221,25 @@ def _fn(*fargs, **fkwargs): # `keras.Sequential` way. self._enter_dunder_call = False + @property + def trainable_weights(self): + # We will append additional weights to what is already discovered from + # tensorflow/python/keras/engine/base_layer.py. + # Note: that in Keras-land "weights" is the source of truth for "variables." + from_keras = super(DistributionLambda, self).trainable_weights + from_module = list(tf.Module.trainable_variables.fget(self)) + return self._dedup_weights(from_keras + from_module) + + @property + def non_trainable_weights(self): + # We will append additional weights to what is already discovered from + # tensorflow/python/keras/engine/base_layer.py. + # Note: that in Keras-land "weights" is the source of truth for "variables." + from_keras = super(DistributionLambda, self).non_trainable_weights + from_module = [v for v in tf.Module.variables.fget(self) + if not getattr(v, 'trainable', True)] + return self._dedup_weights(from_keras + from_module) + def __call__(self, inputs, *args, **kwargs): self._enter_dunder_call = True distribution, _ = super(DistributionLambda, self).__call__( @@ -230,9 +250,9 @@ def __call__(self, inputs, *args, **kwargs): def call(self, inputs, *args, **kwargs): distribution, value = super(DistributionLambda, self).call( inputs, *args, **kwargs) - # We always save the most recently built distribution variables for tracking + # We always save the most recently built distribution for variable tracking # purposes. - self._most_recently_built_distribution_vars = distribution.variables + self._most_recently_built_distribution = distribution if self._enter_dunder_call: # Its critical to return both distribution and concretization # so Keras can inject `_keras_history` to both. This is what enables diff --git a/tensorflow_probability/python/layers/distribution_layer_test.py b/tensorflow_probability/python/layers/distribution_layer_test.py index 4cf61ca76c..8b58b2a4af 100644 --- a/tensorflow_probability/python/layers/distribution_layer_test.py +++ b/tensorflow_probability/python/layers/distribution_layer_test.py @@ -1603,7 +1603,8 @@ def kernel(self): @test_util.test_graph_and_eager_modes class JointDistributionLayer(test_util.TestCase): - def test_works(self): + # TODO(b/171812768): Investigate failure caused by Keras tracking tf.Modules. + def DISABLED_test_works(self): x = tf.keras.Input(shape=()) y = tfp.layers.VariableLayer(shape=[2, 4, 3], dtype=tf.float32)(x) y = tf.keras.layers.Dense(5, use_bias=False)(y) diff --git a/tensorflow_probability/python/layers/internal/distribution_tensor_coercible_test.py b/tensorflow_probability/python/layers/internal/distribution_tensor_coercible_test.py index 2296d1392b..062039b8fe 100644 --- a/tensorflow_probability/python/layers/internal/distribution_tensor_coercible_test.py +++ b/tensorflow_probability/python/layers/internal/distribution_tensor_coercible_test.py @@ -274,7 +274,8 @@ def testWhileLoopWithControlFlowV2(self): @test_util.test_all_tf_execution_regimes class MemoryLeakTest(test_util.TestCase): - def testTypeObjectLeakage(self): + # TODO(b/171812768): Investigate failure caused by Keras tracking tf.Modules. + def DISABLED_testTypeObjectLeakage(self): if not tf.executing_eagerly(): self.skipTest('only relevant to eager') From 01449aac502237cba59535408b13a63d78e10a9c Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Sun, 8 Nov 2020 16:25:00 -0800 Subject: [PATCH 2/3] Increase TF version dependency from 2.3 to 2.4. --- tensorflow_probability/python/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_probability/python/__init__.py b/tensorflow_probability/python/__init__.py index 5c79173fb5..bfccbe631e 100644 --- a/tensorflow_probability/python/__init__.py +++ b/tensorflow_probability/python/__init__.py @@ -50,7 +50,7 @@ def _ensure_tf_install(): # # Update this whenever we need to depend on a newer TensorFlow release. # - required_tensorflow_version = '2.3' + required_tensorflow_version = '2.4' # required_tensorflow_version = '1.15' # Needed internally -- DisableOnExport if (distutils.version.LooseVersion(tf.__version__) < From 2b7ed30efc90819094d78c28ae1cc2e77eaa9669 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Sun, 8 Nov 2020 16:26:20 -0800 Subject: [PATCH 3/3] Set the version for the TFP 0.12-rc0 release. --- tensorflow_probability/python/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_probability/python/version.py b/tensorflow_probability/python/version.py index 5b200cb89e..a76151793c 100644 --- a/tensorflow_probability/python/version.py +++ b/tensorflow_probability/python/version.py @@ -24,7 +24,7 @@ # stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a # release branch, the current version is by default assumed to be a # 'development' version, labeled 'dev'. -_VERSION_SUFFIX = 'dev' +_VERSION_SUFFIX = 'rc0' # Example, '0.4.0-dev' __version__ = '.'.join([