Skip to content

Commit 994c944

Browse files
authored
Merge pull request #1091 from csuter/r0.11
Update jax to explicit version number on r0.11 branch
2 parents 7cf006d + db26e9a commit 994c944

File tree

5 files changed

+35
-35
lines changed

5 files changed

+35
-35
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
'six >= 1.10.0',
3232
'numpy >= 1.13.3',
3333
'decorator',
34-
'cloudpickle == 1.3', # TODO(b/155109696): Unpin cloudpickle version.
34+
'cloudpickle >= 1.3',
3535
'gast >= 0.3.2', # For autobatching
3636
'dm-tree' # For NumPy/JAX backends (hence, also for prefer_static)
3737
]
@@ -105,7 +105,7 @@ def has_ext_modules(self):
105105
],
106106
keywords='tensorflow probability statistics bayesian machine learning',
107107
extras_require={ # e.g. `pip install tfp-nightly[jax]`
108-
'jax': ['jax', 'jaxlib'],
108+
'jax': ['jax==0.1.74', 'jaxlib==0.1.52'],
109109
'tfds': [TFDS_PACKAGE],
110110
}
111111
)

tensorflow_probability/python/layers/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ py_test(
163163
shard_count = 5,
164164
deps = [
165165
# numpy dep,
166+
# six dep,
166167
# tensorflow dep,
167168
"//tensorflow_probability",
168169
"//tensorflow_probability/python/internal:test_util",

tensorflow_probability/python/layers/distribution_layer.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pickle
2626

2727
# Dependency imports
28-
from cloudpickle.cloudpickle import CloudPickler
28+
from cloudpickle import CloudPickler
2929
import numpy as np
3030
import six
3131
import tensorflow.compat.v2 as tf
@@ -2020,37 +2020,32 @@ def _get_convert_to_tensor_fn(identifier):
20202020
'convert-to-tensor function identifier:', identifier)
20212021

20222022

2023-
class _TensorCloudPickler(CloudPickler):
2024-
"""Subclass of `CloudPickler` that includes pickling of `Tensor` objects."""
2025-
2026-
def __init__(self, out_file, protocol=None):
2027-
CloudPickler.__init__(self, out_file, protocol)
2028-
2029-
@staticmethod
2030-
def save_tensor(cloud_pickler, tensor, name=None):
2031-
val = tf.get_static_value(tensor)
2032-
if val is None:
2033-
raise ValueError('Cannot pickle Tensor -- '
2034-
'its value is not known statically: {}.'.format(tensor))
2035-
CloudPickler.save_reduce(cloud_pickler, np.array, (val,))
2036-
2037-
def inject_addons(self):
2038-
tensor_class = tf.convert_to_tensor(1.).__class__
2039-
CloudPickler.dispatch[tensor_class] = _TensorCloudPickler.save_tensor
2040-
2041-
@staticmethod
2042-
def dumps(obj, protocol=None):
2043-
out_file = io.BytesIO()
2044-
try:
2045-
_TensorCloudPickler(out_file, protocol).dump(obj)
2046-
return out_file.getvalue()
2047-
finally:
2048-
out_file.close()
2023+
def _reduce_tensor(tensor):
2024+
val = tf.get_static_value(tensor)
2025+
if val is None:
2026+
raise ValueError('Cannot pickle Tensor -- '
2027+
'its value is not known statically: {}.'.format(tensor))
2028+
return (tf.convert_to_tensor, (val,))
20492029

20502030

20512031
def _serialize_function(func):
2052-
raw_code = _TensorCloudPickler.dumps(func)
2053-
return codecs.encode(raw_code, 'base64').decode('ascii')
2032+
"""Serializes a function (using CloudPickle)."""
2033+
buffer = io.BytesIO()
2034+
pickler = CloudPickler(buffer)
2035+
2036+
# Serializing a DistributionLambda or other distribution layer may require
2037+
# serializaing a lambda or function that closes over a constant, graph-mode
2038+
# Tensor, but graph-mode Tensors do not support pickling. We modify
2039+
# `pickler.dispatch_table` so that a special reduction function will be used
2040+
# for graph-mode Tensors, which will:
2041+
# - Correctly serialize constant graph-mode Tensors.
2042+
# - Raise an explanatory error message for non-constant graph-mode Tensors.
2043+
if not hasattr(pickler, 'dispatch_table'):
2044+
pickler.dispatch_table = {}
2045+
pickler.dispatch_table[tf.Tensor] = _reduce_tensor
2046+
2047+
pickler.dump(func)
2048+
return codecs.encode(buffer.getvalue(), 'base64').decode('ascii')
20542049

20552050

20562051
def _deserialize_function(code):

tensorflow_probability/python/layers/distribution_layer_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
# Dependency imports
2222

2323
import numpy as np
24+
import six
2425
import tensorflow.compat.v1 as tf1
2526
import tensorflow.compat.v2 as tf
2627
import tensorflow_probability as tfp
@@ -396,8 +397,7 @@ def test_serialization_static_method(self):
396397
lambda t: DistributionLambdaSerializationTest._make_distribution(t))
397398
])
398399
model.compile(optimizer='adam', loss='mse')
399-
# TODO(b/138375951): Re-enable this test.
400-
# self.assertSerializable(model, batch_size=3)
400+
self.assertSerializable(model, batch_size=3)
401401

402402
model = tfk.Sequential([
403403
tfkl.Dense(15, input_shape=(5,)),
@@ -410,6 +410,10 @@ def test_serialization_static_method(self):
410410
self.assertExportable(model)
411411

412412
def test_serialization_closure_over_lambdas_tensors_and_numpy_array(self):
413+
if six.PY2 and not tf.executing_eagerly():
414+
self.skipTest('Serialization of constant graph-mode Tensors is not '
415+
'supported under Python 2.')
416+
413417
num_components = np.array(3)
414418
one = tf.convert_to_tensor(1)
415419
mk_ind_norm = lambda event_shape: tfpl.IndependentNormal(event_shape + one)

testing/install_test_dependencies.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,14 @@ install_python_packages() {
150150
python -m pip install $PIP_FLAGS $TF_NIGHTLY_PACKAGE==$TF_VERSION_STR
151151

152152
# For the JAX backend.
153-
python -m pip install jax jaxlib
153+
python -m pip install jax==0.1.74 jaxlib==0.1.52
154154

155155
# The following unofficial dependencies are used only by tests.
156156
# TODO(b/148685448): Unpin Hypothesis and coverage versions.
157157
python -m pip install $PIP_FLAGS hypothesis==3.56.5 coverage==4.4.2 matplotlib mock scipy
158158

159159
# Install additional TFP dependencies.
160-
python -m pip install $PIP_FLAGS decorator cloudpickle==1.3 dm-tree # TODO(b/155109696): Unpin cloudpickle version.
160+
python -m pip install $PIP_FLAGS decorator 'cloudpickle>=1.3' dm-tree
161161

162162
# Upgrade numpy to the latest to address issues that happen when testing with
163163
# Python 3 (https://github.com/tensorflow/tensorflow/issues/16488).

0 commit comments

Comments
 (0)