Skip to content

Commit

Permalink
Merge pull request #1091 from csuter/r0.11
Browse files Browse the repository at this point in the history
Update jax to explicit version number on r0.11 branch
  • Loading branch information
csuter authored Sep 28, 2020
2 parents 7cf006d + db26e9a commit 994c944
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 35 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
'six >= 1.10.0',
'numpy >= 1.13.3',
'decorator',
'cloudpickle == 1.3', # TODO(b/155109696): Unpin cloudpickle version.
'cloudpickle >= 1.3',
'gast >= 0.3.2', # For autobatching
'dm-tree' # For NumPy/JAX backends (hence, also for prefer_static)
]
Expand Down Expand Up @@ -105,7 +105,7 @@ def has_ext_modules(self):
],
keywords='tensorflow probability statistics bayesian machine learning',
extras_require={ # e.g. `pip install tfp-nightly[jax]`
'jax': ['jax', 'jaxlib'],
'jax': ['jax==0.1.74', 'jaxlib==0.1.52'],
'tfds': [TFDS_PACKAGE],
}
)
1 change: 1 addition & 0 deletions tensorflow_probability/python/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ py_test(
shard_count = 5,
deps = [
# numpy dep,
# six dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:test_util",
Expand Down
53 changes: 24 additions & 29 deletions tensorflow_probability/python/layers/distribution_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pickle

# Dependency imports
from cloudpickle.cloudpickle import CloudPickler
from cloudpickle import CloudPickler
import numpy as np
import six
import tensorflow.compat.v2 as tf
Expand Down Expand Up @@ -2020,37 +2020,32 @@ def _get_convert_to_tensor_fn(identifier):
'convert-to-tensor function identifier:', identifier)


class _TensorCloudPickler(CloudPickler):
"""Subclass of `CloudPickler` that includes pickling of `Tensor` objects."""

def __init__(self, out_file, protocol=None):
CloudPickler.__init__(self, out_file, protocol)

@staticmethod
def save_tensor(cloud_pickler, tensor, name=None):
val = tf.get_static_value(tensor)
if val is None:
raise ValueError('Cannot pickle Tensor -- '
'its value is not known statically: {}.'.format(tensor))
CloudPickler.save_reduce(cloud_pickler, np.array, (val,))

def inject_addons(self):
tensor_class = tf.convert_to_tensor(1.).__class__
CloudPickler.dispatch[tensor_class] = _TensorCloudPickler.save_tensor

@staticmethod
def dumps(obj, protocol=None):
out_file = io.BytesIO()
try:
_TensorCloudPickler(out_file, protocol).dump(obj)
return out_file.getvalue()
finally:
out_file.close()
def _reduce_tensor(tensor):
val = tf.get_static_value(tensor)
if val is None:
raise ValueError('Cannot pickle Tensor -- '
'its value is not known statically: {}.'.format(tensor))
return (tf.convert_to_tensor, (val,))


def _serialize_function(func):
raw_code = _TensorCloudPickler.dumps(func)
return codecs.encode(raw_code, 'base64').decode('ascii')
"""Serializes a function (using CloudPickle)."""
buffer = io.BytesIO()
pickler = CloudPickler(buffer)

# Serializing a DistributionLambda or other distribution layer may require
# serializaing a lambda or function that closes over a constant, graph-mode
# Tensor, but graph-mode Tensors do not support pickling. We modify
# `pickler.dispatch_table` so that a special reduction function will be used
# for graph-mode Tensors, which will:
# - Correctly serialize constant graph-mode Tensors.
# - Raise an explanatory error message for non-constant graph-mode Tensors.
if not hasattr(pickler, 'dispatch_table'):
pickler.dispatch_table = {}
pickler.dispatch_table[tf.Tensor] = _reduce_tensor

pickler.dump(func)
return codecs.encode(buffer.getvalue(), 'base64').decode('ascii')


def _deserialize_function(code):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# Dependency imports

import numpy as np
import six
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
Expand Down Expand Up @@ -396,8 +397,7 @@ def test_serialization_static_method(self):
lambda t: DistributionLambdaSerializationTest._make_distribution(t))
])
model.compile(optimizer='adam', loss='mse')
# TODO(b/138375951): Re-enable this test.
# self.assertSerializable(model, batch_size=3)
self.assertSerializable(model, batch_size=3)

model = tfk.Sequential([
tfkl.Dense(15, input_shape=(5,)),
Expand All @@ -410,6 +410,10 @@ def test_serialization_static_method(self):
self.assertExportable(model)

def test_serialization_closure_over_lambdas_tensors_and_numpy_array(self):
if six.PY2 and not tf.executing_eagerly():
self.skipTest('Serialization of constant graph-mode Tensors is not '
'supported under Python 2.')

num_components = np.array(3)
one = tf.convert_to_tensor(1)
mk_ind_norm = lambda event_shape: tfpl.IndependentNormal(event_shape + one)
Expand Down
4 changes: 2 additions & 2 deletions testing/install_test_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,14 @@ install_python_packages() {
python -m pip install $PIP_FLAGS $TF_NIGHTLY_PACKAGE==$TF_VERSION_STR

# For the JAX backend.
python -m pip install jax jaxlib
python -m pip install jax==0.1.74 jaxlib==0.1.52

# The following unofficial dependencies are used only by tests.
# TODO(b/148685448): Unpin Hypothesis and coverage versions.
python -m pip install $PIP_FLAGS hypothesis==3.56.5 coverage==4.4.2 matplotlib mock scipy

# Install additional TFP dependencies.
python -m pip install $PIP_FLAGS decorator cloudpickle==1.3 dm-tree # TODO(b/155109696): Unpin cloudpickle version.
python -m pip install $PIP_FLAGS decorator 'cloudpickle>=1.3' dm-tree

# Upgrade numpy to the latest to address issues that happen when testing with
# Python 3 (https://github.com/tensorflow/tensorflow/issues/16488).
Expand Down

0 comments on commit 994c944

Please sign in to comment.