From a134096462813fad72a85a6291f6d62bdd15c3dd Mon Sep 17 00:00:00 2001 From: jburnim Date: Wed, 29 Jul 2020 11:48:28 -0700 Subject: [PATCH 1/2] Make distribution layer serialization compatible with CloudPickle >= 1.3. Checked that distribution_layer_test passes with: - CloudPickle 1.3.0, 1.4.1, and 1.5.0 . - Python 3.5 and 3.8 . Fixes https://github.com/tensorflow/probability/issues/991 . Thanks to https://github.com/matthewfeickert and https://github.com/ogrisel for helping with this issue! PiperOrigin-RevId: 323834575 --- setup.py | 2 +- tensorflow_probability/python/layers/BUILD | 1 + .../python/layers/distribution_layer.py | 53 +++++++++---------- .../python/layers/distribution_layer_test.py | 8 ++- testing/install_test_dependencies.sh | 2 +- 5 files changed, 33 insertions(+), 33 deletions(-) diff --git a/setup.py b/setup.py index 9c4167363c..7f57cd99a1 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ 'six >= 1.10.0', 'numpy >= 1.13.3', 'decorator', - 'cloudpickle == 1.3', # TODO(b/155109696): Unpin cloudpickle version. + 'cloudpickle >= 1.3', 'gast >= 0.3.2', # For autobatching 'dm-tree' # For NumPy/JAX backends (hence, also for prefer_static) ] diff --git a/tensorflow_probability/python/layers/BUILD b/tensorflow_probability/python/layers/BUILD index 6e53087864..f93b9f172a 100644 --- a/tensorflow_probability/python/layers/BUILD +++ b/tensorflow_probability/python/layers/BUILD @@ -163,6 +163,7 @@ py_test( shard_count = 5, deps = [ # numpy dep, + # six dep, # tensorflow dep, "//tensorflow_probability", "//tensorflow_probability/python/internal:test_util", diff --git a/tensorflow_probability/python/layers/distribution_layer.py b/tensorflow_probability/python/layers/distribution_layer.py index 9c7e66059f..35bd8e039a 100644 --- a/tensorflow_probability/python/layers/distribution_layer.py +++ b/tensorflow_probability/python/layers/distribution_layer.py @@ -25,7 +25,7 @@ import pickle # Dependency imports -from cloudpickle.cloudpickle import CloudPickler +from cloudpickle import CloudPickler import numpy as np import six import tensorflow.compat.v2 as tf @@ -2020,37 +2020,32 @@ def _get_convert_to_tensor_fn(identifier): 'convert-to-tensor function identifier:', identifier) -class _TensorCloudPickler(CloudPickler): - """Subclass of `CloudPickler` that includes pickling of `Tensor` objects.""" - - def __init__(self, out_file, protocol=None): - CloudPickler.__init__(self, out_file, protocol) - - @staticmethod - def save_tensor(cloud_pickler, tensor, name=None): - val = tf.get_static_value(tensor) - if val is None: - raise ValueError('Cannot pickle Tensor -- ' - 'its value is not known statically: {}.'.format(tensor)) - CloudPickler.save_reduce(cloud_pickler, np.array, (val,)) - - def inject_addons(self): - tensor_class = tf.convert_to_tensor(1.).__class__ - CloudPickler.dispatch[tensor_class] = _TensorCloudPickler.save_tensor - - @staticmethod - def dumps(obj, protocol=None): - out_file = io.BytesIO() - try: - _TensorCloudPickler(out_file, protocol).dump(obj) - return out_file.getvalue() - finally: - out_file.close() +def _reduce_tensor(tensor): + val = tf.get_static_value(tensor) + if val is None: + raise ValueError('Cannot pickle Tensor -- ' + 'its value is not known statically: {}.'.format(tensor)) + return (tf.convert_to_tensor, (val,)) def _serialize_function(func): - raw_code = _TensorCloudPickler.dumps(func) - return codecs.encode(raw_code, 'base64').decode('ascii') + """Serializes a function (using CloudPickle).""" + buffer = io.BytesIO() + pickler = CloudPickler(buffer) + + # Serializing a DistributionLambda or other distribution layer may require + # serializaing a lambda or function that closes over a constant, graph-mode + # Tensor, but graph-mode Tensors do not support pickling. We modify + # `pickler.dispatch_table` so that a special reduction function will be used + # for graph-mode Tensors, which will: + # - Correctly serialize constant graph-mode Tensors. + # - Raise an explanatory error message for non-constant graph-mode Tensors. + if not hasattr(pickler, 'dispatch_table'): + pickler.dispatch_table = {} + pickler.dispatch_table[tf.Tensor] = _reduce_tensor + + pickler.dump(func) + return codecs.encode(buffer.getvalue(), 'base64').decode('ascii') def _deserialize_function(code): diff --git a/tensorflow_probability/python/layers/distribution_layer_test.py b/tensorflow_probability/python/layers/distribution_layer_test.py index 253d00834d..77e7a6eb0e 100644 --- a/tensorflow_probability/python/layers/distribution_layer_test.py +++ b/tensorflow_probability/python/layers/distribution_layer_test.py @@ -21,6 +21,7 @@ # Dependency imports import numpy as np +import six import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf import tensorflow_probability as tfp @@ -396,8 +397,7 @@ def test_serialization_static_method(self): lambda t: DistributionLambdaSerializationTest._make_distribution(t)) ]) model.compile(optimizer='adam', loss='mse') - # TODO(b/138375951): Re-enable this test. - # self.assertSerializable(model, batch_size=3) + self.assertSerializable(model, batch_size=3) model = tfk.Sequential([ tfkl.Dense(15, input_shape=(5,)), @@ -410,6 +410,10 @@ def test_serialization_static_method(self): self.assertExportable(model) def test_serialization_closure_over_lambdas_tensors_and_numpy_array(self): + if six.PY2 and not tf.executing_eagerly(): + self.skipTest('Serialization of constant graph-mode Tensors is not ' + 'supported under Python 2.') + num_components = np.array(3) one = tf.convert_to_tensor(1) mk_ind_norm = lambda event_shape: tfpl.IndependentNormal(event_shape + one) diff --git a/testing/install_test_dependencies.sh b/testing/install_test_dependencies.sh index 2cadc4fa9a..04e823f800 100755 --- a/testing/install_test_dependencies.sh +++ b/testing/install_test_dependencies.sh @@ -157,7 +157,7 @@ install_python_packages() { python -m pip install $PIP_FLAGS hypothesis==3.56.5 coverage==4.4.2 matplotlib mock scipy # Install additional TFP dependencies. - python -m pip install $PIP_FLAGS decorator cloudpickle==1.3 dm-tree # TODO(b/155109696): Unpin cloudpickle version. + python -m pip install $PIP_FLAGS decorator 'cloudpickle>=1.3' dm-tree # Upgrade numpy to the latest to address issues that happen when testing with # Python 3 (https://github.com/tensorflow/tensorflow/issues/16488). From db26e9a3f5c05ebbec0e102f701ff0440e6437f2 Mon Sep 17 00:00:00 2001 From: Christopher Suter Date: Mon, 28 Sep 2020 19:17:09 -0400 Subject: [PATCH 2/2] Update to explicit jax version requirements for 0.11 branch --- setup.py | 2 +- testing/install_test_dependencies.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 7f57cd99a1..85e7222ce7 100644 --- a/setup.py +++ b/setup.py @@ -105,7 +105,7 @@ def has_ext_modules(self): ], keywords='tensorflow probability statistics bayesian machine learning', extras_require={ # e.g. `pip install tfp-nightly[jax]` - 'jax': ['jax', 'jaxlib'], + 'jax': ['jax==0.1.74', 'jaxlib==0.1.52'], 'tfds': [TFDS_PACKAGE], } ) diff --git a/testing/install_test_dependencies.sh b/testing/install_test_dependencies.sh index 04e823f800..416c5dfd19 100755 --- a/testing/install_test_dependencies.sh +++ b/testing/install_test_dependencies.sh @@ -150,7 +150,7 @@ install_python_packages() { python -m pip install $PIP_FLAGS $TF_NIGHTLY_PACKAGE==$TF_VERSION_STR # For the JAX backend. - python -m pip install jax jaxlib + python -m pip install jax==0.1.74 jaxlib==0.1.52 # The following unofficial dependencies are used only by tests. # TODO(b/148685448): Unpin Hypothesis and coverage versions.