From bbb60712548032db16158f0c5ecbd14cacb16b76 Mon Sep 17 00:00:00 2001 From: bjp Date: Tue, 8 Dec 2020 11:35:16 -0800 Subject: [PATCH 01/36] Enforce PY3 for multi-substrate targets. Change PY2AND3 to PY3 throughout. PiperOrigin-RevId: 346372466 --- tensorflow_probability/python/build_defs.bzl | 3 +++ tensorflow_probability/python/experimental/distribute/BUILD | 2 +- tensorflow_probability/python/experimental/mcmc/BUILD | 6 +++--- tensorflow_probability/python/experimental/nn/BUILD | 2 +- tensorflow_probability/python/experimental/util/BUILD | 4 ++-- tensorflow_probability/python/internal/backend/numpy/BUILD | 2 +- tensorflow_probability/python/math/BUILD | 6 +++--- 7 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tensorflow_probability/python/build_defs.bzl b/tensorflow_probability/python/build_defs.bzl index 7e66c9c826..7559e764c7 100644 --- a/tensorflow_probability/python/build_defs.bzl +++ b/tensorflow_probability/python/build_defs.bzl @@ -183,6 +183,9 @@ def multi_substrate_py_library( srcs_version: As with `py_library`. """ + if srcs_version != "PY3": + fail("Must use PY3 for srcs_version", srcs_version) + native.py_library( name = name, srcs = srcs, diff --git a/tensorflow_probability/python/experimental/distribute/BUILD b/tensorflow_probability/python/experimental/distribute/BUILD index dd24845a69..f0c7898f50 100644 --- a/tensorflow_probability/python/experimental/distribute/BUILD +++ b/tensorflow_probability/python/experimental/distribute/BUILD @@ -26,7 +26,7 @@ package( py_library( name = "distribute", srcs = ["__init__.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":distribute_lib", ":joint_distribution", diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index 8c87d50e75..035b8cda78 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -134,7 +134,7 @@ py_test( py_library( name = "kernel_outputs", srcs = ["kernel_outputs.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":tracing_reducer", "//tensorflow_probability/python/internal:unnest", @@ -191,7 +191,7 @@ py_test( py_library( name = "preconditioned_hmc", srcs = ["preconditioned_hmc.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ # tensorflow dep, "//tensorflow_probability/python/distributions:independent", @@ -225,7 +225,7 @@ py_test( py_library( name = "progress_bar_reducer", srcs = ["progress_bar_reducer.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":reducer", # tensorflow dep, diff --git a/tensorflow_probability/python/experimental/nn/BUILD b/tensorflow_probability/python/experimental/nn/BUILD index 959b560aa8..5b831a5b29 100644 --- a/tensorflow_probability/python/experimental/nn/BUILD +++ b/tensorflow_probability/python/experimental/nn/BUILD @@ -105,7 +105,7 @@ py_test( py_library( name = "convolutional_layers_v2", srcs = ["convolutional_layers_v2.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":layers", ":variational_base", diff --git a/tensorflow_probability/python/experimental/util/BUILD b/tensorflow_probability/python/experimental/util/BUILD index f605c210be..849114ba9e 100644 --- a/tensorflow_probability/python/experimental/util/BUILD +++ b/tensorflow_probability/python/experimental/util/BUILD @@ -33,7 +33,7 @@ exports_files(["LICENSE"]) multi_substrate_py_library( name = "util", srcs = ["__init__.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", substrates_omit_deps = [ ":deferred_module", ], @@ -45,7 +45,7 @@ multi_substrate_py_library( py_library( name = "deferred_module", srcs = ["deferred_module.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ # numpy dep, # tensorflow dep, diff --git a/tensorflow_probability/python/internal/backend/numpy/BUILD b/tensorflow_probability/python/internal/backend/numpy/BUILD index 5e7c466f46..d157f36c79 100644 --- a/tensorflow_probability/python/internal/backend/numpy/BUILD +++ b/tensorflow_probability/python/internal/backend/numpy/BUILD @@ -363,7 +363,7 @@ py_library( py_library( name = "numpy_testlib", testonly = 1, - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":numpy", # absl/testing:parameterized dep, diff --git a/tensorflow_probability/python/math/BUILD b/tensorflow_probability/python/math/BUILD index ad8b133d0e..eb66b47871 100644 --- a/tensorflow_probability/python/math/BUILD +++ b/tensorflow_probability/python/math/BUILD @@ -66,7 +66,7 @@ multi_substrate_py_library( srcs = [ "bessel.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ # numpy dep, # tensorflow dep, @@ -119,7 +119,7 @@ multi_substrate_py_library( srcs = [ "gram_schmidt.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ # tensorflow dep, "//tensorflow_probability/python/internal:prefer_static", @@ -199,7 +199,7 @@ multi_substrate_py_library( srcs = [ "hypergeometric.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ # numpy dep, # tensorflow dep, From b76652e59da5d75c66e326bae2549f984b26aa86 Mon Sep 17 00:00:00 2001 From: Dave Moore Date: Tue, 8 Dec 2020 11:43:58 -0800 Subject: [PATCH 02/36] Add arg for custom parameter-generating strategy in distribution Hypothesis tests. PiperOrigin-RevId: 346374625 --- .../distributions/hypothesis_testlib.py | 80 +++++++++---------- .../python/internal/hypothesis_testlib.py | 21 +++-- 2 files changed, 55 insertions(+), 46 deletions(-) diff --git a/tensorflow_probability/python/distributions/hypothesis_testlib.py b/tensorflow_probability/python/distributions/hypothesis_testlib.py index 3247d8aef6..62f2a2a529 100644 --- a/tensorflow_probability/python/distributions/hypothesis_testlib.py +++ b/tensorflow_probability/python/distributions/hypothesis_testlib.py @@ -572,31 +572,6 @@ def stringify_slices(slices): return pretty_slices -@hps.composite -def broadcasting_params(draw, - dist_name, - batch_shape, - event_dim=None, - enable_vars=False): - """Strategy for drawing parameters broadcasting to `batch_shape`.""" - if dist_name not in INSTANTIABLE_BASE_DISTS: - raise ValueError('Unknown Distribution name {}'.format(dist_name)) - - params_event_ndims = INSTANTIABLE_BASE_DISTS[dist_name].params_event_ndims - - def _constraint(param): - return constraint_for(dist_name, param) - - return draw( - tfp_hps.broadcasting_params( - batch_shape, - params_event_ndims, - event_dim=event_dim, - enable_vars=enable_vars, - constraint_fn_for=_constraint, - mutex_params=MUTEX_PARAMS)) - - def prime_factors(v): """Compute the prime factors of v.""" factors = [] @@ -639,6 +614,7 @@ def base_distribution_unconstrained_params(draw, batch_shape=None, event_dim=None, enable_vars=False, + param_strategy_fn=None, params=None): """Strategy for drawing unconstrained parameters of a base Distribution. @@ -660,6 +636,10 @@ def base_distribution_unconstrained_params(draw, initialization in slicing_test. If `False`, the returned parameters are all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor` `tfp.util.TransformedVariable`}. + param_strategy_fn: Optional callable with signature + `strategy = param_strategy_fn(shape, dtype, constraint_fn)`. If provided, + overrides the default strategy for generating float-valued parameters. + Default value: `None`. params: An optional set of Distribution parameters. If params are not provided, Hypothesis will choose a set of parameters. @@ -675,11 +655,21 @@ def base_distribution_unconstrained_params(draw, batch_shape = draw(tfp_hps.shapes()) # Draw raw parameters + if dist_name not in INSTANTIABLE_BASE_DISTS: + raise ValueError('Unknown Distribution name {}'.format(dist_name)) + params_event_ndims = INSTANTIABLE_BASE_DISTS[dist_name].params_event_ndims + params_kwargs = draw( - broadcasting_params( - dist_name, batch_shape, event_dim=event_dim, enable_vars=enable_vars)) - hp.note('Forming dist {} with raw parameters {}'.format( - dist_name, params_kwargs)) + tfp_hps.broadcasting_params( + batch_shape, + params_event_ndims, + event_dim=event_dim, + enable_vars=enable_vars, + constraint_fn_for=lambda param: constraint_for(dist_name, param), + mutex_params=MUTEX_PARAMS, + param_strategy_fn=param_strategy_fn)) + hp.note('Forming dist {} with raw parameters {}'.format(dist_name, + params_kwargs)) return params_kwargs, batch_shape @@ -732,8 +722,9 @@ def base_distributions(draw, event_dim=None, enable_vars=False, eligibility_filter=lambda name: True, - validate_args=True, - params=None): + params=None, + param_strategy_fn=None, + validate_args=True): """Strategy for drawing arbitrary base Distributions. This does not draw compound distributions like `Independent`, @@ -756,9 +747,13 @@ def base_distributions(draw, `tfp.util.TransformedVariable`}. eligibility_filter: Optional Python callable. Blacklists some Distribution class names so they will not be drawn at the top level. - validate_args: Python `bool`; whether to enable runtime assertions. params: An optional set of Distribution parameters. If params are not provided, Hypothesis will choose a set of parameters. + param_strategy_fn: Optional callable with signature + `strategy = param_strategy_fn(shape, dtype, constraint_fn)`. If provided, + overrides the default strategy for generating float-valued parameters. + Default value: `None`. + validate_args: Python `bool`; whether to enable runtime assertions. Returns: dists: A strategy for drawing Distributions with the specified `batch_shape` @@ -780,13 +775,14 @@ class names so they will not be drawn at the top level. if params is None: params_unconstrained, batch_shape = draw( - base_distribution_unconstrained_params(dist_name, - batch_shape=batch_shape, - event_dim=event_dim, - enable_vars=enable_vars)) + base_distribution_unconstrained_params( + dist_name, + batch_shape=batch_shape, + event_dim=event_dim, + enable_vars=enable_vars, + param_strategy_fn=param_strategy_fn)) params = constrain_params(params_unconstrained, dist_name) - params = modify_params( - params, dist_name, validate_args=validate_args) + params = modify_params(params, dist_name, validate_args=validate_args) # Actually construct the distribution dist_cls = INSTANTIABLE_BASE_DISTS[dist_name].cls result_dist = dist_cls(**params) @@ -1436,8 +1432,12 @@ class names so they will not be drawn. or dist_name in INSTANTIABLE_BASE_DISTS or dist_name == 'Empirical'): return draw(base_distributions( - dist_name, batch_shape, event_dim, enable_vars, - eligibility_filter, validate_args)) + dist_name, + batch_shape=batch_shape, + event_dim=event_dim, + enable_vars=enable_vars, + eligibility_filter=eligibility_filter, + validate_args=validate_args)) if dist_name == 'BatchReshape': return draw(batch_reshapes( batch_shape, event_dim, enable_vars, depth, diff --git a/tensorflow_probability/python/internal/hypothesis_testlib.py b/tensorflow_probability/python/internal/hypothesis_testlib.py index a7b50bafa3..2b9c12c1cf 100644 --- a/tensorflow_probability/python/internal/hypothesis_testlib.py +++ b/tensorflow_probability/python/internal/hypothesis_testlib.py @@ -435,6 +435,7 @@ def broadcasting_params(draw, enable_vars=False, constraint_fn_for=lambda param: identity_fn, mutex_params=(), + param_strategy_fn=None, dtype=np.float32): """Streategy for drawing parameters which jointly have the given batch shape. @@ -467,6 +468,10 @@ def broadcasting_params(draw, mutually exclusive parameters (e.g., the 'probs' and 'logits' of a Categorical). At most one parameter from each set will appear in the result. + param_strategy_fn: Optional callable with signature + `strategy = param_strategy_fn(shape, dtype, constraint_fn)`. If provided, + overrides the default strategy for generating float-valued parameters. + Default value: `constrained_tensors`. dtype: Dtype for generated parameters. Returns: @@ -479,6 +484,8 @@ def broadcasting_params(draw, """ if event_dim is None: event_dim = draw(hps.integers(min_value=2, max_value=6)) + if param_strategy_fn is None: + param_strategy_fn = constrained_tensors params_event_ndims = params_event_ndims or {} remaining_params = set(params_event_ndims.keys()) @@ -504,12 +511,14 @@ def broadcasting_params(draw, hp.assume(len(param_shape) < 6) # TODO(axch): Can I replace `params_event_ndims` and `constraint_fn_for` - # with a map from params to `Suppport`s, and use `tensors_in_support` here - # instead of this explicit `constrained_tensors` function? - param_strategy = constrained_tensors( - constraint_fn_for(param), param_shape, dtype=dtype) - params_kwargs[param] = draw(maybe_variable( - param_strategy, enable_vars=enable_vars, dtype=dtype, name=param)) + # with a map from params to `Suppport`s, and use `tensors_in_support` here? + param_strategy = param_strategy_fn(constraint_fn=constraint_fn_for(param), + shape=param_shape, + dtype=dtype) + params_kwargs[param] = draw(maybe_variable(param_strategy, + enable_vars=enable_vars, + dtype=dtype, + name=param)) return params_kwargs From aae8cc413186a3ffd3ab8e30fbf4b938d482fa53 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 8 Dec 2020 11:50:24 -0800 Subject: [PATCH 03/36] Implement stopping ratio logistic distribution Hello, this PR again implements the stopping ratio logistic distribution which was already discussed and reviewed [here](https://github.com/tensorflow/probability/pull/963). @srvasude , sorry for closing the other PR. I worked in your requested changes other than the transpose that needs to be done when sampling. Thanks. Cheers, Simon COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/probability/pull/990 from dirmeier:stopping_ratio 0743b5421344d656109b21411e48e31f2943aab1 PiperOrigin-RevId: 346376187 --- .../python/distributions/BUILD | 90 +++-- .../python/distributions/__init__.py | 2 + .../distributions/stopping_ratio_logistic.py | 363 ++++++++++++++++++ .../stopping_ratio_logistic_test.py | 142 +++++++ 4 files changed, 567 insertions(+), 30 deletions(-) create mode 100644 tensorflow_probability/python/distributions/stopping_ratio_logistic.py create mode 100644 tensorflow_probability/python/distributions/stopping_ratio_logistic_test.py diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index 7c0c1fb2e6..36eb9dc026 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -131,6 +131,7 @@ multi_substrate_py_library( ":sinh_arcsinh", ":skellam", ":spherical_uniform", + ":stopping_ratio_logistic", ":student_t", ":student_t_process", ":transformed_distribution", @@ -573,6 +574,21 @@ multi_substrate_py_library( ], ) +multi_substrate_py_library( + name = "exponentially_modified_gaussian", + srcs = ["exponentially_modified_gaussian.py"], + deps = [ + ":distribution", + ":exponential", + ":normal", + # tensorflow dep, + "//tensorflow_probability/python/internal:dtype_util", + "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:reparameterization", + "//tensorflow_probability/python/internal:tensor_util", + ], +) + multi_substrate_py_library( name = "finite_discrete", srcs = ["finite_discrete.py"], @@ -1748,6 +1764,23 @@ multi_substrate_py_library( ], ) +multi_substrate_py_library( + name = "stopping_ratio_logistic", + srcs = ["stopping_ratio_logistic.py"], + deps = [ + ":distribution", + ":kullback_leibler", + # tensorflow dep, + "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:distribution_util", + "//tensorflow_probability/python/internal:dtype_util", + "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:reparameterization", + "//tensorflow_probability/python/internal:tensor_util", + "//tensorflow_probability/python/internal:tensorshape_util", + ], +) + multi_substrate_py_library( name = "half_student_t", srcs = ["half_student_t.py"], @@ -2414,6 +2447,21 @@ multi_substrate_py_test( ], ) +multi_substrate_py_test( + name = "exponentially_modified_gaussian_test", + srcs = ["exponentially_modified_gaussian_test.py"], + jax_size = "medium", + # Disable numpy test for now because a bug in the types returned by special_math.ndtr + numpy_tags = ["notap"], + deps = [ + # numpy dep, + # scipy dep, + # tensorflow dep, + "//tensorflow_probability", + "//tensorflow_probability/python/internal:test_util", + ], +) + multi_substrate_py_test( name = "finite_discrete_test", size = "medium", @@ -3383,6 +3431,18 @@ multi_substrate_py_test( ], ) +multi_substrate_py_test( + name = "stopping_ratio_logistic_test", + srcs = ["stopping_ratio_logistic_test.py"], + deps = [ + # absl/testing:parameterized dep, + # numpy dep, + # tensorflow dep, + "//tensorflow_probability", + "//tensorflow_probability/python/internal:test_util", + ], +) + multi_substrate_py_test( name = "half_student_t_test", size = "medium", @@ -3750,33 +3810,3 @@ py_binary( "//tensorflow_probability/python/distributions:hypothesis_testlib", ], ) - -multi_substrate_py_library( - name = "exponentially_modified_gaussian", - srcs = ["exponentially_modified_gaussian.py"], - deps = [ - ":distribution", - ":exponential", - ":normal", - # tensorflow dep, - "//tensorflow_probability/python/internal:dtype_util", - "//tensorflow_probability/python/internal:prefer_static", - "//tensorflow_probability/python/internal:reparameterization", - "//tensorflow_probability/python/internal:tensor_util", - ], -) - -multi_substrate_py_test( - name = "exponentially_modified_gaussian_test", - srcs = ["exponentially_modified_gaussian_test.py"], - jax_size = "medium", - # Disable numpy test for now because a bug in the types returned by special_math.ndtr - numpy_tags = ["notap"], - deps = [ - # numpy dep, - # scipy dep, - # tensorflow dep, - "//tensorflow_probability", - "//tensorflow_probability/python/internal:test_util", - ], -) diff --git a/tensorflow_probability/python/distributions/__init__.py b/tensorflow_probability/python/distributions/__init__.py index 8a9f81e225..e97b16edf4 100644 --- a/tensorflow_probability/python/distributions/__init__.py +++ b/tensorflow_probability/python/distributions/__init__.py @@ -109,6 +109,7 @@ from tensorflow_probability.python.distributions.sinh_arcsinh import SinhArcsinh from tensorflow_probability.python.distributions.skellam import Skellam from tensorflow_probability.python.distributions.spherical_uniform import SphericalUniform +from tensorflow_probability.python.distributions.stopping_ratio_logistic import StoppingRatioLogistic from tensorflow_probability.python.distributions.student_t import StudentT from tensorflow_probability.python.distributions.student_t_process import StudentTProcess from tensorflow_probability.python.distributions.transformed_distribution import TransformedDistribution @@ -223,6 +224,7 @@ 'SinhArcsinh', 'Skellam', 'SphericalUniform', + 'StoppingRatioLogistic', 'StudentT', 'StudentTProcess', 'Triangular', diff --git a/tensorflow_probability/python/distributions/stopping_ratio_logistic.py b/tensorflow_probability/python/distributions/stopping_ratio_logistic.py new file mode 100644 index 0000000000..49d4d7266c --- /dev/null +++ b/tensorflow_probability/python/distributions/stopping_ratio_logistic.py @@ -0,0 +1,363 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The stopping ratio logistic distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python import math as tfp_math +from tensorflow_probability.python.distributions import categorical +from tensorflow_probability.python.distributions import distribution +from tensorflow_probability.python.distributions import kullback_leibler +from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import distribution_util +from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import prefer_static +from tensorflow_probability.python.internal import reparameterization +from tensorflow_probability.python.internal import tensor_util +from tensorflow_probability.python.internal import tensorshape_util + + +class StoppingRatioLogistic(distribution.Distribution): + """Stopping ratio logistic distribution. + + The StoppingRatioLogistic distribution is parameterized by a location and a + set of non-decreasing cutpoints. It is defined over the integers + `{0, 1, ..., K}` for `K` non-decreasing cutpoints. + + The difference to the OrderedLogistic is that categories can only be reached + one after another, i.e., sequentially. Specifically, while the probability + of an ordinal random variable `X` to be in category `c` + for the OrderedLogistic reads as + + ```none + P(X = c; cutpoints, loc) = P(X > c - 1) - P(X > c) + = sigmoid(loc - concat([-inf, cutpoints, inf])[c]) - + sigmoid(loc - concat([-inf, cutpoints, inf])[c + 1]) + ``` + + the StoppingRatioLogistic distribution models the probability of an ordinal + random variable `X` to be in category `c` given `X >= c` as + + ```none + P(X = c; X >= c, cutpoints, loc) = sigmoid(cutpoints[c] - loc) + ``` + + The sequential mechanism for `X` starts in category `c = 0` where a binary + decision between `c = 0` and `c > 0` is made: + + ```none + P(X = 0; cutpoints, loc) = sigmoid(cutpoints[0] - loc) + ``` + + If `X = 0`, the process stops. Otherwise the process continues with + + ```none + P(X = 1; X >= 1, cutpoints, loc) = sigmoid(cutpoints[1] - loc) + ``` + + The process continues to move on to higher level categories until it stops at + some category `X = c`. + + This distribution is useful for ordinal variables where lower categories + need to be reached first, for instance modelling the degree of a person + where the categories are `[Bachelor, Master, PhD]`. In order to obtain a PhD + title, first the degrees `Bachelor` and `Master` need to be reached. + + #### Mathematical Details + + The probability mass function (pmf) is + + ```none + pmf(x; cutpoints, loc) = + sigmoid(cutpoints[x] - loc) * + prod_{s=0}^{x - 1} (1 - sigmoid(cutpoints[s] - loc)) + ``` + + where `loc` is the location of a latent logistic distribution and + `cutpoints` define points to split up this latent distribution. + + #### Examples + + To expand on the `[Bachelor, Master, PhD]` from above, create a distribution + of three ordered categories: + + ```python + import tensorflow_probability as tfp + tfd = tfp.distributions + + dist = tfd.StoppingRatioLogistic(cutpoints=[-1.0, 1.0], loc=0.) + + dist.categorical_probs() + # ==> array([0.2689414 0.53444666 0.19661193], dtype=float32) + ``` + + Here, the probability of finishing one's education with a Bachelor would be + approx. 26% in this example, while the probability of continuing to pursue + a Master's would be approx. 53% and the probability of even attaining a PhD + would be 20%. + + Some further functionality: + + ```python + dist = tfd.StoppingRatioLogistic(cutpoints=[-2., 0., 2.], loc=0.) + + dist.prob([0, 3]) + # ==> array([0.11920291, 0.05249681], dtype=float32) + + dist.log_prob(1) + # ==> -0.82007515 + + dist.sample(3) + # ==> array([2, 1, 2], dtype=int32) + ``` + + """ + + def __init__( + self, + cutpoints, + loc, + dtype=tf.int32, + validate_args=False, + allow_nan_stats=True, + name='StoppingRatioLogistic', + ): + """Initialize Stopping Ratio Logistic distributions. + + Args: + cutpoints: A floating-point `Tensor` with shape `(K,)` where + `K` is the number of cutpoints. The vector of cutpoints should be + non-decreasing, which is only checked if `validate_args=True`. + loc: A floating-point `Tensor` with shape `()`. The entry represents the + mean of the latent logistic distribution. + dtype: The type of the event samples (default: int32). + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g. mode) use the value "`NaN`" to indicate the result is + undefined. When `False`, an exception is raised if one or more of the + statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + + parameters = dict(locals()) + + with tf.name_scope(name) as name: + + float_dtype = dtype_util.common_dtype( + [cutpoints, loc], + dtype_hint=tf.float32) + + self._cutpoints = tensor_util.convert_nonref_to_tensor( + cutpoints, dtype_hint=float_dtype, name='cutpoints') + self._loc = tensor_util.convert_nonref_to_tensor( + loc, dtype_hint=float_dtype, name='loc') + + super(StoppingRatioLogistic, self).__init__( + dtype=dtype, + reparameterization_type=reparameterization.NOT_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + name=name) + + @classmethod + def _params_event_ndims(cls): + return dict(cutpoints=1, loc=0) + + @staticmethod + def _param_shapes(sample_shape): + return dict( + zip(('loc', 'scale'), + ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 2))) + + @property + def cutpoints(self): + """Cutpoints param separating the latent distribution into categories.""" + return self._cutpoints + + @property + def loc(self): + """Mean parameter of the latent logistic distribution.""" + return self._loc + + def categorical_log_probs(self): + """Log probabilities for the `K + 1` sequential categories.""" + + cutpoints = tf.convert_to_tensor(self.cutpoints) + loc = tf.convert_to_tensor(self.loc) + num_cat = self._num_categories() + + # For the StoppingRatioLogistic, we have: + # P(X = c; X >= c, cutpoints, loc) = sigmoid(cutpoints[c] - loc) + # Given these conditional probabilities, we would like to retrieve + # P(X = c; cutpoints, loc). + # Let F(c) = P(X = c; X >= c, cutpoints, loc) and + # G(c) = P(X = c; cutpoints, loc) + + # Conditional probabilities. These are log(F(k)) and log(1 - F(k)) + conditional_log_probs = tf.math.log_sigmoid( + cutpoints - loc[..., tf.newaxis]) + conditional_log_probs_complement = tfp_math.log1mexp(conditional_log_probs) + + # Note that F(0) = G(0). + # G(1) = P(X = 1; cutpoints, loc) = + # P(X = 1; X >= 1, cutpoints, loc) * P(X >= 1) = F(1) * (1 - G(0)) + # G(2) = P(X = 2; cutpoints, loc) = + # P(X = 2; X >= 2, cutpoints, loc) * P(X >= 2) = F(2) * (1 - G(0) - G(1)) + # In general, G(k) = F(k) * (1 - \sum_{k-1} G(i)) + + # We rewrite this recurrence in terms of F(k) + # G(1) = F(1) * (1 - G(0)) = F(1) * (1 - F(0)) + # G(2) = F(2) * (1 - G(0) - G(1)) = (1 - F(0) - F(1) * (1 - F(0)) + # = F(2) * (1 - F(0)) * (1 - F(1)) + # G(k) = F(k) * \prod_{k-1} (1 - F(i)) + + # log(F(k)) + log(\prod (1 - F(i))) + categorical_log_probs = conditional_log_probs + tf.math.cumsum( + conditional_log_probs_complement[..., :(num_cat - 1)], + axis=-1, exclusive=True) + # Finally we need to handle the last category. + return tf.concat([ + categorical_log_probs, + tf.math.reduce_sum( + conditional_log_probs_complement[ + ..., :num_cat], axis=-1, keepdims=True)], axis=-1) + + def categorical_probs(self): + """Probabilities for the `K + 1` sequential categories.""" + return tf.math.exp(self.categorical_log_probs()) + + def _num_categories(self): + return prefer_static.shape(self.cutpoints, out_type=self.dtype)[-1] + 1 + + def _sample_n(self, n, seed=None): + return categorical.Categorical( + logits=self.categorical_log_probs()).sample(n, seed=seed) + + def _batch_shape_tensor(self, cutpoints=None, loc=None): + cutpoints = self.cutpoints if cutpoints is None else cutpoints + loc = self.loc if loc is None else loc + return prefer_static.broadcast_shape( + prefer_static.shape(cutpoints)[:-1], + prefer_static.shape(loc)) + + def _batch_shape(self): + return tf.broadcast_static_shape( + self.loc.shape, self.cutpoints.shape[:-1]) + + def _event_shape_tensor(self): + return tf.constant([], dtype=tf.int32) + + def _event_shape(self): + return tf.TensorShape([]) + + def _log_prob(self, x): + return categorical.Categorical( + logits=self.categorical_log_probs()).log_prob(x) + + def _cdf(self, x): + return categorical.Categorical( + logits=self.categorical_log_probs()).cdf(x) + + def _mode(self): + log_probs = self.categorical_log_probs() + mode = tf.argmax(log_probs, axis=-1, output_type=self.dtype) + tensorshape_util.set_shape(mode, log_probs.shape[:-1]) + return mode + + def _default_event_space_bijector(self): + return + + def _parameter_control_dependencies(self, is_init): + assertions = [] + + # In init, we can always build shape and dtype checks because + # we assume shape doesn't change for Variable backed args. + if is_init: + + if not dtype_util.is_floating(self.cutpoints.dtype): + raise TypeError('Argument `cutpoints` must having floating type.') + + if not dtype_util.is_floating(self.loc.dtype): + raise TypeError('Argument `loc` must having floating type.') + + cutpoint_dims = tensorshape_util.rank(self.cutpoints.shape) + msg = 'Argument `cutpoints` must have rank at least 1.' + if cutpoint_dims is not None: + if cutpoint_dims < 1: + raise ValueError(msg) + elif self.validate_args: + cutpoints = tf.convert_to_tensor(self.cutpoints) + assertions.append( + assert_util.assert_rank_at_least(cutpoints, 1, message=msg)) + + if not self.validate_args: + return [] + + if is_init != tensor_util.is_ref(self.cutpoints): + cutpoints = tf.convert_to_tensor(self.cutpoints) + assertions.append(distribution_util.assert_nondecreasing( + cutpoints, message='Argument `cutpoints` must be non-decreasing.')) + + return assertions + + def _sample_control_dependencies(self, x): + assertions = [] + if not self.validate_args: + return assertions + assertions.extend(distribution_util.assert_nonnegative_integer_form(x)) + assertions.append( + assert_util.assert_less_equal( + x, tf.cast(self._num_categories(), x.dtype), + message=('StoppingRatioLogistic samples must be `>= 0` and `<= K` ' + 'where `K` is the number of cutpoints.'))) + return assertions + + +@kullback_leibler.RegisterKL(StoppingRatioLogistic, StoppingRatioLogistic) +def _kl_stopping_ratio_logistic_stopping_ratio_logistic(a, b, name=None): + """Calculate the batched KL divergence KL(a || b), both StoppingRatioLogistic. + + This function utilises the `StoppingRatioLogistic` `categorical_log_probs` + member function to implement KL divergence for discrete probability + distributions as described in + e.g. [Wikipedia](https://en.wikipedia.org/wiki/Kullback-Leibler_divergence). + + Args: + a: instance of a StoppingRatioLogistic distribution object. + b: instance of a StoppingRatioLogistic distribution object. + name: Python `str` name to use for created operations. + Default value: `None` + + Returns: + Batchwise KL(a || b) + """ + with tf.name_scope(name or + 'kl_stopping_ratio_logistic_stopping_ratio_logistic'): + a_log_probs = a.categorical_log_probs() + b_log_probs = b.categorical_log_probs() + return tf.reduce_sum( + tf.math.multiply_no_nan( + tf.math.exp(a_log_probs), + a_log_probs - b_log_probs), + axis=-1) + diff --git a/tensorflow_probability/python/distributions/stopping_ratio_logistic_test.py b/tensorflow_probability/python/distributions/stopping_ratio_logistic_test.py new file mode 100644 index 0000000000..419548bbb1 --- /dev/null +++ b/tensorflow_probability/python/distributions/stopping_ratio_logistic_test.py @@ -0,0 +1,142 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +# Dependency imports +from absl.testing import parameterized +import numpy as np +import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp + +from tensorflow_probability.python.internal import test_util + +tfd = tfp.distributions +tfb = tfp.bijectors + + +@test_util.test_all_tf_execution_regimes +class StoppingRatioLogisticTest(test_util.TestCase): + + def _random_cutpoints(self, shape): + return self._ordered.inverse(self._rng.randn(*shape)) + + def _random_location(self, shape): + return self._rng.randn(*shape) + + def _random_rvs(self, shape): + return self._rng.multinomial(1, *shape) + + def setUp(self): + self._ordered = tfb.Ordered() + self._rng = np.random.RandomState(test_util.test_seed()) + super(StoppingRatioLogisticTest, self).setUp() + + @parameterized.parameters( + itertools.product(['cutpoints', 'loc', 'both'], [[], [1], [1, 2, 3]]) + ) + def testBatchShapes(self, test, batch_shape): + if test == 'cutpoints': + cutpoints = self._random_cutpoints(batch_shape + [2]) + loc = tf.constant(0., dtype=tf.float64) + elif test == 'loc': + cutpoints = tf.constant([1., 2.], dtype=tf.float64) + loc = self._random_location(batch_shape) + elif test == 'both': + cutpoints = self._random_cutpoints(batch_shape + [2]) + loc = self._random_location(batch_shape) + + dist = tfd.StoppingRatioLogistic(cutpoints=cutpoints, loc=loc) + + self.assertAllEqual(dist.batch_shape, batch_shape) + self.assertAllEqual( + self.evaluate(dist.batch_shape_tensor()), batch_shape) + + self.assertAllEqual(dist.event_shape, []) + self.assertAllEqual(self.evaluate(dist.event_shape_tensor()), []) + + categorical_probs = dist.categorical_probs() + categorical_probs_shape = tf.shape(categorical_probs) + self.assertAllEqual( + self.evaluate(categorical_probs_shape), batch_shape + [3]) + + samples = dist.sample(seed=test_util.test_seed()) + sample_shape = tf.shape(samples) + self.assertAllEqual(self.evaluate(sample_shape), batch_shape) + + probs = dist.prob(samples) + probs_shape = tf.shape(probs) + self.assertAllEqual(self.evaluate(probs_shape), batch_shape) + + samples = dist.sample([4, 5], seed=test_util.test_seed()) + sample_shape_n = tf.shape(samples) + self.assertAllEqual(self.evaluate(sample_shape_n), [4, 5] + batch_shape) + + probs = dist.prob(samples) + probs_shape = tf.shape(probs) + self.assertAllEqual(self.evaluate(probs_shape), [4, 5] + batch_shape) + + mode = dist.mode() + mode_shape = tf.shape(mode) + self.assertAllEqual(self.evaluate(mode_shape), batch_shape) + + def testProbs(self): + expected_probs = [0.11920291, 0.44039854, 0.38790172, 0.05249681] + dist = tfd.StoppingRatioLogistic(cutpoints=[-2., 0., 2.], loc=0.) + + categorical_probs = self.evaluate(dist.categorical_probs()) + self.assertAllClose(expected_probs, categorical_probs, atol=1e-4) + + probs = self.evaluate(dist.prob([0, 1, 2, 3])) + self.assertAllClose(expected_probs, probs, atol=1e-4) + + def testMode(self): + dist = tfd.StoppingRatioLogistic(cutpoints=[-10., 10.], loc=[-20., 0., 20.]) + mode = self.evaluate(dist.mode()) + self.assertAllEqual([0, 1, 2], mode) + + def testSample(self): + dist = tfd.StoppingRatioLogistic(cutpoints=[-1., 0., 1.], loc=0.) + samples = self.evaluate(dist.sample(int(1e5), seed=test_util.test_seed())) + expected_probs = [0.2689414, 0.3655293, 0.26722333, 0.09830596] + for k, p in enumerate(expected_probs): + self.assertAllClose(np.mean(samples == k), p, atol=0.01) + + def testKLAgainstSampling(self): + a_cutpoints = self._random_cutpoints([4]) + b_cutpoints = self._random_cutpoints([4]) + loc = self._random_location([]) + + a = tfd.StoppingRatioLogistic(cutpoints=a_cutpoints, loc=loc) + b = tfd.StoppingRatioLogistic(cutpoints=b_cutpoints, loc=loc) + + samples = a.sample(int(1e5), seed=test_util.test_seed()) + sampled_kl = self.evaluate( + tf.reduce_mean(a.log_prob(samples) - b.log_prob(samples))) + kl = self.evaluate(tfd.kl_divergence(a, b)) + + self.assertAllClose(sampled_kl, kl, atol=2e-2) + + def testUnorderedCutpointsFails(self): + with self.assertRaisesRegexp( + ValueError, 'Argument `cutpoints` must be non-decreasing.'): + dist = tfd.StoppingRatioLogistic( + cutpoints=[1., 0.9], loc=0.0, validate_args=True) + self.evaluate(dist.mode()) + +if __name__ == '__main__': + tf.test.main() From 9d3addd8606a499872f0ddd0012346a3c37f23a5 Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Tue, 8 Dec 2020 12:02:00 -0800 Subject: [PATCH 04/36] Add mpmath as a test dependency. PiperOrigin-RevId: 346378914 --- testing/install_test_dependencies.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/install_test_dependencies.sh b/testing/install_test_dependencies.sh index 5007fdb30f..1bdc028a67 100755 --- a/testing/install_test_dependencies.sh +++ b/testing/install_test_dependencies.sh @@ -178,7 +178,7 @@ install_python_packages() { # The following unofficial dependencies are used only by tests. # TODO(b/148685448): Unpin Hypothesis and coverage versions. - python -m pip install $PIP_FLAGS hypothesis==3.56.5 coverage==4.4.2 matplotlib mock scipy + python -m pip install $PIP_FLAGS hypothesis==3.56.5 coverage==4.4.2 matplotlib mock mpmath scipy # Install additional TFP dependencies. python -m pip install $PIP_FLAGS decorator 'cloudpickle>=1.3' dm-tree From 44431ddb2507e7c782dfb6ff72ea32014ebb1f3f Mon Sep 17 00:00:00 2001 From: TensorFlower Gardener Date: Tue, 8 Dec 2020 13:22:10 -0800 Subject: [PATCH 05/36] Merge pull request #1188 from SamuelMarks:args-for-google-style-docstrings PiperOrigin-RevId: 346395749 --- .../latent_dirichlet_allocation_distributions.py | 6 +++--- tensorflow_probability/examples/vae.py | 2 +- tensorflow_probability/python/bijectors/composition.py | 4 ++-- .../python/bijectors/masked_autoregressive.py | 4 ++-- tensorflow_probability/python/bijectors/real_nvp.py | 2 +- .../python/distributions/mixture_same_family.py | 6 +++--- .../python/distributions/von_mises.py | 8 ++++---- tensorflow_probability/python/distributions/zipf.py | 2 +- .../experimental/lazybones/utils/weak_container.py | 2 +- .../python/internal/backend/numpy/gen/tensor_shape.py | 4 ++-- tensorflow_probability/python/internal/cache_util.py | 10 +++++----- .../python/internal/prefer_static.py | 4 ++-- .../python/internal/test_combinations.py | 8 ++++---- .../python/layers/dense_variational_v2.py | 2 +- .../python/layers/distribution_layer.py | 2 +- .../python/layers/distribution_layer_test.py | 4 ++-- tensorflow_probability/python/layers/initializers.py | 2 +- .../python/layers/masked_autoregressive.py | 2 +- tensorflow_probability/python/layers/variable_input.py | 2 +- 19 files changed, 38 insertions(+), 38 deletions(-) diff --git a/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py b/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py index 5366f7659d..239239594a 100644 --- a/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py +++ b/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py @@ -248,7 +248,7 @@ def make_prior(num_topics, initial_value): def model_fn(features, labels, mode, params, config): """Build the model function for use in an estimator. - Arguments: + Args: features: The input features for the estimator. labels: The labels, unused here. mode: Signifies whether it is train or test or predict. @@ -353,7 +353,7 @@ def get_topics_strings(topics_words, alpha, vocabulary, topics_to_print=10, words_per_topic=10): """Returns the summary of the learned topics. - Arguments: + Args: topics_words: KxV tensor with topics as rows and words as columns. alpha: 1xK tensor of prior Dirichlet concentrations for the topics. @@ -464,7 +464,7 @@ def build_input_fns(data_dir, batch_size): Each object is represented as a bag-of-words vector. - Arguments: + Args: data_dir: Folder in which to store the data. batch_size: Batch size for both train and evaluation. Returns: diff --git a/tensorflow_probability/examples/vae.py b/tensorflow_probability/examples/vae.py index 43e3ce9e5c..66d7b5a2a2 100644 --- a/tensorflow_probability/examples/vae.py +++ b/tensorflow_probability/examples/vae.py @@ -325,7 +325,7 @@ def image_tile_summary(name, tensor, rows=8, cols=8): def model_fn(features, labels, mode, params, config): """Builds the model function for use in an estimator. - Arguments: + Args: features: The input features for the estimator. labels: The labels, unused here. mode: Signifies whether it is train or test or predict. diff --git a/tensorflow_probability/python/bijectors/composition.py b/tensorflow_probability/python/bijectors/composition.py index f26277a4c6..d4fd398e28 100644 --- a/tensorflow_probability/python/bijectors/composition.py +++ b/tensorflow_probability/python/bijectors/composition.py @@ -300,7 +300,7 @@ def _walk_forward(self, step_fn, argument, **kwargs): The `_walk_{direction}` methods define how arguments are routed through nested bijectors, expressing the directed topology of the underlying graph. - Arguments: + Args: step_fn: A method taking a bijector, a single positional argument matching `bijector.forward_min_event_ndims`, and arbitrary **kwargs, and returning a structure matching `bijector.inverse_min_event_ndims`. @@ -319,7 +319,7 @@ def _walk_inverse(self, step_fn, argument, **kwargs): The `_walk_{direction}` methods define how arguments are routed through nested bijectors, expressing the directed topology of the underlying graph. - Arguments: + Args: step_fn: A method taking a bijector, a single positional argument matching `bijector.inverse_min_event_ndims`, and arbitrary **kwargs, and returning a structure matching `bijector.forward_min_event_ndims`. diff --git a/tensorflow_probability/python/bijectors/masked_autoregressive.py b/tensorflow_probability/python/bijectors/masked_autoregressive.py index b6ee69e146..1230522692 100644 --- a/tensorflow_probability/python/bijectors/masked_autoregressive.py +++ b/tensorflow_probability/python/bijectors/masked_autoregressive.py @@ -435,7 +435,7 @@ def masked_dense(inputs, See [Germain et al. (2015)][1] for detailed explanation. - Arguments: + Args: inputs: Tensor input. units: Python `int` scalar representing the dimensionality of the output space. @@ -894,7 +894,7 @@ def __init__(self, **kwargs): """Constructs the MADE layer. - Arguments: + Args: params: Python integer specifying the number of parameters to output per input. event_shape: Python `list`-like of positive integers (or a single int), diff --git a/tensorflow_probability/python/bijectors/real_nvp.py b/tensorflow_probability/python/bijectors/real_nvp.py index 7645195606..9a9feeee8b 100644 --- a/tensorflow_probability/python/bijectors/real_nvp.py +++ b/tensorflow_probability/python/bijectors/real_nvp.py @@ -344,7 +344,7 @@ def real_nvp_default_template(hidden_layers, Real NVP bijector, implement a conditioned shift/scale template that handles the `condition_kwargs`. - Arguments: + Args: hidden_layers: Python `list`-like of non-negative integer, scalars indicating the number of units in each hidden layer. Default: `[512, 512]`. diff --git a/tensorflow_probability/python/distributions/mixture_same_family.py b/tensorflow_probability/python/distributions/mixture_same_family.py index 4c5c9d5f47..2d27201c6b 100644 --- a/tensorflow_probability/python/distributions/mixture_same_family.py +++ b/tensorflow_probability/python/distributions/mixture_same_family.py @@ -453,7 +453,7 @@ def _reparameterize_sample(self, x, event_shape): 3. Distributional transform currently only works for known rank of the batch tensor. - Arguments: + Args: x: Sample of mixture distribution event_shape: The event shape of this distribution @@ -512,7 +512,7 @@ def _distributional_transform(self, x, event_shape): w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1) and w_0^k = w_k is the mixture probability of the k-th component. - Arguments: + Args: x: Sample of mixture distribution event_shape: The event shape of this distribution @@ -682,7 +682,7 @@ def _prevent_2nd_derivative(x): NB: you need to apply a non-identity function to the output tensor for the exception to be raised. - Arguments: + Args: x: A tensor. Returns: diff --git a/tensorflow_probability/python/distributions/von_mises.py b/tensorflow_probability/python/distributions/von_mises.py index b712f6a896..0a64dab963 100644 --- a/tensorflow_probability/python/distributions/von_mises.py +++ b/tensorflow_probability/python/distributions/von_mises.py @@ -371,7 +371,7 @@ def von_mises_cdf(x, concentration): using automatic differentiation. We use forward mode for the series case (which allows to save memory) and backward mode for the Normal approximation. - Arguments: + Args: x: The point at which to evaluate the CDF. concentration: The concentration parameter of the von Mises distribution. @@ -498,7 +498,7 @@ def cdf_func(concentration): def _von_mises_sample_no_gradient(shape, concentration, seed): """Performs rejection sampling for standardized von Mises. - Arguments: + Args: shape: The output sample shape. concentration: The concentration parameter of the distribution. seed: The random seed. @@ -641,7 +641,7 @@ def _von_mises_sample_jvp(shape, primals, tangents): def _von_mises_sample_with_gradient(shape, concentration, seed): """Performs rejection sampling for standardized von Mises. - Arguments: + Args: shape: The output sample shape. concentration: The concentration parameter of the distribution. seed: (optional) The random seed. @@ -662,7 +662,7 @@ def random_von_mises(shape, concentration, dtype=tf.float32, seed=None): The sampling algorithm is rejection sampling with wrapped Cauchy proposal [1]. The samples are pathwise differentiable using the approach of [2]. - Arguments: + Args: shape: The output sample shape. concentration: The concentration parameter of the von Mises distribution. dtype: The data type of concentration and the outputs. diff --git a/tensorflow_probability/python/distributions/zipf.py b/tensorflow_probability/python/distributions/zipf.py index ec18c6c6e3..92f7d49dae 100644 --- a/tensorflow_probability/python/distributions/zipf.py +++ b/tensorflow_probability/python/distributions/zipf.py @@ -353,7 +353,7 @@ def _hat_integral(self, x, power): pmf. This function implements `hat` integral: H(x) = int_x^inf h(t) dt; which is needed for sampling purposes. - Arguments: + Args: x: A Tensor of points x at which to evaluate H(x). power: Power that parameterized hat function. diff --git a/tensorflow_probability/python/experimental/lazybones/utils/weak_container.py b/tensorflow_probability/python/experimental/lazybones/utils/weak_container.py index 4f385d54fd..055269503c 100644 --- a/tensorflow_probability/python/experimental/lazybones/utils/weak_container.py +++ b/tensorflow_probability/python/experimental/lazybones/utils/weak_container.py @@ -114,7 +114,7 @@ class HashableWeakRef(weakref.ref): def __init__(self, referrent, callback=None): """weakref.ref which makes any object hashable. - Arguments: + Args: referrent: Object that is being referred to. callback: Optional callback to invoke when object is GCed. """ diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py b/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py index d918546a0e..b9c53973ff 100755 --- a/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py @@ -143,7 +143,7 @@ def dimension_value(dimension): value = tensor_shape[i] # Warning: this will return the dim value in V2! ``` - Arguments: + Args: dimension: Either a `Dimension` instance, an integer, or None. Returns: @@ -189,7 +189,7 @@ def dimension_at_index(shape, index): # instantiated on the fly. ``` - Arguments: + Args: shape: A TensorShape instance. index: An integer index. diff --git a/tensorflow_probability/python/internal/cache_util.py b/tensorflow_probability/python/internal/cache_util.py index d2f3607492..7dda003016 100644 --- a/tensorflow_probability/python/internal/cache_util.py +++ b/tensorflow_probability/python/internal/cache_util.py @@ -91,7 +91,7 @@ class HashableWeakRef(weakref.ref): def __init__(self, referrent, callback=None): """weakref.ref which makes tf.Tensor and np.array objects hashable. - Arguments: + Args: referrent: Object that is being referred to. callback: Optional callback to invoke when object is GCed. """ @@ -327,7 +327,7 @@ def bijector_class(self): def forward(self, x, **kwargs): """Invokes the 'forward' transformation, or looks up previous results. - Arguments: + Args: x: The singular argument passed to `bijector._forward`. **kwargs: Any auxiliary arguments passed to the function. These reflect shared context to the function, and are associated @@ -340,7 +340,7 @@ def forward(self, x, **kwargs): def inverse(self, y, **kwargs): """Invokes the 'inverse' transformation, or looks up previous results. - Arguments: + Args: y: The singular argument passed to `bijector._inverse`. **kwargs: Any auxiliary arguments passed to the function. These reflect shared context to the function, and are associated @@ -431,7 +431,7 @@ def _attributes(self, input, fn_name, **kwargs): == 0) ``` - Arguments: + Args: input: The singular ordered argument passed to the wrapped function. fn_name: `str`, name of the directed function to which `input` is an arg (typically `'_forward'` or `'_inverse'`). @@ -469,7 +469,7 @@ def _lookup(self, input, forward_name, inverse_name, **kwargs): assert cache.inverse._lookup(y, '_inverse', '_forward') == (x, attrs) ``` - Arguments: + Args: input: The singular ordered argument passed to the wrapped function. forward_name: `str`, the name of the function implementing the bijector's forward transformation (typically `'_forward'`). diff --git a/tensorflow_probability/python/internal/prefer_static.py b/tensorflow_probability/python/internal/prefer_static.py index 48d576f9e8..b1d30bbd0d 100644 --- a/tensorflow_probability/python/internal/prefer_static.py +++ b/tensorflow_probability/python/internal/prefer_static.py @@ -205,7 +205,7 @@ def broadcast_shape(x_shape, y_shape): computed statically and returned as a `TensorShape`. Otherwise, a rank-1 `Tensor` will be returned. - Arguments: + Args: x_shape: A `TensorShape` or rank-1 integer `Tensor`. The input `Tensor` is broadcast against this shape. y_shape: A `TensorShape` or rank-1 integer `Tensor`. The input `Tensor` is @@ -230,7 +230,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): If `pred` is a bool or has a constant value, we return either `true_fn()` or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. - Arguments: + Args: pred: A scalar determining whether to return the result of `true_fn` or `false_fn`. true_fn: The callable to be performed if pred is true. diff --git a/tensorflow_probability/python/internal/test_combinations.py b/tensorflow_probability/python/internal/test_combinations.py index f436c8946e..127848f8d4 100644 --- a/tensorflow_probability/python/internal/test_combinations.py +++ b/tensorflow_probability/python/internal/test_combinations.py @@ -78,7 +78,7 @@ def should_execute_combination(self, kwargs): If the environment doesn't satisfy the dependencies of the test combination, then it can be skipped. - Arguments: + Args: kwargs: Arguments that are passed to the test combination. Returns: @@ -100,7 +100,7 @@ def context_managers(self, kwargs): The test combination will run under all context managers that all `TestCombination` instances return. - Arguments: + Args: kwargs: Arguments and their values that are passed to the test combination. @@ -119,7 +119,7 @@ class ParameterModifier(object): def __init__(self, parameter_name=None): """Construct a parameter modifier that may be specific to a parameter. - Arguments: + Args: parameter_name: A `ParameterModifier` instance may operate on a class of parameters or on a parameter with a particular name. Only `ParameterModifier` instances that are of a unique type or were @@ -135,7 +135,7 @@ def modified_arguments(self, kwargs, requested_parameters): This makes it possible to adjust user-provided arguments before passing them to the test method. - Arguments: + Args: kwargs: The combined arguments for the test. requested_parameters: The set of parameters that are defined in the signature of the test method. diff --git a/tensorflow_probability/python/layers/dense_variational_v2.py b/tensorflow_probability/python/layers/dense_variational_v2.py index 2bc6614a0c..116cb2a29d 100644 --- a/tensorflow_probability/python/layers/dense_variational_v2.py +++ b/tensorflow_probability/python/layers/dense_variational_v2.py @@ -53,7 +53,7 @@ def __init__(self, **kwargs): """Creates the `DenseVariational` layer. - Arguments: + Args: units: Positive integer, dimensionality of the output space. make_posterior_fn: Python callable taking `tf.size(kernel)`, `tf.size(bias)`, `dtype` and returns another callable which takes an diff --git a/tensorflow_probability/python/layers/distribution_layer.py b/tensorflow_probability/python/layers/distribution_layer.py index 383ca3fca4..191a33d0a5 100644 --- a/tensorflow_probability/python/layers/distribution_layer.py +++ b/tensorflow_probability/python/layers/distribution_layer.py @@ -1510,7 +1510,7 @@ def new(params, num_components, component_layer, def params_size(num_components, component_params_size, name=None): """Number of `params` needed to create a `MixtureSameFamily` distribution. - Arguments: + Args: num_components: Number of component distributions in the mixture distribution. component_params_size: Number of parameters needed to create a single diff --git a/tensorflow_probability/python/layers/distribution_layer_test.py b/tensorflow_probability/python/layers/distribution_layer_test.py index 8b58b2a4af..abe4683515 100644 --- a/tensorflow_probability/python/layers/distribution_layer_test.py +++ b/tensorflow_probability/python/layers/distribution_layer_test.py @@ -320,7 +320,7 @@ class DistributionLambdaSerializationTest(test_util.TestCase): def assertSerializable(self, model, batch_size=1): """Assert that a model can be saved/loaded via Keras Model.save/load_model. - Arguments: + Args: model: A Keras model that outputs a `tfd.Distribution`. batch_size: The batch size to use when checking that the model produces the same results as a serialized/deserialized copy. Default value: 1. @@ -348,7 +348,7 @@ def assertSerializable(self, model, batch_size=1): def assertExportable(self, model, batch_size=1): """Assert a Keras model supports export_saved_model/load_from_saved_model. - Arguments: + Args: model: A Keras model with Tensor output. batch_size: The batch size to use when checking that the model produces the same results as a serialized/deserialized copy. Default value: 1. diff --git a/tensorflow_probability/python/layers/initializers.py b/tensorflow_probability/python/layers/initializers.py index 530706f5ed..f3e9906dcd 100644 --- a/tensorflow_probability/python/layers/initializers.py +++ b/tensorflow_probability/python/layers/initializers.py @@ -30,7 +30,7 @@ class BlockwiseInitializer(tf.keras.initializers.Initializer): def __init__(self, initializers, sizes, validate_args=False): """Creates the `BlockwiseInitializer`. - Arguments: + Args: initializers: `list` of Keras initializers, e.g., `"glorot_uniform"` or `tf.keras.initializers.Constant(0.5413)`. sizes: `list` of `int` scalars representing the number of elements diff --git a/tensorflow_probability/python/layers/masked_autoregressive.py b/tensorflow_probability/python/layers/masked_autoregressive.py index 0a5f4b9316..4699222412 100644 --- a/tensorflow_probability/python/layers/masked_autoregressive.py +++ b/tensorflow_probability/python/layers/masked_autoregressive.py @@ -123,7 +123,7 @@ def f_inverse(x): def __init__(self, made, **kwargs): """Constructs the AutoregressiveTransform layer. - Arguments: + Args: made: A `Made` layer, which must output two parameters for each input. **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. """ diff --git a/tensorflow_probability/python/layers/variable_input.py b/tensorflow_probability/python/layers/variable_input.py index 63dd74355a..f0e7a9b74b 100644 --- a/tensorflow_probability/python/layers/variable_input.py +++ b/tensorflow_probability/python/layers/variable_input.py @@ -83,7 +83,7 @@ def __init__(self, **kwargs): """Creates the `VariableLayer`. - Arguments: + Args: shape: integer or integer vector specifying the shape of the output of this layer. dtype: TensorFlow `dtype` of the variable created by this layer. From 6a8bd09cb1998802b2bee3bdac2ecfccb79614e2 Mon Sep 17 00:00:00 2001 From: jburnim Date: Tue, 8 Dec 2020 16:02:17 -0800 Subject: [PATCH 06/36] Fix bug in tfp.math.bracket_root for negative roots. PiperOrigin-RevId: 346430469 --- tensorflow_probability/python/math/root_search.py | 2 +- .../python/math/root_search_test.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/math/root_search.py b/tensorflow_probability/python/math/root_search.py index 9d55368bdb..c442b92ab2 100644 --- a/tensorflow_probability/python/math/root_search.py +++ b/tensorflow_probability/python/math/root_search.py @@ -592,7 +592,7 @@ def bracket_root(objective_fn, xs_positive = tf.exp(tf.linspace(tf.cast(-10., dtype), tf.math.log(dtype_info.max), num_points // 2)) - xs = tf.concat([-xs_positive, xs_positive], axis=0) + xs = tf.concat([tf.reverse(-xs_positive, axis=[0]), xs_positive], axis=0) # Evaluate the objective at all points. The objective function may return # a batch of values (e.g., `objective(x) = x - batch_of_roots`). diff --git a/tensorflow_probability/python/math/root_search_test.py b/tensorflow_probability/python/math/root_search_test.py index db998878f3..7a894d49df 100644 --- a/tensorflow_probability/python/math/root_search_test.py +++ b/tensorflow_probability/python/math/root_search_test.py @@ -288,6 +288,20 @@ def objective_fn(x): self.assertAllTrue(low < roots) self.assertAllTrue(high > roots) + def test_negative_root(self): + root = -17.314 + low, high = self.evaluate(tfp.math.bracket_root(lambda x: (x - root))) + self.assertLess(low, root) + self.assertGreater(high, root) + + def test_root_near_zero(self): + root = tf.exp(-13.) + low, high = self.evaluate(tfp.math.bracket_root(lambda x: (x - root))) + self.assertLess(low, np.exp(-13.)) + self.assertGreater(high, np.exp(-13)) + self.assertAllClose(low, root, atol=1e-4) + self.assertAllClose(high, root, atol=1e-4) + def test_returns_zero_width_bracket_at_root(self): root = tf.exp(-10.) low, high = self.evaluate(tfp.math.bracket_root(lambda x: (x - root))) From 13ceab7869aa8984b584610f8e4ec4f8e0313c95 Mon Sep 17 00:00:00 2001 From: kateslin Date: Tue, 8 Dec 2020 16:59:25 -0800 Subject: [PATCH 07/36] Update `build_asvi_surrogate_posterior` to handle substituted distributions, some nested distributions, and a mean field option. PiperOrigin-RevId: 346440980 --- .../python/experimental/vi/BUILD | 17 +- .../experimental/vi/surrogate_posteriors.py | 204 ++++++++++++++---- .../vi/surrogate_posteriors_test.py | 126 ++++++++++- 3 files changed, 296 insertions(+), 51 deletions(-) diff --git a/tensorflow_probability/python/experimental/vi/BUILD b/tensorflow_probability/python/experimental/vi/BUILD index f443c07f4d..9072388b64 100644 --- a/tensorflow_probability/python/experimental/vi/BUILD +++ b/tensorflow_probability/python/experimental/vi/BUILD @@ -40,12 +40,21 @@ py_library( srcs = ["surrogate_posteriors.py"], srcs_version = "PY3", deps = [ - # numpy dep, # tensorflow dep, - "//tensorflow_probability/python/distributions", - "//tensorflow_probability/python/internal:nest_util", + "//tensorflow_probability/python/bijectors", + "//tensorflow_probability/python/bijectors:softplus", + "//tensorflow_probability/python/distributions:beta", + "//tensorflow_probability/python/distributions:independent", + "//tensorflow_probability/python/distributions:joint_distribution", + "//tensorflow_probability/python/distributions:joint_distribution_auto_batched", + "//tensorflow_probability/python/distributions:joint_distribution_coroutine", + "//tensorflow_probability/python/distributions:joint_distribution_util", + "//tensorflow_probability/python/distributions:normal", + "//tensorflow_probability/python/distributions:sample", + "//tensorflow_probability/python/distributions:transformed_distribution", + "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", - "//tensorflow_probability/python/monte_carlo", + "//tensorflow_probability/python/util", ], ) diff --git a/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py b/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py index c8c40563f8..7d847daf28 100644 --- a/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py +++ b/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py @@ -23,23 +23,34 @@ import functools import tensorflow.compat.v2 as tf -import tensorflow_probability as tfp from tensorflow_probability.python import bijectors as tfb from tensorflow_probability.python import util as tfp_util from tensorflow_probability.python.bijectors import softplus as softplus_lib +from tensorflow_probability.python.distributions import beta +from tensorflow_probability.python.distributions import half_normal from tensorflow_probability.python.distributions import independent +from tensorflow_probability.python.distributions import joint_distribution from tensorflow_probability.python.distributions import joint_distribution_auto_batched from tensorflow_probability.python.distributions import joint_distribution_coroutine from tensorflow_probability.python.distributions import joint_distribution_util from tensorflow_probability.python.distributions import normal +from tensorflow_probability.python.distributions import sample +from tensorflow_probability.python.distributions import transformed_distribution +from tensorflow_probability.python.distributions import truncated_normal +from tensorflow_probability.python.distributions import uniform from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import + Root = joint_distribution_coroutine.JointDistributionCoroutine.Root -_NON_STATISTICAL_PARAMS = ['name', 'validate_args', 'allow_nan_stats'] +_NON_STATISTICAL_PARAMS = [ + 'name', 'validate_args', 'allow_nan_stats', 'experimental_use_kahan_sum', + 'reinterpreted_batch_ndims' +] +_NON_TRAINABLE_PARAMS = ['low', 'high'] ASVIParameters = collections.namedtuple( 'ASVIParameters', ['prior_weight', 'mean_field_parameter']) @@ -306,38 +317,94 @@ def model_fn(): component_distributions, validate_args=validate_args)) -def _make_asvi_trainable_variables(prior): +def _as_trainable_family(distribution): + """Substitutes prior distributions with more easily trainable ones.""" + with tf.name_scope('as_trainable_family'): + + if isinstance(distribution, half_normal.HalfNormal): + return truncated_normal.TruncatedNormal( + loc=0., + scale=distribution.scale, + low=0., + high=distribution.scale * 10.) + elif isinstance(distribution, uniform.Uniform): + return tfb.Shift(distribution.low)( + tfb.Scale(distribution.high - distribution.low)(beta.Beta( + concentration0=tf.ones( + distribution.event_shape_tensor(), dtype=distribution.dtype), + concentration1=1.))) + else: + return distribution + + +def _make_asvi_trainable_variables(prior, + mean_field=False, + initial_prior_weight=0.5): """Generates parameter dictionaries given a prior distribution and list.""" with tf.name_scope('make_asvi_trainable_variables'): param_dicts = [] prior_dists = prior._get_single_sample_distributions() # pylint: disable=protected-access for dist in prior_dists: - actual_dist = dist.distribution if isinstance(dist, Root) else dist - dist_params = actual_dist.parameters + original_dist = dist.distribution if isinstance(dist, Root) else dist + + substituted_dist = _as_trainable_family(original_dist) + + # Grab the base distribution if it exists + try: + actual_dist = substituted_dist.distribution + except AttributeError: + actual_dist = substituted_dist + new_params_dict = {} # Build trainable ASVI representation for each distribution's parameters. - for param, value in dist_params.items(): - if param in _NON_STATISTICAL_PARAMS or value is None: + parameter_properties = actual_dist.parameter_properties( + dtype=actual_dist.dtype) + sample_shape = tf.concat( + [dist.batch_shape_tensor(), + dist.event_shape_tensor()], axis=0) + for param, value in actual_dist.parameters.items(): + if param in (_NON_STATISTICAL_PARAMS + + _NON_TRAINABLE_PARAMS) or value is None: continue - new_params_dict[param] = ASVIParameters( - prior_weight=tfp.util.TransformedVariable( - 0.5, - bijector=tfb.Sigmoid(), - name='prior_weight/{}/{}'.format(dist.name, param)), - mean_field_parameter=tfp.util.TransformedVariable( - 0.5, - bijector=dist.parameter_properties( - dtype=dist.dtype)[param].default_constraining_bijector_fn(), - name='mean_field_parameter/{}/{}'.format(dist.name, param)) - ) + try: + bijector = parameter_properties[ + param].default_constraining_bijector_fn() + except NotImplementedError: + bijector = tfb.Identity() + unconstrained_ones = tf.ones( + shape=bijector.inverse_event_shape_tensor( + parameter_properties[param].shape_fn( + sample_shape=sample_shape)), + dtype=actual_dist.dtype) + + if mean_field: + new_params_dict[param] = ASVIParameters( + prior_weight=None, + mean_field_parameter=tfp_util.TransformedVariable( + value, + bijector=bijector, + name='mean_field_parameter/{}/{}'.format(dist.name, param))) + else: + new_params_dict[param] = ASVIParameters( + prior_weight=tfp_util.TransformedVariable( + initial_prior_weight * unconstrained_ones, + bijector=tfb.Sigmoid(), + name='prior_weight/{}/{}'.format(dist.name, param)), + mean_field_parameter=tfp_util.TransformedVariable( + value, + bijector=bijector, + name='mean_field_parameter/{}/{}'.format(dist.name, param))) param_dicts.append(new_params_dict) return param_dicts # TODO(kateslin): Add support for models with prior+likelihood written as # a single JointDistribution. -def build_asvi_surrogate_posterior(prior, name=None): +def build_asvi_surrogate_posterior(prior, + mean_field=False, + initial_prior_weight=0.5, + name=None): """Builds a structured surrogate posterior inspired by conjugate updating. ASVI, or Automatic Structured Variational Inference, was proposed by @@ -360,12 +427,22 @@ def build_asvi_surrogate_posterior(prior, name=None): Args: prior: tfd.JointDistribution instance of the prior. - name: Optional string. + mean_field: Optional Python boolean. If `True`, creates a degenerate + surrogate distribution in which all variables are independent, + ignoring the prior dependence structure. Default value: `False`. + initial_prior_weight: Optional float value (either static or tensor value) + on the interval [0, 1]. A larger value creates an initial surrogate + distribution with more dependence on the prior structure. Default value: + `0.5`. + name: Optional string. Default value: `build_asvi_surrogate_posterior`. Returns: surrogate_posterior: A `tfd.JointDistributionCoroutineAutoBatched` instance whose samples have shape and structure matching that of `prior`. + Raises: + TypeError: The `prior` argument cannot be a nested `JointDistribution`. + ### Examples Consider a Brownian motion model expressed as a JointDistribution: @@ -417,9 +494,10 @@ def model_fn(): """ with tf.name_scope(name or 'build_asvi_surrogate_posterior'): - - param_dicts = _make_asvi_trainable_variables(prior) - + param_dicts = _make_asvi_trainable_variables( + prior=prior, + mean_field=mean_field, + initial_prior_weight=initial_prior_weight) def posterior_generator(): prior_gen = prior._model_coroutine() # pylint: disable=protected-access @@ -428,25 +506,61 @@ def posterior_generator(): i = 0 try: while True: - actual_dist = dist.distribution if isinstance(dist, Root) else dist - dist_params = actual_dist.parameters - temp_params_dict = {} + original_dist = dist.distribution if isinstance(dist, Root) else dist - for param, value in dist_params.items(): - if param in _NON_STATISTICAL_PARAMS or value is None: - temp_params_dict[param] = value + if isinstance(original_dist, joint_distribution.JointDistribution): + # TODO(kateslin): Build inner JD surrogate in + # _make_asvi_trainable_variables to avoid rebuilding variables. + raise TypeError( + 'Argument `prior` cannot be a nested `JointDistribution`.') + + else: + + original_dist = _as_trainable_family(original_dist) + + try: + actual_dist = original_dist.distribution + except AttributeError: + actual_dist = original_dist + + dist_params = actual_dist.parameters + temp_params_dict = {} + + for param, value in dist_params.items(): + if param in (_NON_STATISTICAL_PARAMS + + _NON_TRAINABLE_PARAMS) or value is None: + temp_params_dict[param] = value + else: + prior_weight = param_dicts[i][param].prior_weight + mean_field_parameter = param_dicts[i][ + param].mean_field_parameter + if mean_field: + temp_params_dict[param] = mean_field_parameter + else: + temp_params_dict[param] = prior_weight * value + ( + 1. - prior_weight) * mean_field_parameter + + if isinstance(original_dist, sample.Sample): + surrogate_dist = sample.Sample( + type(actual_dist)(**temp_params_dict)) else: - prior_weight = param_dicts[i][param].prior_weight - mean_field_parameter = param_dicts[i][param].mean_field_parameter - temp_params_dict[param] = prior_weight * value + ( - 1. - prior_weight) * mean_field_parameter + surrogate_dist = type(actual_dist)(**temp_params_dict) - surrogate_dist = type(actual_dist)(**temp_params_dict) + if isinstance(original_dist, + transformed_distribution.TransformedDistribution): + surrogate_dist = transformed_distribution.TransformedDistribution( + surrogate_dist, bijector=original_dist.bijector) - if isinstance(dist, Root): - value_out = yield Root(surrogate_dist) - else: - value_out = yield surrogate_dist + if isinstance(original_dist, independent.Independent): + surrogate_dist = independent.Independent( + surrogate_dist, + reinterpreted_batch_ndims=original_dist + .reinterpreted_batch_ndims) + + if isinstance(dist, Root): + value_out = yield Root(surrogate_dist) + else: + value_out = yield surrogate_dist dist = prior_gen.send(value_out) i += 1 @@ -456,5 +570,19 @@ def posterior_generator(): surrogate_posterior = ( joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched( posterior_generator)) + + # Ensure that the surrogate posterior structure matches that of the prior + try: + tf.nest.assert_same_structure(prior.dtype, surrogate_posterior.dtype) + except TypeError: + tokenize = lambda structure: tf.nest.pack_sequence_as( # pylint:disable=g-long-lambda + structure, [i for (i, _) in enumerate(tf.nest.flatten(structure))]) + surrogate_posterior = tfb.Restructure( + output_structure=tokenize(prior.dtype), + input_structure=tokenize(surrogate_posterior.dtype))( + surrogate_posterior) + surrogate_posterior.also_track = param_dicts return surrogate_posterior + + diff --git a/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py b/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py index 7b7f8e8f90..9cfdb20d6c 100644 --- a/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py +++ b/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py @@ -28,6 +28,7 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp from tensorflow_probability.python.experimental.vi import surrogate_posteriors +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import test_util tfb = tfp.bijectors @@ -265,10 +266,16 @@ def test_dims_and_gradients(self): # Test that the correct number of trainable variables are being tracked prior_dists = prior_dist._get_single_sample_distributions() # pylint: disable=protected-access expected_num_trainable_vars = 0 - for dist in prior_dists: + for original_dist in prior_dists: + try: + original_dist = original_dist.distribution + except AttributeError: + pass + dist = surrogate_posteriors._as_trainable_family(original_dist) dist_params = dist.parameters for param, value in dist_params.items(): - if param not in surrogate_posteriors._NON_STATISTICAL_PARAMS and value is not None: + if (param not in surrogate_posteriors._NON_STATISTICAL_PARAMS + and value is not None and param not in ('low', 'high')): expected_num_trainable_vars += 2 # prior_weight, mean_field_parameter self.assertLen(surrogate_posterior.trainable_variables, @@ -285,9 +292,9 @@ def test_dims_and_gradients(self): # Test that the sample shape is correct three_posterior_samples = surrogate_posterior.sample(3) three_prior_samples = prior_dist.sample(3) - - self.assertAllEqualNested([s.shape for s in three_prior_samples], - [s.shape for s in three_posterior_samples]) + self.assertAllEqualNested( + [s.shape for s in tf.nest.flatten(three_prior_samples)], + [s.shape for s in tf.nest.flatten(three_posterior_samples)]) def test_fitting_surrogate_posterior(self): @@ -308,8 +315,9 @@ def test_fitting_surrogate_posterior(self): # Compute posterior statistics. with tf.control_dependencies([losses]): posterior_samples = surrogate_posterior.sample(100) - posterior_mean = [tf.reduce_mean(x) for x in posterior_samples] - posterior_stddev = [tf.math.reduce_std(x) for x in posterior_samples] + posterior_mean = tf.nest.map_structure(tf.reduce_mean, posterior_samples) + posterior_stddev = tf.nest.map_structure(tf.math.reduce_std, + posterior_samples) self.evaluate(tf1.global_variables_initializer()) _ = self.evaluate(losses) @@ -328,8 +336,16 @@ def test_make_asvi_trainable_variables(self): # Confirm that there exists correct number of trainable variables. for (prior_distribution, trained_vars_dict) in zip(prior_dists, trained_vars): - for param_name, prior_value in prior_distribution.parameters.items(): - if param_name not in surrogate_posteriors._NON_STATISTICAL_PARAMS and prior_value is not None: + substituted_dist = surrogate_posteriors._as_trainable_family( + prior_distribution) + try: + posterior_distribution = substituted_dist.distribution + except AttributeError: + posterior_distribution = substituted_dist + + for param_name, prior_value in posterior_distribution.parameters.items(): + if (param_name not in surrogate_posteriors._NON_STATISTICAL_PARAMS + and prior_value is not None and param_name not in ('low', 'high')): self.assertIsInstance(trained_vars_dict[param_name], surrogate_posteriors.ASVIParameters) @@ -375,9 +391,101 @@ def target_log_prob(*x): return target_log_prob +@test_util.test_all_tf_execution_regimes +class ASVISurrogatePosteriorTestEightSchools(test_util.TestCase, + _TrainableASVISurrogate): + + def make_prior_dist(self): + treatment_effects = tf.constant([28, 8, -3, 7, -1, 1, 18, 12], + dtype=tf.float32) + num_schools = ps.shape(treatment_effects)[-1] + + return tfd.JointDistributionNamed({ + 'avg_effect': + tfd.Normal(loc=0., scale=10., name='avg_effect'), + 'log_stddev': + tfd.Normal(loc=5., scale=1., name='log_stddev'), + 'school_effects': + lambda log_stddev, avg_effect: ( # pylint: disable=g-long-lambda + tfd.Independent( + tfd.Normal( + loc=avg_effect[..., None] * tf.ones(num_schools), + scale=tf.exp(log_stddev[..., None]) * tf.ones( + num_schools), + name='school_effects'), + reinterpreted_batch_ndims=1)) + }) + + def make_likelihood_model(self, x, observation_noise=None): + treatment_stddevs = tf.constant([15, 10, 16, 11, 9, 11, 10, 18], + dtype=tf.float32) + + return tfd.Independent( + tfd.Normal(loc=x['school_effects'], scale=treatment_stddevs), + reinterpreted_batch_ndims=1) + + def get_observations(self, prior_dist): + ground_truth = self.evaluate(prior_dist.sample()) + likelihood = self.make_likelihood_model(x=ground_truth) + return likelihood.sample(1) + + def get_target_log_prob(self, observations, prior_dist): + + def target_log_prob(**x): + likelihood_dist = self.make_likelihood_model(x=x) + return likelihood_dist.log_prob(observations) + prior_dist.log_prob(x) + + return target_log_prob + + +@test_util.test_all_tf_execution_regimes +class ASVISurrogatePosteriorTestHalfNormal(test_util.TestCase, + _TrainableASVISurrogate): + + def make_prior_dist(self): + + def _prior_model_fn(): + innovation_noise = 1. + yield tfd.HalfNormal( + scale=innovation_noise, validate_args=True, allow_nan_stats=False) + + return tfd.JointDistributionCoroutineAutoBatched(_prior_model_fn) + + def make_likelihood_model(self, x, observation_noise): + + def _likelihood_model(): + yield tfd.Normal( + loc=x, + scale=observation_noise, + validate_args=True, + allow_nan_stats=False) + + return tfd.JointDistributionCoroutineAutoBatched(_likelihood_model) + + def get_observations(self, prior_dist): + observation_noise = 1. + ground_truth = prior_dist.sample() + likelihood = self.make_likelihood_model( + x=ground_truth, observation_noise=observation_noise) + return likelihood.sample(1) + + def get_target_log_prob(self, observations, prior_dist): + + obs = observations + def target_log_prob(*x): + observation_noise = 0.15 + likelihood_dist = self.make_likelihood_model( + x=x, observation_noise=observation_noise) + + return likelihood_dist.log_prob(obs) + prior_dist.log_prob(x) + + return target_log_prob + # TODO(kateslin): Add an ASVI surrogate posterior test for gamma distributions. # TODO(kateslin): Add an ASVI surrogate posterior test with for a model with # missing observations. +# TODO(kateslin): Add an ASVI surrogate posterior test for Uniform distribution +# to check that Beta substitution works properly if __name__ == '__main__': tf.test.main() From 30bdacfa59f5554e030d75d7519cf103d3929243 Mon Sep 17 00:00:00 2001 From: emilyaf Date: Wed, 9 Dec 2020 12:43:30 -0800 Subject: [PATCH 08/36] Add utilities for batched (transpose) convolutions to `tfp.experimental.nn`. PiperOrigin-RevId: 346614698 --- .../python/experimental/nn/util/BUILD | 18 +- .../python/experimental/nn/util/__init__.py | 12 +- .../experimental/nn/util/convolution_util.py | 910 ++++++++++++++++++ .../nn/util/convolution_util_test.py | 561 +++++++++++ .../python/experimental/nn/util/im2row.py | 187 ---- .../experimental/nn/util/im2row_test.py | 60 -- 6 files changed, 1495 insertions(+), 253 deletions(-) create mode 100644 tensorflow_probability/python/experimental/nn/util/convolution_util.py create mode 100644 tensorflow_probability/python/experimental/nn/util/convolution_util_test.py delete mode 100644 tensorflow_probability/python/experimental/nn/util/im2row.py delete mode 100644 tensorflow_probability/python/experimental/nn/util/im2row_test.py diff --git a/tensorflow_probability/python/experimental/nn/util/BUILD b/tensorflow_probability/python/experimental/nn/util/BUILD index 7fefb0db42..3661c714e3 100644 --- a/tensorflow_probability/python/experimental/nn/util/BUILD +++ b/tensorflow_probability/python/experimental/nn/util/BUILD @@ -28,7 +28,7 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY3", deps = [ - ":im2row", + ":convolution_util", ":random_variable", ":utils", "//tensorflow_probability/python/internal:all_util", @@ -36,24 +36,32 @@ py_library( ) py_library( - name = "im2row", - srcs = ["im2row.py"], + name = "convolution_util", + srcs = ["convolution_util.py"], srcs_version = "PY3", deps = [ + ":utils", # tensorflow dep, + "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", ], ) py_test( - name = "im2row_test", + name = "convolution_util_test", size = "medium", - srcs = ["im2row_test.py"], + srcs = ["convolution_util_test.py"], python_version = "PY3", + shard_count = 4, srcs_version = "PY3", deps = [ + ":convolution_util", + # absl/testing:parameterized dep, + # numpy dep, # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:test_util", ], ) diff --git a/tensorflow_probability/python/experimental/nn/util/__init__.py b/tensorflow_probability/python/experimental/nn/util/__init__.py index 3c0c4c4417..8678e5c9ba 100644 --- a/tensorflow_probability/python/experimental/nn/util/__init__.py +++ b/tensorflow_probability/python/experimental/nn/util/__init__.py @@ -17,7 +17,12 @@ from __future__ import division from __future__ import print_function -from tensorflow_probability.python.experimental.nn.util.im2row import im2row +from tensorflow_probability.python.experimental.nn.util.convolution_util import im2row +from tensorflow_probability.python.experimental.nn.util.convolution_util import im2row_index +from tensorflow_probability.python.experimental.nn.util.convolution_util import make_convolution_fn +from tensorflow_probability.python.experimental.nn.util.convolution_util import make_convolution_transpose_fn_with_dilation +from tensorflow_probability.python.experimental.nn.util.convolution_util import make_convolution_transpose_fn_with_subkernels +from tensorflow_probability.python.experimental.nn.util.convolution_util import make_convolution_transpose_fn_with_subkernels_matrix from tensorflow_probability.python.experimental.nn.util.random_variable import CallOnce from tensorflow_probability.python.experimental.nn.util.random_variable import RandomVariable from tensorflow_probability.python.experimental.nn.util.utils import batchify_op @@ -53,10 +58,15 @@ 'flatten_rightmost', 'halflife_decay', 'im2row', + 'im2row_index', 'make_fit_op', 'make_kernel_bias', 'make_kernel_bias_posterior_mvn_diag', 'make_kernel_bias_prior_spike_and_slab', + 'make_convolution_fn', + 'make_convolution_transpose_fn_with_dilation', + 'make_convolution_transpose_fn_with_subkernels', + 'make_convolution_transpose_fn_with_subkernels_matrix', 'negloglik', 'prepare_conv_args', 'prepare_strides', diff --git a/tensorflow_probability/python/experimental/nn/util/convolution_util.py b/tensorflow_probability/python/experimental/nn/util/convolution_util.py new file mode 100644 index 0000000000..da4c00fdba --- /dev/null +++ b/tensorflow_probability/python/experimental/nn/util/convolution_util.py @@ -0,0 +1,910 @@ +# Lint as: python2, python3 +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Functions for framing `conv` as `matmul`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python.experimental.nn.util import utils +from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import prefer_static as ps + +__all__ = [ + 'im2row', + 'im2row_index', + 'make_convolution_fn', + 'make_convolution_transpose_fn_with_dilation', + 'make_convolution_transpose_fn_with_subkernels', + 'make_convolution_transpose_fn_with_subkernels_matrix', +] + + +def im2row(x, + block_shape, + slice_step=(1, 1), + padding='VALID', + name=None): + """Rearrange image blocks into rows. + + This function can be used to implement 2D convolution as a `matmul`, e.g., + + `tf.nn.conv2d(x, k) = tf.matmul( + tf.experimental.nn.util.im2row(x), tf.reshape(k, shape=[-1, out_size]))`. + + Args: + x: Rank 3 (or more) Tensor representing 2D images. + block_shape: Length-2 vector representing the block or "filter" shape. + slice_step: Length-2 vector specifying the convolution stride length. + Default value: `(1, 1)`. + padding: One of `'VALID'` or `'SAME'` (case insensitive). + Default value: `'VALID'`. + name: Python `str` used to describe ops created by this function. + Default value: `None` (i.e., `'im2col'`). + + Returns: + im2row_x: batch of matrices representing subblock copies of `x`. + Same batch shape as `x` but with rightmost shape: + `batch_shape + [oh * ow, block_shape[0] * block_shape[1] * channels]`, + where `oh = (h - block_shape[0] + 1) // slice_step[0]` and + `ow = (w - block_shape[1] + 1) // slice_step[1]` when `padding = 'VALID'` + and `oh = h` and `ow = w` when `padding = 'SAME'`. + shape: shape `Tensor` equivalent to: + `batch_shape + [oh, ow, block_shape[0] * block_shape[1] * channels]` where + `oh, ow` are defined as above. + """ + with tf.name_scope(name or 'im2row'): + padding = _validate_padding(padding) + if padding == 'VALID': + pass # Do nothing. + elif padding == 'SAME': + raise NotImplementedError( + 'Argument padding="SAME" not implemented.') + # TODO(jvdillon): See if the following works: + # fh, fw = block_shape + # o = 1 if data_format == 'NHWC' else 0 + # n = ps.maximum(0, ps.rank(x) - 3) + # paddings = ps.pad( + # [[0, fh - 1], [0, fw - 1]], + # paddings=[[n + 1 - o, o], [0, 0]], + # constant_values=0) + # x = tf.pad(x, paddings=paddings, constant_values=0) + # padding = 'VALID' + else: + assert False # Can't be here. + x_shape = ps.shape(x) + idx, s = im2row_index( + x_shape, block_shape=block_shape, slice_step=slice_step) + flat_shape = ps.pad( + x_shape[:-3], paddings=[[0, 1]], constant_values=-1) + x = tf.gather(tf.reshape(x, flat_shape), idx, axis=-1) # == np.take + return tf.reshape(x, s) + + +def im2row_index(input_shape, + block_shape, + rank=2, + slice_step=(1, 1), + dilations=(1, 1), + dtype=tf.int32, + transpose=False, + validate_args=False, + name=None): + """Computes indexes into a flattened image for building `im2row`.""" + with tf.name_scope(name or 'im2row_index'): + if tf.get_static_value(rank) != 2: + raise NotImplementedError('Argument `rank` currently only supports `2`; ' + 'saw "{}".'.format(rank)) + fh, fw = prepare_tuple_argument( + block_shape, n=rank, arg_name='block_shape', + validate_args=validate_args) + sh, sw = prepare_tuple_argument( + slice_step, n=rank, arg_name='slice_step', validate_args=validate_args) + dh, dw = prepare_tuple_argument( + dilations, n=rank, arg_name='dilations', validate_args=validate_args) + + # 1) Process input arguments. + batch_shape, h, w, c = ps.split( + ps.reshape(ps.cast(input_shape, dtype=dtype), shape=[-1]), + num_or_size_splits=[-1, 1, 1, 1]) + h, w, c = h[0], w[0], c[0] + + tot_fh = dh * (fh - 1) + 1 + tot_fw = dw * (fw - 1) + 1 + + # 2) Assemble all block start positions as indexes into the flattened image. + # start_idx.shape = [fh, fw, c] + if transpose: + last_element = lambda size, step: size - (size - 1) % step - 1 + w_step = c * dw + h_step = c * w * dh + last_w = last_element(c * tot_fw, w_step) + last_h = last_element(c * w * tot_fh, h_step) + start_idx = cartesian_add([ + ps.range(last_h, -1, delta=-h_step, dtype=dtype), + ps.range(last_w, -1, delta=-w_step, dtype=dtype), + ps.range(c, delta=1, dtype=dtype), + ]) + else: + start_idx = cartesian_add([ + ps.range(c * w * tot_fh, delta=c * w * dh, dtype=dtype), + ps.range(c * tot_fw, delta=c * dw, dtype=dtype), + ps.range(c, delta=1, dtype=dtype), + ]) + + # 3) Assemble all block offsets (into flattened image). + eh = h - tot_fh + 1 + ew = w - tot_fw + 1 + + offset_idx = cartesian_add([ + ps.range(w * eh, delta=w * sh, dtype=dtype), + ps.range(ew, delta=sw, dtype=dtype), + ]) + + offset_idx = offset_idx * c + oh = (eh - 1) // sh + 1 # out height + ow = (ew - 1) // sw + 1 # out width + + # 4) Combine block start/offset pairs. + # shape = [(eh // sh) * (ew // sw), fh * fw * c] + idx = cartesian_add([offset_idx, start_idx]) + new_shape = ps.concat( + [batch_shape, ps.convert_to_shape_tensor([oh, ow, fh * fw * c])], + axis=0) + return idx, new_shape + + +def cartesian_add(xs): + """Adds a list of vectors by cumulatively expanding a dimension.""" + return sum(ps.reshape(x, shape=[-1] + [1] * (len(xs) - 1 - i)) + for i, x in enumerate(xs)) + + +def _validate_padding(padding): + """Verify correctness of `padding` argument.""" + padding_ = str(padding).upper() + if padding_ in {'SAME', 'VALID'}: + return padding_ + raise ValueError( + 'Argument padding="{}" not recognized; must be one of ' + '{{"VALID", "SAME"}} (case insensitive).'.format(padding)) + + +# TODO(emilyaf): Finish docstrings. +def make_convolution_fn( + filter_shape, rank, strides, padding, dilations=None, dtype=tf.int32, + validate_args=False, name=None): + """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" + with tf.name_scope(name or 'conv2d'): + if tf.get_static_value(rank) != 2: + raise NotImplementedError('Argument `rank` currently only supports `2`; ' + 'saw "{}".'.format(rank)) + [ + filter_shape, + rank, + strides, + padding, + dilations, + ] = prepare_conv_args( + filter_shape, rank=rank, strides=strides, padding=padding, + dilations=dilations, validate_args=validate_args) + + def op(x, kernel): + input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) + x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') + kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') + + batch_shape, event_shape = ps.split( + ps.shape(x), num_or_size_splits=[-1, 3]) + xh, xw, c_in = ps.unstack(event_shape, num=3) + fh, fw = filter_shape + + assertions = _maybe_validate_input_shapes( + ps.shape(kernel), channels_in=c_in, filter_height=fh, + filter_width=fw, validate_args=validate_args) + + with tf.control_dependencies(assertions): + if tf.get_static_value(ps.rank(kernel)) == 2: + flat_x = tf.reshape(x, shape=ps.concat([[-1], event_shape], axis=0)) + flat_y = tf.nn.conv2d( + x, + filters=tf.reshape(kernel, shape=[fh, fw, c_in, -1]), + strides=strides, + padding=padding, + data_format='NHWC', + dilations=dilations) + output_shape = ps.shape(flat_y)[-3:] + return tf.reshape( + flat_y, shape=ps.concat([batch_shape, output_shape], axis=0)) + + pad_values = [ + _get_conv_padding( + xdim, filter_dim=k, stride=s, dilation=d, padding=padding) + for (xdim, k, s, d) in zip((xh, xw), filter_shape, strides, dilations) + ] + + idx, shape = im2row_index( + (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in), + block_shape=filter_shape, slice_step=strides, dilations=dilations, + dtype=dtype) + + if padding == 'SAME': + n = ps.maximum(0, ps.rank(x) - 3) + paddings = ps.pad( + pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) + x = tf.pad(x, paddings=paddings, constant_values=0) + + flat_shape = ps.pad( + batch_shape, paddings=[[0, 1]], constant_values=-1) + flat_x = tf.gather(tf.reshape(x, shape=flat_shape), indices=idx, axis=-1) + im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) + return tf.matmul(im_x, kernel[..., tf.newaxis, :, :]) + return op + + +def _get_conv_padding(xdim, filter_dim, stride, dilation, padding): + """Returns the number of zeros to pad at the start and end of an axis.""" + if padding == 'VALID': + return (0, 0) + elif padding == 'SAME': + tot_k = dilation * (filter_dim - 1) + 1 + tot_pad = tf.maximum(tot_k - ((xdim - 1) % stride + 1), 0) + pad_start = tot_pad // 2 + return pad_start, tot_pad - pad_start + + +def make_convolution_transpose_fn_with_dilation( + filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, + validate_args=False, name=None): + """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`. + + This version tends to be fastest on GPU. It implements the transposed + convolution as a regular convolution of an image that is dilated by + interleaving rows and columns of zeros equal to the number of strides. + + Args: + filter_shape: ... + strides: ... + padding: ... + rank: ... + dilations: ... + dtype: ... + validate_args: ... + name: ... + Returns: + convolution_transpose_fn: A callable that takes an input `Tensor` and kernel + and applies the transpose convolution operation. + """ + with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): + + if tf.get_static_value(rank) != 2: + raise NotImplementedError('Argument `rank` currently only supports `2`; ' + 'saw "{}".'.format(rank)) + [ + filter_shape, + rank, + strides, + padding, + dilations, + ] = prepare_conv_args( + filter_shape, rank=rank, strides=strides, padding=padding, + dilations=dilations, validate_args=validate_args) + + sh, sw = strides + fh, fw = filter_shape + + pad_values = [ + _get_transpose_conv_dilated_padding( + k, stride=s, dilation=d, padding=padding) + for (k, s, d) in zip(filter_shape, strides, dilations)] + + def op(x, kernel): + input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) + x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') + kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') + + batch_shape, event_shape = ps.split( + ps.shape(x), num_or_size_splits=[-1, 3]) + xh, xw, c_in = ps.unstack(event_shape, num=3) + kernel_shape = ps.shape(kernel) + assertions = _maybe_validate_input_shapes( + kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, + validate_args=validate_args) + + with tf.control_dependencies(assertions): + # If the kernel does not have batch shape, fall back to + # `conv2d_transpose` (unless dilations > 1, which is not implemented in + # `conv2d_transpose`). + if (tf.get_static_value(ps.rank(kernel)) == 2 + and all(d == 1 for d in dilations)): + return _call_conv2d_transpose( + x, kernel, filter_shape, strides, padding, dilations, + kernel_shape[-1], batch_shape, event_shape) + + idx, shape = im2row_index( + (xh * sh + sum(pad_values[0]), xw * sw + sum(pad_values[1]), c_in), + block_shape=filter_shape, slice_step=(1, 1), dilations=dilations, + dtype=dtype, transpose=True) + + n = ps.maximum(0, ps.rank(x) - 3) + paddings = ps.pad( + pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) + + # Interleave the rows and columns of the input with rows and columns of + # zeros equal to the number of strides. + x_half_dilated = tf.concat( + [tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)], axis=0), + dtype=input_dtype), + tf.reshape( + x, shape=ps.concat([batch_shape, (xh * xw, 1, c_in)], axis=0)) + ], axis=-2) + y = tf.reshape( + x_half_dilated, + shape=ps.concat([batch_shape, (xh, 1, xw * sw, c_in)], axis=0)) + + x = tf.reshape( + tf.concat( + [tf.zeros( + ps.concat( + [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0), + dtype=input_dtype), y], axis=-3), + shape=ps.concat([batch_shape, (xh * sh, xw * sw, c_in)], axis=0)) + x_pad = tf.pad(x, paddings=paddings, constant_values=0) + flat_shape = ps.pad(batch_shape, paddings=[[0, 1]], constant_values=-1) + flat_x = tf.gather( + tf.reshape(x_pad, shape=flat_shape), indices=idx, axis=-1) + im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) + return tf.matmul(im_x, kernel[..., tf.newaxis, :, :]) + return op + + +def make_convolution_transpose_fn_with_subkernels_matrix( + filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, + validate_args=False, name=None): + """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" + with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): + + if tf.get_static_value(rank) != 2: + raise NotImplementedError('Argument `rank` currently only supports `2`; ' + 'saw "{}".'.format(rank)) + + strides = tf.get_static_value(strides) + if not isinstance(strides, int): + raise ValueError('Argument `strides` must be a statically known integer.' + 'Saw: {}'.format(strides)) + + [ + filter_shape, + rank, + _, + padding, + dilations, + ] = prepare_conv_args( + filter_shape, rank=rank, strides=strides, padding=padding, + dilations=dilations, validate_args=validate_args) + + fh, fw = filter_shape + dh, dw = dilations + + # Determine maximum filter height and filter width of sub-kernels. + sub_fh = (fh - 1) // strides + 1 + sub_fw = (fw - 1) // strides + 1 + + def loop_body(i_, event_ind): + i = i_ // strides + j = i_ % strides + + i_ind = ps.range(i * fw, fw * fh, delta=strides * fw, dtype=dtype) + j_ind = ps.range(j, fw, delta=strides, dtype=dtype) + + nc = cartesian_add([i_ind, j_ind]) + ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0]) + + k = ps.reshape( + cartesian_add( + [ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype), + ps.range(ps.shape(nc)[1], dtype=dtype)]), + shape=[-1]) + last_j = strides - (fw - j - 1) % strides - 1 + last_i = strides - (fh - i - 1) % strides - 1 + kernel_ind = ps.stack( + [k, ps.ones_like(k) * last_i * strides + last_j], axis=1) + event_ind = ps.tensor_scatter_nd_update( + event_ind, ind[..., tf.newaxis], kernel_ind) + + return i_ + 1, event_ind + + event_ind = ps.zeros((fh * fw, 2), dtype=dtype) + _, event_ind = tf.while_loop( + lambda i, _: i < strides ** 2, + loop_body, + [tf.zeros([], dtype=dtype), event_ind]) + + tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding( + fh, stride=strides, dilation=dh, padding=padding) + tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding( + fw, stride=strides, dilation=dw, padding=padding) + + pad_bottom = (tot_pad_bottom - 1) // strides + 1 + pad_top = (tot_pad_top - 1) // strides + 1 + pad_right = (tot_pad_right - 1) // strides + 1 + pad_left = (tot_pad_left - 1) // strides + 1 + padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right)) + + truncate_top = pad_top * strides - tot_pad_top + truncate_left = pad_left * strides - tot_pad_left + + def op(x, kernel): + input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) + x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') + kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') + + batch_shape, event_shape = ps.split( + ps.shape(x), num_or_size_splits=[-1, 3]) + xh, xw, c_in = ps.unstack(event_shape, num=3) + + kernel_shape = ps.shape(kernel) + c_out = kernel_shape[-1] + kernel_batch = kernel_shape[:-2] + assertions = _maybe_validate_input_shapes( + kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, + validate_args=validate_args) + + with tf.control_dependencies(assertions): + + # If the kernel does not have batch shape, fall back to + # `conv2d_transpose` (unless dilations > 1, which is not implemented in + # `conv2d_transpose`). + if (tf.get_static_value(ps.rank(kernel)) == 2 + and all(d == 1 for d in dilations)): + return _call_conv2d_transpose( + x, kernel=kernel, filter_shape=filter_shape, + strides=(strides,) * rank, padding=padding, dilations=dilations, + c_out=c_out, batch_shape=batch_shape, event_shape=event_shape) + + n = ps.maximum(0, ps.rank(x) - 3) + paddings = ps.pad( + padding_vals, + paddings=[[n, 1], [0, 0]], + constant_values=0) + + x_pad = tf.pad(x, paddings=paddings, constant_values=0) + x_pad_shape = ps.shape(x_pad)[:-3] + flat_shape = ps.pad(x_pad_shape, paddings=[[0, 1]], constant_values=-1) + flat_x = tf.reshape(x_pad, shape=flat_shape) + + idx, s = im2row_index( + (xh + tf.reduce_sum(padding_vals[0]), + xw + tf.reduce_sum(padding_vals[1]), c_in), + block_shape=(sub_fh, sub_fw), slice_step=(1, 1), dilations=dilations + ) + + x_ = tf.gather(flat_x, indices=idx, axis=-1) + im_x = tf.reshape(x_, shape=ps.concat([x_pad_shape, s], axis=0)) + + # Add channels to subkernel indices + idx_event = event_ind * [[c_in, 1]] + idx_event_channels = ( + idx_event[tf.newaxis] + + tf.stack([ps.range(c_in), tf.zeros((c_in,), dtype=dtype)], + axis=-1)[:, tf.newaxis, :]) + idx_event = tf.squeeze( + tf.batch_to_space( + idx_event_channels, block_shape=[c_in], crops=[[0, 0]]), axis=0) + idx_event_broadcast = tf.broadcast_to( + idx_event, + shape=ps.concat([kernel_batch, ps.shape(idx_event)], axis=0)) + + # Add cartesian product of batch indices, since scatter_nd can only be + # applied to leading dimensions. + idx_batch = tf.stack( + tf.meshgrid( + *[ps.range(b_, delta=1, dtype=dtype) + for b_ in tf.unstack(kernel_batch)], indexing='ij'), + axis=ps.size(kernel_batch)) + + idx_batch = tf.cast(idx_batch, dtype=dtype) # empty tensor is float + + idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros( + (ps.shape(idx_event)[0], 1), dtype=dtype) + idx_kernel = tf.concat( + [idx_batch_broadcast, idx_event_broadcast], axis=-1) + + kernel_mat = tf.scatter_nd( + idx_kernel, + updates=kernel, + shape=ps.cast( + ps.concat([kernel_batch, + [sub_fh * sub_fw * c_in, strides ** 2, c_out]], + axis=0), + dtype=dtype)) + + kernel_mat = tf.reshape( + kernel_mat, + shape=ps.concat( + [ps.shape(kernel_mat)[:-2], [strides ** 2 * c_out]], axis=0)) + + kernel_mat = kernel_mat[..., tf.newaxis, :, :] + out = tf.matmul(im_x, kernel_mat) + broadcast_batch_shape = ps.broadcast_shape(batch_shape, kernel_batch) + + if strides > 1: + tot_size = tf.reduce_prod(broadcast_batch_shape) + flat_out = tf.reshape( + out, + shape=ps.concat([[tot_size], ps.shape(out)[-3:]], axis=0)) + out = tf.nn.depth_to_space(flat_out, block_size=strides) + + if padding == 'VALID': + out_height = fh + strides * (xh - 1) + out_width = fw + strides * (xw - 1) + elif padding == 'SAME': + out_height = xh * strides + out_width = xw * strides + + out = out[..., truncate_top:truncate_top + out_height, + truncate_left:truncate_left + out_width, :] + out = tf.reshape( + out, shape=ps.concat( + [broadcast_batch_shape, [out_height, out_width, c_out]], + axis=0)) + return out + return op + + +def make_convolution_transpose_fn_with_subkernels( + filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, + validate_args=False, name=None): + """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" + with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): + + if tf.get_static_value(rank) != 2: + raise NotImplementedError('Argument `rank` currently only supports `2`; ' + 'saw "{}".'.format(rank)) + [ + filter_shape, + rank, + strides, + padding, + dilations, + ] = prepare_conv_args( + filter_shape, rank=rank, strides=strides, padding=padding, + dilations=dilations, validate_args=validate_args) + + sh, sw = strides + fh, fw = filter_shape + dh, dw = dilations + + # Determine maximum filter height and filter width of sub-kernels. + sub_fh = (fh - 1) // sh + 1 + sub_fw = (fw - 1) // sw + 1 + + def loop_body(i_, kernels_ind): + i = i_ // sw + j = i_ % sw + i_ind = ps.range((sh - i - 1)*fw, fw * fh, delta=sh*fw, dtype=dtype) + j_ind = ps.range((sw - j - 1), fw, delta=sw, dtype=dtype) + + last_j = sw - (fw - j - 1) % sw - 1 + last_i = sh - (fh - i - 1) % sh - 1 + pos = last_i * sw + last_j + + nc = cartesian_add([i_ind, j_ind]) + kernels_ind = kernels_ind.write( + sh * sw - pos - 1, ps.reverse(ps.reverse(nc, [0]), [1])) + + return i_ + 1, kernels_ind + + kernels_ind = tf.TensorArray(dtype=dtype, infer_shape=False, size=1, + dynamic_size=True) + + _, kernels_ind = tf.while_loop( + lambda i, _: i < sh * sw, + loop_body, + [0, kernels_ind]) + + tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding( + fh, stride=sh, dilation=dh, padding=padding) + tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding( + fw, stride=sw, dilation=dw, padding=padding) + + pad_bottom = (tot_pad_bottom - 1) // sh + 1 + pad_top = (tot_pad_top - 1) // sh + 1 + pad_right = (tot_pad_right - 1) // sw + 1 + pad_left = (tot_pad_left - 1) // sw + 1 + padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right)) + + truncate_top = pad_top * sh - tot_pad_top + truncate_left = pad_left * sw - tot_pad_left + + def op(x, kernel): + input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) + x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') + kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') + + batch_shape, event_shape = ps.split( + ps.shape(x), num_or_size_splits=[-1, 3]) + xh, xw, c_in = ps.unstack(event_shape, num=3) + + kernel_shape = ps.shape(kernel) + c_out = kernel_shape[-1] + kernel_batch = kernel_shape[:-2] + assertions = _maybe_validate_input_shapes( + kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, + validate_args=validate_args) + + with tf.control_dependencies(assertions): + # If the kernel does not have batch shape, fall back to + # `conv2d_transpose` (unless dilations > 1, which is not implemented in + # `conv2d_transpose`). + if (tf.get_static_value(ps.rank(kernel)) == 2 + and all(d == 1 for d in dilations)): + return _call_conv2d_transpose( + x, kernel, filter_shape, strides, padding, dilations, c_out, + batch_shape, event_shape) + + n = ps.maximum(0, ps.rank(x) - 3) + paddings = ps.pad( + padding_vals, + paddings=[[n, 1], [0, 0]], + constant_values=0) + x_pad = tf.pad(x, paddings=paddings, constant_values=0) + + ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1 + ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1 + + def loop_body(i, outputs): + subkernel_ind = kernels_ind.read(i) + fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2) + eh = ex_h + fh_ - 1 + ew = ex_w + fw_ - 1 + + subkernel_ind = ps.reshape( + ps.reshape(subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] + + ps.range(c_in), shape=[-1]) + + k = tf.gather(kernel, subkernel_ind, axis=-2) + ind, shape = im2row_index( + [eh, ew, c_in], + block_shape=(fh_, fw_), + slice_step=(1, 1), + dilations=dilations) + x_i = x_pad[..., :eh, :ew, :] + x_i_shape = ps.shape(x_i) + flat_shape = ps.pad( + x_i_shape[:-3], paddings=[[0, 1]], constant_values=-1) + flat_x = tf.reshape(x_i, flat_shape) + x_ = tf.gather(flat_x, ind, axis=-1) + im_x = tf.reshape(x_, ps.concat([x_i_shape[:-3], shape], axis=0)) + outputs = outputs.write( + i, + tf.matmul( + im_x, + tf.reshape( + k, ps.concat( + [kernel_batch, [1, fh_ * fw_* c_in, c_out]], axis=0))) + ) + return i + 1, outputs + + outputs = tf.TensorArray(dtype=input_dtype, infer_shape=False, size=1, + dynamic_size=True) + + _, outputs = tf.while_loop( + lambda i, _: i < sh * sw, + loop_body, + [0, outputs]) + + y = outputs.concat() + + m = tf.reduce_prod(ps.shape(y)[:-3]) + y_ = tf.reshape(y, shape=ps.concat([[m], ps.shape(y)[-3:]], axis=0)) + y2 = tf.batch_to_space( + y_, strides, crops=tf.zeros([2, 2], dtype=tf.int64)) + broadcast_batch_shape = ps.broadcast_shape(batch_shape, kernel_batch) + y2 = tf.reshape(y2, ps.concat( + [broadcast_batch_shape, ps.shape(y2)[-3:]], axis=0)) + + if padding == 'VALID': + out_height = fh + sh * (xh - 1) + out_width = fw + sw * (xw - 1) + elif padding == 'SAME': + out_height = xh * sh + out_width = xw * sw + + return y2[..., truncate_top:truncate_top+out_height, + truncate_left:truncate_left+out_width, :] + return op + + +def _maybe_validate_input_shapes( + kernel_shape, channels_in, filter_height, filter_width, validate_args): + """Validate shapes of inputs to convolution op.""" + k_dim = kernel_shape[-2] + k_dim_ = tf.get_static_value(k_dim) + expected_k_dim = filter_height * filter_width * channels_in + expected_k_dim_ = tf.get_static_value(expected_k_dim) + assertions = [] + if expected_k_dim_ is not None and k_dim_ is not None: + if expected_k_dim_ != k_dim_: + raise ValueError( + 'The size of the second-to-rightmost dimension of `kernel` ( ={}) ' + ' must equal `filter_height * filter_width * channels_in` ( ={}), ' + 'where `channels_in` is the size of the rightmost dimension of the ' + 'input.'.format(k_dim_, expected_k_dim_)) + elif validate_args: + assertions.append( + assert_util.assert_equal( + k_dim, expected_k_dim, + message=('The size of the second-to-rightmost dimension of `kernel`' + ' must equal `filter_height * filter_width * channels_in`,' + ' where `channels_in` is the size of the rightmost ' + 'dimension of the input.'))) + return assertions + + +def _get_transpose_conv_dilated_padding(filter_dim, stride, dilation, padding): + """Zero-padding for inputs dilated by strides.""" + tot_filter_dim = filter_dim + (filter_dim - 1) * (dilation - 1) + if padding == 'VALID': + tot_pad = 2 * (tot_filter_dim - 1) + elif padding == 'SAME': + tot_pad = tot_filter_dim + stride - 2 + + # TODO(emilyaf): Don't need to consider case where stride > kernel_dim, right? + # if filter_dim > 1: + pad_end = tot_pad // 2 + pad_start = tot_pad - pad_end - (stride - 1) # implicit pad + # else: + # pad_end = pad_start = 0 + return pad_start, pad_end + + +def _get_output_shape(rank, strides, padding, dilations, input_shape, + output_size, filter_shape, output_padding=None): + """Compute the `output_shape` and `strides` arg used by `conv_transpose`.""" + if output_padding is None: + output_padding = (None,) * rank + else: + output_padding = utils.prepare_tuple_argument( + output_padding, n=rank, arg_name='output_padding') + for stride, out_pad in zip(strides, output_padding): + if out_pad >= stride: + raise ValueError('Stride {} must be greater than output ' + 'padding {}.'.format(strides, output_padding)) + event_shape = [] + for i in range(-rank, 0): + event_shape.append(_deconv_output_length( + input_shape[i - 1], + filter_size=filter_shape[i], + padding=padding, + output_padding=output_padding[i], + stride=strides[i], + dilation=dilations[i])) + event_shape.append(output_size) + batch_shape = input_shape[:-rank-1] + output_shape = ps.concat([batch_shape, event_shape], axis=0) + strides = ps.pad(strides, paddings=[[1, 1]], constant_values=1) + return output_shape, strides + + +def _deconv_output_length(input_size, filter_size, padding, output_padding, + stride, dilation): + """Determines output length of a transposed convolution given input length. + + Args: + input_size: `int`. + filter_size: `int`. + padding: one of `"SAME"`, `"VALID"`, `"FULL"`. + output_padding: `int`, amount of padding along the output dimension. Can + be set to `None` in which case the output length is inferred. + stride: `int`. + dilation: `int`. + + Returns: + output_length: The output length (`int`). + """ + assert padding in {'SAME', 'VALID', 'FULL'} + if input_size is None: + return None + # Get the dilated kernel size + filter_size = filter_size + (filter_size - 1) * (dilation - 1) + # Infer length if output padding is None, else compute the exact length + if output_padding is None: + if padding == 'VALID': + return input_size * stride + max(filter_size - stride, 0) + elif padding == 'FULL': + return input_size * stride - (stride + filter_size - 2) + elif padding == 'SAME': + return input_size * stride + if padding == 'SAME': + pad = filter_size // 2 + elif padding == 'VALID': + pad = 0 + elif padding == 'FULL': + pad = filter_size - 1 + return (input_size - 1) * stride + filter_size - 2 * pad + output_padding + + +def prepare_conv_args( + filter_shape, rank, strides, padding, dilations, validate_args=False): + """Sanitizes use provided input.""" + padding = _validate_padding(padding) # pylint: disable=protected-access + try: + rank = int(tf.get_static_value(rank)) + except TypeError: + raise TypeError('Argument `rank` must be statically known `int`.') + valid_rank = {1, 2, 3} + if rank not in valid_rank: + raise ValueError('Argument `rank` must be in {}.'.format(valid_rank)) + filter_shape = prepare_tuple_argument( + filter_shape, n=rank, arg_name='filter_shape', + validate_args=validate_args) + strides = prepare_tuple_argument( + strides, n=rank, arg_name='strides', validate_args=validate_args) + padding = utils._prepare_padding_argument(padding) # pylint: disable=protected-access + dilations = prepare_tuple_argument( + dilations, n=rank, arg_name='dilations', validate_args=validate_args) + return filter_shape, rank, strides, padding, dilations + + +# TODO(emilyaf): Replace the version in `utils` with this. +def prepare_tuple_argument(arg, n, arg_name, validate_args): + """Helper which processes `Tensor`s to tuples in standard form.""" + arg_size = ps.size(arg) + arg_size_ = tf.get_static_value(arg_size) + assertions = [] + if arg_size_ is not None: + if arg_size_ not in (1, n): + raise ValueError('The size of `{}` must be equal to `1` or to the rank ' + 'of the convolution (={}). Saw size = {}'.format( + arg_name, n, arg_size_)) + elif validate_args: + assertions.append(assert_util.assert_equal( + ps.logical_or(arg_size == 1, arg_size == n), + True, + message=('The size of `{}` must be equal to `1` or to the rank of the ' + 'convolution (={})'.format(arg_name, n)))) + + with tf.control_dependencies(assertions): + arg = ps.broadcast_to(arg, shape=[n]) + arg = ps.unstack(arg, num=n) + return arg + + +def _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, + c_out, batch_shape, event_shape): + """Call `tf.nn.conv2d_transpose` (for kernels with no batch dimensions).""" + fh, fw = filter_shape + flat_x = tf.reshape(x, shape=ps.concat([[-1], event_shape], axis=0)) + output_shape, strides_ = _get_output_shape( + rank=2, strides=strides, padding=padding, dilations=dilations, + input_shape=ps.shape(flat_x), output_size=c_out, + filter_shape=filter_shape) + flat_y = tf.nn.conv2d_transpose( + flat_x, + filters=tf.transpose( + tf.reshape( + kernel, shape=[fh, fw, event_shape[-1], -1]), + perm=[0, 1, 3, 2]), + output_shape=output_shape, + strides=strides_, + padding=padding, + data_format='NHWC', + dilations=dilations) + return tf.reshape( + flat_y, shape=ps.concat([batch_shape, output_shape[-3:]], axis=0)) diff --git a/tensorflow_probability/python/experimental/nn/util/convolution_util_test.py b/tensorflow_probability/python/experimental/nn/util/convolution_util_test.py new file mode 100644 index 0000000000..6dbf9d473b --- /dev/null +++ b/tensorflow_probability/python/experimental/nn/util/convolution_util_test.py @@ -0,0 +1,561 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for batched convolutions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from absl.testing import parameterized + +import numpy as np +import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp + +from tensorflow_probability.python.experimental.nn.util import convolution_util +from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.internal import test_util + +tfn = tfp.experimental.nn + + +# TODO(emilyaf): Test that gradients work. +# pylint: disable=bad-whitespace +_CONV_TEST_CASES = ( + # input dim filter c_out strides padding dilations + ((1, 32, 32, 3), (3, 4), 2, (1, 1), 'VALID', (1, 1)), + ((5, 2, 32, 32, 3), (2, 2), 4, (1, 2), 'SAME', (1, 1)), + ((5, 2, 7, 7, 3), (2, 2), 4, (1, 2), 'SAME', (2, 1)), + ((5, 2, 13, 13, 3), (2, 2), 4, (1, 2), 'SAME', (1, 1)), + ((4, 28, 28, 2), (2, 3), 2, (2, 2), 'VALID', (1, 2)) + ) + +_CONV_TRANSPOSE_TEST_CASES = ( + # input dim filter c_out strides padding dilations + ((2, 16, 16, 3), (3, 3), 4, (2, 2), 'SAME', (1, 1)), + ((2, 16, 16, 3), (4, 4), 3, (2, 2), 'SAME', (1, 1)), + ((2, 8, 8, 2), (3, 3), 3, (1, 2), 'SAME', (1, 1)), + ((4, 9, 9, 3), (3, 3), 2, (1, 1), 'SAME', (2, 2)), + ((4, 12, 9, 3), (3, 3), 1, (2, 2), 'VALID', (1, 1)), + ((2, 12, 12, 2), (2, 3), 1, (2, 2), 'VALID', (1, 1)), + ) +# pylint: enable=bad-whitespace + + +def _make_input_and_kernel( + make_input, input_batch_shape, input_shape, kernel_batch_shape, + filter_shape, channels_out, dtype): + total_input_shape = ps.concat([input_batch_shape, input_shape], axis=0) + total_kernel_shape = ps.concat( + [kernel_batch_shape, [filter_shape[0] * filter_shape[1] * input_shape[-1], + channels_out]], axis=0) + # Use integers for numerical stability. + sample_fn = lambda s: make_input(tf.cast( # pylint: disable=g-long-lambda + tf.random.uniform( + ps.cast(s, tf.int32), minval=-10, maxval=10, dtype=tf.int32), + dtype=dtype)) + return sample_fn(total_input_shape), sample_fn(total_kernel_shape) + + +def _get_conv_transpose_fn(method): + if method == 'subkernels': + return tfn.util.make_convolution_transpose_fn_with_subkernels + elif method == 'subkernels_matrix': + return tfn.util.make_convolution_transpose_fn_with_subkernels_matrix + elif method == 'dilation': + return tfn.util.make_convolution_transpose_fn_with_dilation + else: + raise ValueError('Unsupported method for `_get_conv_transpose_fn`: {}.' + ''.format(method)) + + +class _Common(object): + """Common methods for Conv/ConvTranspose tests.""" + + def assertRaisesMaybeStaticError(self, msg): + if tf.executing_eagerly() or self.use_static_shape: + return self.assertRaisesRegex(ValueError, msg) + return self.assertRaisesOpError(msg) + + def make_integer_input(self, number): + if self.use_static_shape: + return number + output = tf.Variable(number, dtype=tf.int32) + self.evaluate(output.initializer) + return output + + +@test_util.test_all_tf_execution_regimes +class Im2RowTest(test_util.TestCase): + + def test_works_like_conv2d(self): + x = tf.constant([[ + [[2], [1], [2], [0], [1]], + [[1], [3], [2], [2], [3]], + [[1], [1], [3], [3], [0]], + [[2], [2], [0], [1], [1]], + [[0], [0], [3], [1], [2]], + ]], tf.float32) # shape=[1, 5, 5, 1] + x = tf.concat([x, x], axis=-1) + k = tf.constant([ + [[[2, 0.1]], [[3, 0.2]]], + [[[0, 0.3]], [[1, 0.4]]], + ], tf.float32) # shape=[2, 2, 1, 2] + k = tf.concat([k, k], axis=-2) + strides = [1, 2] + im2row_x = tfn.util.im2row( + x, + block_shape=ps.shape(k)[:2], + slice_step=strides, + padding='VALID') + y_expected = tf.nn.conv2d(x, k, strides=strides, padding='VALID') + y_actual = tf.matmul(im2row_x, tf.reshape(k, shape=[-1, k.shape[-1]])) + [y_expected_, y_actual_] = self.evaluate([y_expected, y_actual]) + self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) + + @parameterized.parameters((tf.int32, np.int32), (tf.int64, np.int64)) + def test_dtype(self, tf_dtype, np_dtype): + ind, _ = tfn.util.im2row_index( + input_shape=(1, 12, 16, 3), + block_shape=(2, 3), + dtype=tf_dtype) + self.assertDTypeEqual(ind, np_dtype) + + +@test_util.test_all_tf_execution_regimes +class ConvolutionUtilsTest(test_util.TestCase, _Common): + + use_static_shape = False + + def test_prepare_tuple_argument(self): + + rank = 3 + + # Test that scalars are processed to tuples. + arg = convolution_util.prepare_tuple_argument( + self.make_integer_input(2), n=rank, arg_name='arg', validate_args=True) + self.assertIsInstance(arg, list) + self.assertLen(arg, rank) + + # Test that `Tensor` args are processed correctly. + arg = convolution_util.prepare_tuple_argument( + self.make_integer_input( + [2, 3, 4]), n=rank, arg_name='arg_2', validate_args=True) + self.assertIsInstance(arg, list) + self.assertLen(arg, rank) + + with self.assertRaisesRegex( + ValueError, 'must be equal to `1` or to the rank'): + convolution_util.prepare_tuple_argument( + self.make_integer_input([1, 2]), n=rank, arg_name='invalid_arg', + validate_args=True) + + def test_prepare_conv_args(self): + [filter_shape, + rank, + strides, + padding, + dilations] = convolution_util.prepare_conv_args( + (3, 3), + rank=2, + strides=2, + padding='same', + dilations=(1, 1)) + + for arg in [filter_shape, strides, dilations]: + self.assertLen(arg, rank) + + self.assertEqual(padding, 'SAME') + + +@test_util.test_all_tf_execution_regimes +class _BatchedConvTest(test_util.TestCase, _Common): + + @parameterized.parameters(*_CONV_TEST_CASES) + def test_works_like_conv2d( + self, input_shape, filter_shape, channels_out, + strides, padding, dilations): + x, k = _make_input_and_kernel( + self.make_input, + input_batch_shape=[], + input_shape=input_shape, + # Use singleton kernel_batch_shape to bypass the short circuit to tf.nn. + kernel_batch_shape=[1], + filter_shape=filter_shape, + channels_out=channels_out, + dtype=self.dtype) + + conv_fn = tfn.util.make_convolution_fn( + self.make_integer_input(filter_shape), + rank=2, + strides=self.make_integer_input(strides), + padding=padding, + dilations=self.make_integer_input(dilations), + validate_args=True) + y_actual = conv_fn(x, k) + + tf_kernel = tf.reshape( + k, shape=(filter_shape) + (input_shape[-1], channels_out)) + y_expected = tf.nn.conv2d( + x, tf_kernel, strides=strides, padding=padding, dilations=dilations) + [y_expected_, y_actual_] = self.evaluate([y_expected, y_actual]) + self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) + + @parameterized.parameters( + ((1,), ()), # scalar input batch, scalar kernel batch + ((1,), (2, 3)), # non-scalar kernel batch + ((3, 4), ()), # non-scalar input batch + ((3, 1), (2,)), # broadcasting kernel and input batch shapes + ((2, 3), (2, 3),)) # same kernel and input batch shapes + def test_batching(self, input_batch_shape, kernel_batch_shape): + input_shape = (12, 12, 2) + filter_shape = (2, 2) + channels_out = 3 + strides = (1, 1) + dilations = (1, 1) + padding = 'SAME' + + x, k = _make_input_and_kernel( + self.make_input, + input_batch_shape=input_batch_shape, + input_shape=input_shape, + kernel_batch_shape=kernel_batch_shape, + filter_shape=filter_shape, + channels_out=channels_out, + dtype=self.dtype) + + conv_fn = tfn.util.make_convolution_fn( + filter_shape, rank=2, strides=strides, padding=padding, + dilations=dilations, validate_args=True) + y_batched = conv_fn(x, k) + + broadcast_batch_shape = ps.broadcast_shape( + input_batch_shape, kernel_batch_shape) + broadcasted_input = tf.broadcast_to( + x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0)) + broadcasted_kernel = tf.broadcast_to( + k, shape=ps.concat([broadcast_batch_shape, ps.shape(k)[-2:]], axis=0)) + + flat_y = tf.reshape( + y_batched, + shape=ps.pad( + ps.shape(y_batched)[-3:], paddings=[[1, 0]], constant_values=-1)) + flat_x = tf.reshape( + broadcasted_input, + shape=ps.pad(input_shape, paddings=[[1, 0]], constant_values=-1)) + flat_tf_kernel = tf.reshape( + broadcasted_kernel, + shape=ps.concat([(-1,), filter_shape, (input_shape[-1], channels_out)], + axis=0)) + + y_expected = tf.vectorized_map( + lambda args: tf.nn.conv2d( # pylint: disable=g-long-lambda + args[0][tf.newaxis], + args[1], + strides=strides, + padding=padding), + elems=(flat_x, flat_tf_kernel)) + + [y_actual_, y_expected_] = self.evaluate( + [flat_y, tf.squeeze(y_expected, axis=1)]) + self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) + + def test_incompatible_shapes_raises(self): + filter_shape = (3, 3) + + # Inconsistent channels in for kernel and image. + c_in_kernel = 6 + c_in_image = 8 + c_out = 12 + + k_dim = np.prod(filter_shape) * c_in_kernel + kernel = self.make_input(tf.ones((2, k_dim, c_out), dtype=tf.float32)) + x = self.make_input(tf.ones((3, 2, 16, 16, c_in_image), dtype=tf.float32)) + conv_fn = tfn.util.make_convolution_fn( + self.make_integer_input(filter_shape), + rank=2, + strides=self.make_integer_input((1, 1)), + padding='SAME', + dilations=self.make_integer_input((1, 1)), + validate_args=True) + with self.assertRaisesMaybeStaticError('size of the rightmost dimension'): + self.evaluate(conv_fn(x, kernel)) + + def test_dtype(self): + # Test int64 indices. + conv_fn = tfn.util.make_convolution_fn( + (2, 2), rank=2, strides=(1, 1), padding='SAME', dilations=(1, 1), + dtype=tf.int64, validate_args=True) + x = tf.ones((2, 8, 8, 2), dtype=tf.float32) + kernel = tf.ones((2, 8, 2), dtype=tf.float32) + _ = self.evaluate(conv_fn(x, kernel)) + + # Test f64 input. + conv_fn = tfn.util.make_convolution_fn( + self.make_integer_input((2, 2)), + rank=2, + strides=self.make_integer_input((1, 1)), + padding='SAME', + dilations=self.make_integer_input((1, 1)), + validate_args=True) + x = tf.ones((2, 8, 8, 2), dtype=tf.float64) + kernel = tf.ones((2, 8, 2), dtype=tf.float64) + y = self.evaluate(conv_fn(x, kernel)) + self.assertDTypeEqual(y, np.float64) + + +@test_util.test_all_tf_execution_regimes +class _BatchedConvTransposeTest(test_util.TestCase, _Common): + + dynamic_strides_ok = True + unequal_strides_ok = True + + def make_conv_fn(self, filter_shape, strides, padding, dilations): + return _get_conv_transpose_fn(self.method)( + self.make_integer_input(filter_shape), + strides=(self.make_integer_input(strides) + if self.dynamic_strides_ok else strides), + padding=padding, + dilations=self.make_integer_input(dilations), + validate_args=True) + + @parameterized.parameters(*_CONV_TRANSPOSE_TEST_CASES) + def test_works_like_conv2d_transpose( + self, input_shape, filter_shape, channels_out, strides, padding, + dilations): + + strides_tuple = strides + if not self.unequal_strides_ok: + if strides[0] != strides[1]: + # Skip this test case if the method does not support unequal strides. + return + else: + strides = strides[0] + + x, k = _make_input_and_kernel( + self.make_input, + input_batch_shape=[], + input_shape=input_shape, + # Use singleton kernel_batch_shape to avoid the short circuit to + # `conv2d_transpose`. + kernel_batch_shape=[1], + filter_shape=filter_shape, + channels_out=channels_out, + dtype=self.dtype) + conv_fn = self.make_conv_fn(filter_shape, strides, padding, dilations) + y_actual = conv_fn(x, k) + output_shape, strides_ = convolution_util._get_output_shape( + rank=2, strides=strides_tuple, padding=padding, dilations=dilations, + input_shape=input_shape, output_size=channels_out, + filter_shape=filter_shape) + + tf_kernel = tf.transpose( + tf.reshape(k, ps.concat( + [filter_shape, [input_shape[-1], channels_out]], axis=0)), + perm=[0, 1, 3, 2]) + # conv2d_transpose does not support dilations > 1; use Keras instead. + if any(d > 1 for d in dilations): + keras_convt = tf.keras.layers.Conv2DTranspose( + filters=channels_out, + kernel_size=filter_shape, + strides=strides, + padding=padding, + dilation_rate=dilations, + use_bias=False) + _ = keras_convt(x) # build kernel + keras_convt.kernel = tf_kernel + y_expected = keras_convt(x) + else: + y_expected = tf.nn.conv2d_transpose( + x, tf_kernel, output_shape=output_shape, + strides=strides_, padding=padding, dilations=dilations) + + [y_expected_, y_actual_] = self.evaluate([y_expected, y_actual]) + self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) + + @parameterized.parameters( + ((1,), ()), # scalar input batch, scalar kernel batch + ((1,), (2, 3)), # non-scalar kernel batch + ((3, 4), ()), # non-scalar input batch + ((3, 1), (2,)), # broadcasting kernel and input batch shapes + ((2, 3), (2, 3),)) # same kernel and input batch shapes + def test_batching(self, input_batch_shape, kernel_batch_shape): + input_shape = (12, 12, 2) + filter_shape = (3, 3) + channels_out = 4 + strides = 2 + dilations = (1, 1) + padding = 'SAME' + + x, k = _make_input_and_kernel( + self.make_input, + input_batch_shape=input_batch_shape, + input_shape=input_shape, + kernel_batch_shape=kernel_batch_shape, + filter_shape=filter_shape, + channels_out=channels_out, + dtype=self.dtype) + + conv_fn = self.make_conv_fn(filter_shape, strides, padding, dilations) + y_batched = conv_fn(x, k) + + broadcast_batch_shape = ps.broadcast_shape( + input_batch_shape, kernel_batch_shape) + broadcasted_input = tf.broadcast_to( + x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0)) + broadcasted_kernel = tf.broadcast_to( + k, shape=ps.concat([broadcast_batch_shape, ps.shape(k)[-2:]], axis=0)) + + flat_y = tf.reshape( + y_batched, + shape=ps.pad( + ps.shape(y_batched)[-3:], paddings=[[1, 0]], constant_values=-1)) + flat_x = tf.reshape( + broadcasted_input, + shape=ps.pad(input_shape, paddings=[[1, 0]], constant_values=-1)) + flat_tf_kernel = tf.einsum( + '...ij->...ji', + tf.reshape( + broadcasted_kernel, + shape=ps.concat( + [(-1,), filter_shape, (input_shape[-1], channels_out)], + axis=0))) + + rank = 2 + output_shape, strides_ = convolution_util._get_output_shape( + rank=rank, strides=(strides,) * rank, padding=padding, + dilations=dilations, input_shape=input_shape, output_size=channels_out, + filter_shape=filter_shape) + + y_expected = tf.vectorized_map( + lambda args: tf.nn.conv2d_transpose( # pylint: disable=g-long-lambda + args[0][tf.newaxis], + args[1], + output_shape=ps.concat([[1], output_shape], axis=0), + strides=strides_, + padding=padding), + elems=(flat_x, flat_tf_kernel)) + + [y_actual_, y_expected_] = self.evaluate( + [flat_y, tf.squeeze(y_expected, axis=1)]) + self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) + + def test_incompatible_shapes_raises(self): + filter_shape = (3, 3) + + # Inconsistent channels in for kernel and image. + c_in_kernel = 6 + c_in_image = 8 + c_out = 12 + + k_dim = np.prod(filter_shape) * c_in_kernel + kernel = self.make_input(tf.ones((2, k_dim, c_out), dtype=self.dtype)) + x = self.make_input(tf.ones((3, 2, 16, 16, c_in_image), dtype=self.dtype)) + conv_fn = self.make_conv_fn( + filter_shape, strides=1, padding='SAME', dilations=1) + + with self.assertRaisesMaybeStaticError('size of the rightmost dimension'): + self.evaluate(conv_fn(x, kernel)) + + def test_dtype(self): + # Test int64 indices. + conv_fn = self.make_conv_fn((2, 2), strides=1, padding='SAME', dilations=1) + x = tf.ones((2, 8, 8, 2), dtype=tf.float32) + kernel = tf.ones((2, 8, 2), dtype=tf.float32) + _ = self.evaluate(conv_fn(x, kernel)) + + # Test f64 input. + conv_fn = self.make_conv_fn((2, 2), strides=1, padding='SAME', dilations=1) + x = tf.ones((2, 8, 8, 2), dtype=tf.float64) + kernel = tf.ones((2, 8, 2), dtype=tf.float64) + y = self.evaluate(conv_fn(x, kernel)) + self.assertDTypeEqual(y, np.float64) + + +@test_util.test_all_tf_execution_regimes +class BatchedConvStaticTest(_BatchedConvTest): + + dtype = tf.float32 + use_static_shape = True + + +@test_util.test_all_tf_execution_regimes +class BatchedConvDynamicTest(_BatchedConvTest): + + dtype = tf.float32 + use_static_shape = False + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithDilationsStaticTest(_BatchedConvTransposeTest): + + method = 'dilation' + dtype = tf.float32 + use_static_shape = True + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithSubkernelsMatrixStaticTest( + _BatchedConvTransposeTest): + + method = 'subkernels_matrix' + dtype = tf.float32 + use_static_shape = True + unequal_strides_ok = False + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithSubkernelsStaticTest(_BatchedConvTransposeTest): + + method = 'subkernels' + dtype = tf.float32 + use_static_shape = True + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithDilationsDynamicTest(_BatchedConvTransposeTest): + + method = 'dilation' + dtype = tf.float32 + use_static_shape = False + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithSubkernelsMatrixDynamicTest( + _BatchedConvTransposeTest): + + method = 'subkernels_matrix' + dtype = tf.float32 + use_static_shape = False + dynamic_strides_ok = False + unequal_strides_ok = False + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithSubkernelsDynamicTest(_BatchedConvTransposeTest): + + method = 'subkernels' + dtype = tf.float32 + use_static_shape = False + + +del _BatchedConvTest +del _BatchedConvTransposeTest + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_probability/python/experimental/nn/util/im2row.py b/tensorflow_probability/python/experimental/nn/util/im2row.py deleted file mode 100644 index a67f46f309..0000000000 --- a/tensorflow_probability/python/experimental/nn/util/im2row.py +++ /dev/null @@ -1,187 +0,0 @@ -# Lint as: python2, python3 -# Copyright 2020 The TensorFlow Probability Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Functions for framing `conv` as `matmul`.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow.compat.v2 as tf - -from tensorflow_probability.python.internal import prefer_static - - -__all__ = [ - 'im2row', -] - - -def im2row(x, - block_shape, - slice_step=(1, 1), - data_format='NHWC', - padding='VALID', - name=None): - """Rearrange image blocks into rows. - - This function can be used to implement 2D convolution as a `matml`, e.g., - - `tf.nn.conv2d(x, k) = tf.matmul(im2row(x), tf.reshape(k, [-1, out_size]))`. - - Args: - x: Rank 3 (or more) Tensor representing 2D images. - block_shape: Length-2 vector representing the block or "filter" shape. - slice_step: Length-2 vector specifying the convolution stride length. - Default value: `(1, 1)`. - data_format: One of `'NHWC'` or `'NCHW'` (case insensitive). - Default value: `'NHWC'`. - padding: One of `'VALID'` or `'SAME'` (case insensitive). - Default value: `'VALID'`. - name: Python `str` used to describe ops created by this function. - Default value: `None` (i.e., `'im2col'`). - - Returns: - im2row_x: batch of matrices representing subblock copies of `x`. - Same batch shape as `x` but with rightmost shape: - `batch_shape + [oh * ow, block_shape[0] * block_shape[1] * channels]`, - where `oh = (h - block_shape[0] + 1) // slice_step[0]` and - `ow = (w - block_shape[1] + 1) // slice_step[1]` when `padding = 'VALID'` - and `oh = h` and `ow = w` when `padding = 'SAME'`. - shape: shape `Tensor` equivalent to: - `batch_shape + [oh, ow, block_shape[0] * block_shape[1] * channels]` where - `oh, ow` are defined as above. - """ - with tf.name_scope(name or 'im2row'): - data_format = _validate_data_format(data_format) - padding = _validate_padding(padding) - if padding == 'VALID': - pass # Do nothing. - elif padding == 'SAME': - raise NotImplementedError( - 'Argument padding="SAME" not implemented.') - # TODO(jvdillon): See if the following works: - # fh, fw = block_shape - # o = 1 if data_format == 'NHWC' else 0 - # n = prefer_static.maximum(0, prefer_static.rank(x) - 3) - # paddings = prefer_static.pad( - # [[0, fh - 1], [0, fw - 1]], - # paddings=[[n + 1 - o, o], [0, 0]], - # constant_values=0) - # x = tf.pad(x, paddings=paddings, constant_values=0) - # padding = 'VALID' - else: - assert False # Can't be here. - x_shape = prefer_static.shape(x) - idx, s = _im2row_index( - x_shape, block_shape, slice_step, data_format, padding) - flat_shape = prefer_static.pad( - x_shape[:-3], paddings=[[0, 1]], constant_values=-1) - x = tf.gather(tf.reshape(x, flat_shape), idx, axis=-1) # == np.take - return tf.reshape(x, s) - - -def _im2row_index(input_shape, - block_shape, - slice_step=(1, 1), - data_format='NHWC', - padding='VALID', - dtype=tf.int64, - name=None): - """Computes indexes into a flattened image for building `im2col`.""" - with tf.name_scope(name or 'im2row_index'): - # 1) Process input arguments. - batch_shape, s3, s2, s1 = prefer_static.split( - prefer_static.cast(input_shape, tf.int32), - num_or_size_splits=[-1, 1, 1, 1]) - fh, fw = _split_pair(block_shape) - sh, sw = _split_pair(slice_step) - data_format = _validate_data_format(data_format) - padding = _validate_padding(padding) - - # 2) Assemble all block start positions as indexes into the flattened image. - if data_format == 'NHWC': - h, w, c = s3[0], s2[0], s1[0] - # start_idx.shape = [fh, fw, c] - start_idx = _cartesian_add([ - prefer_static.range(c * w * fh, delta=c * w, dtype=dtype), - prefer_static.range(c * fw, delta=c, dtype=dtype), - prefer_static.range(c, delta=1, dtype=dtype), - ]) - elif data_format == 'NCHW': - c, h, w = s3[0], s2[0], s1[0] - # start_idx.shape = [c, fh, fw] - start_idx = _cartesian_add([ - prefer_static.range(w * h * c, delta=w * h, dtype=dtype), - prefer_static.range(w * fh, delta=w, dtype=dtype), - prefer_static.range(fw, delta=1, dtype=dtype), - ]) - else: - assert False # Can't be here. - - # 3) Assemble all block offsets (into flattened image). - if padding == 'VALID': - eh = h - fh + 1 # extent height - ew = w - fw + 1 # extent width - # offset_idx.shape = [eh // sh, ew // sw] - offset_idx = _cartesian_add([ - prefer_static.range(w * eh, delta=w * sh, dtype=dtype), - prefer_static.range(ew, delta=sw, dtype=dtype), - ]) - if data_format == 'NHWC': - offset_idx *= c - oh = eh // sh # out height - ow = ew // sw # out width - else: - assert False # Can't be here. - - # 4) Combine block start/offset pairs. - # shape = [(eh // sh) * (ew // sw), fh * fw * c] - idx = _cartesian_add([offset_idx, start_idx]) - new_shape = [oh, ow, fh * fw * c] - new_shape = prefer_static.concat([batch_shape, new_shape], axis=0) - return idx, new_shape - - -def _split_pair(x): - """Splits a length two vector into two scalars.""" - x = prefer_static.cast(x, dtype=tf.int32) - a, b = prefer_static.split(x, num_or_size_splits=[1, 1]) - return a[0], b[0] - - -def _cartesian_add(xs): - """Adds a list of vectors by cumulatively expanding a dimension.""" - return sum(prefer_static.reshape(x, shape=[-1] + [1]*(len(xs) - 1 - i)) - for i, x in enumerate(xs)) - - -def _validate_data_format(data_format): - """Verify correctness of `data_format` argument.""" - data_format_ = str(data_format).upper() - if data_format_ in {'NHWC', 'NCHW'}: - return data_format_ - raise ValueError( - 'Argument data_format="{}" not recognized; must be one of ' - '{{"NHWC", "NCHW"}} (case insensitive).'.format(data_format)) - - -def _validate_padding(padding): - """Verify correctness of `padding` argument.""" - padding_ = str(padding).upper() - if padding_ in {'SAME', 'VALID'}: - return padding_ - raise ValueError( - 'Argument padding="{}" not recognized; must be one of ' - '{{"VALID", "SAME"}} (case insensitive).'.format(padding)) diff --git a/tensorflow_probability/python/experimental/nn/util/im2row_test.py b/tensorflow_probability/python/experimental/nn/util/im2row_test.py deleted file mode 100644 index c5591db944..0000000000 --- a/tensorflow_probability/python/experimental/nn/util/im2row_test.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2020 The TensorFlow Probability Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Tests for im2col.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports -import tensorflow.compat.v2 as tf -import tensorflow_probability as tfp - -from tensorflow_probability.python.internal import test_util - -tfn = tfp.experimental.nn - - -@test_util.test_all_tf_execution_regimes -class Im2ColTest(test_util.TestCase): - - def test_works_like_conv2d(self): - x = tf.constant([[ - [[2], [1], [2], [0], [1]], - [[1], [3], [2], [2], [3]], - [[1], [1], [3], [3], [0]], - [[2], [2], [0], [1], [1]], - [[0], [0], [3], [1], [2]], - ]], tf.float32) # shape=[1, 5, 5, 1] - x = tf.concat([x, x], axis=-1) - k = tf.constant([ - [[[2, 0.1]], [[3, 0.2]]], - [[[0, 0.3]], [[1, 0.4]]], - ], tf.float32) # shape=[2, 2, 1, 2] - k = tf.concat([k, k], axis=-2) - strides = [1, 2] - im2row_x = tfn.util.im2row( - x, - block_shape=k.shape[:2], - slice_step=strides, - padding='VALID') - y_expected = tf.nn.conv2d(x, k, strides=strides, padding='VALID') - y_actual = tf.matmul(im2row_x, tf.reshape(k, [-1, k.shape[-1]])) - [y_expected_, y_actual_] = self.evaluate([y_expected, y_actual]) - self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) - - -if __name__ == '__main__': - tf.test.main() From 605a1ec31645fd50a224d847397fbefbbfcce02a Mon Sep 17 00:00:00 2001 From: ebrevdo Date: Wed, 9 Dec 2020 14:33:48 -0800 Subject: [PATCH 09/36] [TFP] Do non-lazy loading of symbols if TF is already loaded. This means one must only 'import tensorflow_probability' to load saved models that have serialized TFP keras layers and TFP CompositeTensor specs. PiperOrigin-RevId: 346639386 --- tensorflow_probability/python/__init__.py | 39 +++++++++++++++++------ 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/tensorflow_probability/python/__init__.py b/tensorflow_probability/python/__init__.py index 345dbe2511..d5fd9afa7e 100644 --- a/tensorflow_probability/python/__init__.py +++ b/tensorflow_probability/python/__init__.py @@ -19,6 +19,7 @@ from __future__ import print_function import functools +import sys import types from tensorflow_probability.python.internal import all_util @@ -39,7 +40,7 @@ def _validate_tf_environment(package): """ try: import tensorflow.compat.v1 as tf - except ImportError: + except (ImportError, ModuleNotFoundError): # Print more informative error message, then reraise. print('\n\nFailed to import TensorFlow. Please note that TensorFlow is not ' 'installed by default when you install TensorFlow Probability. This ' @@ -96,13 +97,11 @@ def _validate_tf_environment(package): util: types.ModuleType vi: types.ModuleType -_allowed_symbols = [ +_lazy_load = [ 'bijectors', 'debugging', 'distributions', - 'experimental', 'glm', - 'layers', 'math', 'mcmc', 'monte_carlo', @@ -114,11 +113,33 @@ def _validate_tf_environment(package): 'vi', ] -for pkg in _allowed_symbols: - globals()[pkg] = lazy_loader.LazyLoader( - pkg, globals(), 'tensorflow_probability.python.{}'.format(pkg), +# If TensorFlow is already imported, we should non-lazily load modules which +# include registrations (e.g., Keras layer registrations and CompositeTensor +# registrations) -- which must be loaded when deserializing tensorflow +# saved models. +_maybe_nonlazy_load = [ + 'experimental', + 'layers', +] + + +def _tf_loaded(): + return 'compat' in dir(sys.modules.get('tensorflow', None)) + + +# To start with, lazy-load everything. Later we may replace some of the +# lazy-loaded modules by forcing a load. +for pkg_name in _lazy_load + _maybe_nonlazy_load: + globals()[pkg_name] = lazy_loader.LazyLoader( + pkg_name, globals(), 'tensorflow_probability.python.{}'.format(pkg_name), # These checks need to happen before lazy-loading, since the modules # themselves will try to import tensorflow, too. - on_first_access=functools.partial(_validate_tf_environment, pkg)) + on_first_access=functools.partial(_validate_tf_environment, pkg_name)) + +if _tf_loaded(): + # Non-lazy load of packages that register with tensorflow or keras. + for pkg_name in _maybe_nonlazy_load: + dir(globals()[pkg_name]) # Forces loading the package from its lazy loader. + -all_util.remove_undocumented(__name__, _allowed_symbols) +all_util.remove_undocumented(__name__, _lazy_load + _maybe_nonlazy_load) From a1de748db273d4529c00a461424554e7d33455da Mon Sep 17 00:00:00 2001 From: emilyaf Date: Wed, 9 Dec 2020 17:04:39 -0800 Subject: [PATCH 10/36] Fix bug in the `batch_shape` of `JointDistributions` transformed by the `Restructure` bijector. PiperOrigin-RevId: 346669515 --- .../distributions/transformed_distribution.py | 10 +++--- .../transformed_distribution_test.py | 31 +++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/tensorflow_probability/python/distributions/transformed_distribution.py b/tensorflow_probability/python/distributions/transformed_distribution.py index 68daa369a7..a2e83936cf 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution.py +++ b/tensorflow_probability/python/distributions/transformed_distribution.py @@ -289,12 +289,11 @@ def _batch_shape_tensor(self): # dtype.) if tf.nest.is_nested(base_batch_shape_tensor): if self._is_joint: - return base_batch_shape_tensor - + return tf.nest.pack_sequence_as( + self.dtype, tf.nest.flatten(base_batch_shape_tensor)) base_batch_shape_tensor = functools.reduce( ps.broadcast_shape, tf.nest.flatten(base_batch_shape_tensor)) - return base_batch_shape_tensor def _batch_shape(self): @@ -308,7 +307,10 @@ def _batch_shape(self): # the batch shape components of the base distribution are broadcast to # obtain the batch shape of the transformed distribution. batch_shape = self.distribution.batch_shape - if tf.nest.is_nested(batch_shape) and not self._is_joint: + if tf.nest.is_nested(batch_shape): + if self._is_joint: + return tf.nest.pack_sequence_as( + self.dtype, tf.nest.flatten(batch_shape)) batch_shape = functools.reduce( tf.broadcast_static_shape, tf.nest.flatten(batch_shape)) return batch_shape diff --git a/tensorflow_probability/python/distributions/transformed_distribution_test.py b/tensorflow_probability/python/distributions/transformed_distribution_test.py index 04e3ea43e6..712c09fe3b 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution_test.py +++ b/tensorflow_probability/python/distributions/transformed_distribution_test.py @@ -1047,6 +1047,37 @@ def test_transform_joint_to_joint(self, split_sizes): self.assertAllEqual(tf.nest.map_structure(lambda y: y.shape, y), tf.nest.map_structure(lambda y: y.shape, y_sampled)) + # Test that a `Restructure` bijector applied to a `JointDistribution` works + # as expected. + num_components = len(split_sizes) + input_keys = (split_sizes.keys() if isinstance(split_sizes, dict) + else range(num_components)) + output_keys = [str(i) for i in range(num_components)] + output_structure = {k: v for k, v in zip(output_keys, input_keys)} + restructure = tfb.Restructure(output_structure) + restructured_dist = tfd.TransformedDistribution( + base_dist, bijector=restructure, validate_args=True) + + # Check that attributes of the restructured distribution have the same + # nested structure as the `output_structure` of the bijector. Pass a no-op + # as the `assert_fn` since the contents of the structures are not + # required to be the same. + noop_assert_fn = lambda *_: None + self.assertAllAssertsNested( + noop_assert_fn, restructured_dist.event_shape, output_structure) + self.assertAllAssertsNested( + noop_assert_fn, restructured_dist.batch_shape, output_structure) + self.assertAllAssertsNested( + noop_assert_fn, + self.evaluate(restructured_dist.event_shape_tensor()), + output_structure) + self.assertAllAssertsNested( + noop_assert_fn, + self.evaluate(restructured_dist.batch_shape_tensor()), + output_structure) + self.assertAllAssertsNested( + noop_assert_fn, + self.evaluate(restructured_dist.sample(seed=test_util.test_seed()))) if __name__ == '__main__': tf.test.main() From cb2ed6972ac7f24fe61c8c64dc9b2d20ea57b680 Mon Sep 17 00:00:00 2001 From: sharadmv Date: Wed, 9 Dec 2020 17:43:44 -0800 Subject: [PATCH 11/36] Enable `tfp.experimental.distribute` in JAX backend PiperOrigin-RevId: 346676496 --- .../python/experimental/BUILD | 4 +- .../python/experimental/distribute/BUILD | 39 +++- .../experimental/distribute/distribute_lib.py | 208 ++++++++++++------ .../distribute/distribute_lib_test.py | 100 ++++----- .../distribute/distribute_test_lib.py | 74 +++++++ .../distribute/joint_distribution.py | 115 +++++++++- .../distribute/joint_distribution_test.py | 108 ++++----- .../python/experimental/distribute/sharded.py | 100 +++++++-- .../experimental/distribute/sharded_test.py | 39 ++-- .../python/internal/samplers.py | 21 +- 10 files changed, 570 insertions(+), 238 deletions(-) create mode 100644 tensorflow_probability/python/experimental/distribute/distribute_test_lib.py diff --git a/tensorflow_probability/python/experimental/BUILD b/tensorflow_probability/python/experimental/BUILD index 42f2e5e98a..414ce58fdc 100644 --- a/tensorflow_probability/python/experimental/BUILD +++ b/tensorflow_probability/python/experimental/BUILD @@ -33,11 +33,13 @@ exports_files(["LICENSE"]) multi_substrate_py_library( name = "experimental", srcs = ["__init__.py"], + numpy_omit_deps = [ + "//tensorflow_probability/python/experimental/distribute", + ], srcs_version = "PY3", substrates_omit_deps = [ ":composite_tensor", "//tensorflow_probability/python/experimental/auto_batching", - "//tensorflow_probability/python/experimental/distribute", "//tensorflow_probability/python/experimental/lazybones", "//tensorflow_probability/python/experimental/linalg", "//tensorflow_probability/python/experimental/marginalize", diff --git a/tensorflow_probability/python/experimental/distribute/BUILD b/tensorflow_probability/python/experimental/distribute/BUILD index f0c7898f50..610f32e50f 100644 --- a/tensorflow_probability/python/experimental/distribute/BUILD +++ b/tensorflow_probability/python/experimental/distribute/BUILD @@ -14,6 +14,11 @@ # ============================================================================ # Description: # Contains utilities for writing distributed TFP code. +load( + "//tensorflow_probability/python:build_defs.bzl", + "multi_substrate_py_library", + "multi_substrate_py_test", +) licenses(["notice"]) @@ -23,7 +28,7 @@ package( ], ) -py_library( +multi_substrate_py_library( name = "distribute", srcs = ["__init__.py"], srcs_version = "PY3", @@ -34,7 +39,7 @@ py_library( ], ) -py_library( +multi_substrate_py_library( name = "distribute_lib", srcs = ["distribute_lib.py"], srcs_version = "PY3", @@ -43,7 +48,7 @@ py_library( ], ) -py_library( +multi_substrate_py_library( name = "sharded", srcs = ["sharded.py"], deps = [ @@ -55,7 +60,7 @@ py_library( ], ) -py_library( +multi_substrate_py_library( name = "joint_distribution", srcs = ["joint_distribution.py"], deps = [ @@ -66,39 +71,59 @@ py_library( ], ) -py_test( +multi_substrate_py_library( + name = "distribute_test_lib", + testonly = 1, + srcs = ["distribute_test_lib.py"], + srcs_version = "PY3", + deps = [ + # tensorflow dep, + "//tensorflow_probability/python/internal:test_util", + ], +) + +multi_substrate_py_test( name = "sharded_test", srcs = ["sharded_test.py"], + disabled_substrates = ["numpy"], python_version = "PY3", deps = [ + ":distribute_lib", + ":distribute_test_lib", ":sharded", # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/internal:test_util", ], ) -py_test( +multi_substrate_py_test( name = "joint_distribution_test", srcs = ["joint_distribution_test.py"], + disabled_substrates = ["numpy"], python_version = "PY3", deps = [ + ":distribute_test_lib", ":joint_distribution", ":sharded", # absl/testing:parameterized dep, # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/internal:test_util", ], ) -py_test( +multi_substrate_py_test( name = "distribute_lib_test", srcs = ["distribute_lib_test.py"], + disabled_substrates = ["numpy"], python_version = "PY3", srcs_version = "PY3", deps = [ ":distribute_lib", + ":distribute_test_lib", # tensorflow dep, "//tensorflow_probability", "//tensorflow_probability/python/internal:test_util", diff --git a/tensorflow_probability/python/experimental/distribute/distribute_lib.py b/tensorflow_probability/python/experimental/distribute/distribute_lib.py index bab1b6b37e..574c5e489f 100644 --- a/tensorflow_probability/python/experimental/distribute/distribute_lib.py +++ b/tensorflow_probability/python/experimental/distribute/distribute_lib.py @@ -19,25 +19,66 @@ from __future__ import print_function import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient +JAX_MODE = False -def psum(x): +if JAX_MODE: + import jax # pylint: disable=g-import-not-at-top + from jax import lax # pylint: disable=g-import-not-at-top + + +def psum(x, axis_name=None): + if JAX_MODE: + return lax.psum(x, axis_name) ctx = tf.distribute.get_replica_context() return ctx.all_reduce('sum', x) -def pmean(x): +def pmean(x, axis_name=None): + if JAX_MODE: + return lax.pmean(x, axis_name) ctx = tf.distribute.get_replica_context() return ctx.all_reduce('mean', x) +def get_replica_id(axis_name=None): + if JAX_MODE: + return lax.axis_index(axis_name) + ctx = tf.distribute.get_replica_context() + return ctx.replica_id_in_sync_group + + +def get_num_replicas(axis_name=None): + if JAX_MODE: + return lax.psum(1, axis_name) + ctx = tf.distribute.get_replica_context() + return ctx.num_replicas_in_sync + + class _DummyGrads(object): + """Wraps gradients to preserve structure when computing a custom gradient.""" def __init__(self, grads): self.grads = grads + def tree_flatten(self): + return (self.grads,), () + + @classmethod + def tree_unflatten(cls, _, xs): + return cls(*xs) + + def __repr__(self): + return f'_DummyGrads({self.grads})' + + +if JAX_MODE: + from jax import tree_util # pylint: disable=g-import-not-at-top + tree_util.register_pytree_node_class(_DummyGrads) -def make_sharded_log_prob_parts(log_prob_parts_fn, is_sharded): + +def make_sharded_log_prob_parts(log_prob_parts_fn, is_sharded, axis_name=None): """Constructs a log prob parts function that all-reduces over terms. Given a log_prob_parts function, this function will return a new one that @@ -55,81 +96,116 @@ def make_sharded_log_prob_parts(log_prob_parts_fn, is_sharded): add an all-reduce sum for its term in the log prob calculation. If it is `False`, the returned function will have an all-reduce sum over the gradient of sharded terms w.r.t. to the unsharded value. + axis_name: a `str` used for the axis name in the JAX backend. Unused in the + TensorFlow backend. Returns: A new log prob parts function that can be run inside of strategy. """ - @tf.custom_gradient - def sharded_log_prob_parts(value): + def _sharded_log_prob_parts_fwd(value): tf.nest.assert_same_structure(value, is_sharded) - with tf.GradientTape(persistent=True) as tape: - tape.watch(value) + if JAX_MODE: + def flat_log_prob_parts_fn(flat_args): + args = tf.nest.pack_sequence_as(is_sharded, flat_args) + log_prob_parts = log_prob_parts_fn(args) + return tf.nest.flatten(log_prob_parts) + + def wrapped_log_prob(value): + flat_sharded = tf.nest.flatten(is_sharded) + return tf.nest.pack_sequence_as( + is_sharded, + [ + _DummyGrads(tf.nest.pack_sequence_as(is_sharded, [ # pylint: disable=g-complex-comprehension + jax.grad(lambda v: flat_log_prob_parts_fn(v)[i]) # pylint: disable=cell-var-from-loop + (tf.nest.flatten(value))[j] + for i in range(len(flat_sharded)) + ])) + for j in range(len(flat_sharded)) + ]) + log_prob_parts = log_prob_parts_fn(value) - tf.nest.assert_same_structure(log_prob_parts, is_sharded) + local_grads = wrapped_log_prob(value) + else: + with tf.GradientTape(persistent=True) as tape: + tape.watch(value) + log_prob_parts = log_prob_parts_fn(value) + tf.nest.assert_same_structure(log_prob_parts, is_sharded) + + def local_grad(v): + return _DummyGrads( + tf.nest.map_structure( + lambda log_prob_part: tape.gradient(log_prob_part, v), + log_prob_parts)) + local_grads = tf.nest.map_structure(local_grad, value) total_log_prob_parts = tf.nest.map_structure( lambda log_prob_part, sharded: ( # pylint: disable=g-long-lambda - psum(log_prob_part) if sharded else log_prob_part), + psum(log_prob_part, axis_name=axis_name) + if sharded else log_prob_part), log_prob_parts, is_sharded) - def vjp(*gs): - gs = tf.nest.pack_sequence_as(log_prob_parts, gs) - - def local_grad(v, g): - return _DummyGrads( - tf.nest.map_structure( - lambda lp: tape.gradient(lp, v, output_gradients=g), - log_prob_parts)) - - local_grads = tf.nest.map_structure(local_grad, value, gs) - - def value_grad(v, value_sharded, term_grads): - """Computes reductions of output gradients. - - A `log_prob_parts` function takes in a list of values and outputs - a log density for each input to the function. The vector-Jacobian - product (VJP) of a `log_prob_parts` function thus needs to compute the - gradient of each output term w.r.t. each input value. This function - overrides the default VJP of an output term `j` w.r.t to an input - value `i` to include an all-reduce-sum when: - 1) The gradient of `j` w.r.t. `i` is connected. - 2) `j` is a sharded term and `i` is an unsharded value. - - If these conditions do not hold, the gradient remains the same and - either corresponds to: - 1) The gradient of a sharded term w.r.t to a sharded value - 2) The gradient of an unsharded term w.r.t. to an unsharded value. - 3) The gradient of an unsharded term w.r.t. to an sharded value. - In any of these cases, no all-reduce-sum is necessary. - Args: - v: The output term of a `log_prob_part` function. - value_sharded: A boolean indicating whether or not the output term is - is sharded or not. - term_grads: The gradient of the output term w.r.t. to each of the - input values to the `log_prob_part` function. - Returns: - The vector Jacobian product of `v` w.r.t. the input parts of the - `log_prob_parts` function. - """ - term_grads = term_grads.grads - - def psum_grads(term_grad, term_sharded): - if term_grad is not None: - if not value_sharded and term_sharded: - term_grad = psum(term_grad) - return term_grad - - total_grad = tf.nest.map_structure(psum_grads, term_grads, - is_sharded) - if all([grad is None for grad in tf.nest.flatten(total_grad)]): - return None - return tf.add_n( - [v for v in tf.nest.flatten(total_grad) if v is not None]) - - return tf.nest.map_structure(value_grad, value, is_sharded, local_grads) - - return total_log_prob_parts, vjp + return total_log_prob_parts, (value, local_grads) + + def _sharded_log_prob_parts_bwd(res, gs): + value, local_grads = res + + def grad_mul(vs, g): + return tf.nest.map_structure(lambda v: v * g if v is not None else v, vs) + + local_grads = tf.nest.map_structure( + lambda v, g: _DummyGrads(grad_mul(v.grads, g)), local_grads, gs) + + def value_grad(v, value_sharded, term_grads): + """Computes reductions of output gradients. + + A `log_prob_parts` function takes in a list of values and outputs + a log density for each input to the function. The vector-Jacobian + product (VJP) of a `log_prob_parts` function thus needs to compute the + gradient of each output term w.r.t. each input value. This function + overrides the default VJP of an output term `j` w.r.t to an input + value `i` to include an all-reduce-sum when: + 1) The gradient of `j` w.r.t. `i` is connected. + 2) `j` is a sharded term and `i` is an unsharded value. + + If these conditions do not hold, the gradient remains the same and + either corresponds to: + 1) The gradient of a sharded term w.r.t to a sharded value + 2) The gradient of an unsharded term w.r.t. to an unsharded value. + 3) The gradient of an unsharded term w.r.t. to an sharded value. + In any of these cases, no all-reduce-sum is necessary. + Args: + v: The output term of a `log_prob_part` function. + value_sharded: A boolean indicating whether or not the output term is + sharded or not. + term_grads: The gradient of the output term w.r.t. to each of the input + values to the `log_prob_part` function. + + Returns: + The vector Jacobian product of `v` w.r.t. the input parts of the + `log_prob_parts` function. + """ + term_grads = term_grads.grads + def psum_grads(term_grad, term_sharded): + if term_grad is not None: + if not value_sharded and term_sharded: + term_grad = psum(term_grad, axis_name=axis_name) + return term_grad + + total_grad = tf.nest.map_structure(psum_grads, term_grads, + is_sharded) + if all([grad is None for grad in tf.nest.flatten(total_grad)]): + return None + return tf.add_n( + [v for v in tf.nest.flatten(total_grad) if v is not None]) + + out = tf.nest.map_structure(value_grad, value, is_sharded, local_grads) + return (out,) + + @tfp_custom_gradient.custom_gradient( + vjp_fwd=_sharded_log_prob_parts_fwd, vjp_bwd=_sharded_log_prob_parts_bwd) + def sharded_log_prob_parts(value): + return _sharded_log_prob_parts_fwd(value)[0] return sharded_log_prob_parts diff --git a/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py b/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py index 3f65709a54..9bd095962f 100644 --- a/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py +++ b/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py @@ -20,33 +20,17 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp from tensorflow_probability.python.experimental.distribute import distribute_lib +from tensorflow_probability.python.experimental.distribute import distribute_test_lib as test_lib from tensorflow_probability.python.internal import test_util tfd = tfp.distributions -NUM_DEVICES = 4 - - -def per_replica_to_tensor(value): - return tf.nest.map_structure( - lambda per_replica: tf.stack(per_replica.values, axis=0), value) - @test_util.test_all_tf_execution_regimes -class LogProbPartsTest(test_util.TestCase): - - def setUp(self): - super(LogProbPartsTest, self).setUp() - self.strategy = tf.distribute.MirroredStrategy( - devices=tf.config.list_logical_devices()) - - def shard_values(self, values): - - def value_fn(ctx): - return values[ctx.replica_id_in_sync_group] - - return self.strategy.experimental_distribute_values_from_function(value_fn) +class LogProbPartsTest(test_lib.DistributedTest): + @test_util.disable_test_for_backend( + disable_jax=True, reason='Behavior supported natively') def test_can_shard_values_across_logical_devices(self): @tf.function(autograph=False) @@ -59,9 +43,12 @@ def add_one(x): values = self.strategy.experimental_distribute_values_from_function( value_fn) out_values = self.evaluate( - per_replica_to_tensor(self.strategy.run(add_one, (values,)))) + self.per_replica_to_tensor(self.strategy_run(add_one, (values,)))) self.assertAllEqual(out_values, [1., 2., 3., 4.]) + @test_util.disable_test_for_backend( + disable_jax=True, + reason='Cannot use sharded distributions outside of pmap.') def test_correct_log_prob_for_global_variable_no_strategy(self): data = tf.ones(4) @@ -73,7 +60,7 @@ def log_prob_parts(value): ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [False, True]) + log_prob_parts, [False, True], axis_name=None) self.assertAllEqualNested( self.evaluate(sharded_log_prob_parts([tf.constant(0.), data])), self.evaluate([ @@ -81,6 +68,9 @@ def log_prob_parts(value): tf.reduce_sum(tfd.Normal(0., 1.).log_prob(data)) ])) + @test_util.disable_test_for_backend( + disable_jax=True, + reason='Cannot use sharded distributions outside of pmap.') def test_correct_log_prob_for_local_variable_no_strategy(self): data = tf.ones(4) @@ -93,7 +83,7 @@ def log_prob_parts(value): ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [True, True]) + log_prob_parts, [True, True], axis_name=None) self.assertAllEqualNested( self.evaluate(sharded_log_prob_parts([tf.ones(4), data])), self.evaluate([ @@ -103,7 +93,6 @@ def log_prob_parts(value): def test_correct_log_prob_for_global_variable(self): - @tf.function(autograph=False) def run(x, data): def log_prob_parts(value): @@ -114,14 +103,15 @@ def log_prob_parts(value): ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [False, True]) + log_prob_parts, [False, True], axis_name=self.axis_name) return sharded_log_prob_parts([x, data]) x = tf.constant(0.) data = tf.ones(4) sharded_data = self.shard_values(data) - out_parts = per_replica_to_tensor(self.strategy.run(run, (x, sharded_data))) + out_parts = self.per_replica_to_tensor( + self.strategy_run(run, (x, sharded_data), in_axes=(None, 0))) self.assertAllEqualNested( self.evaluate(out_parts), @@ -132,7 +122,6 @@ def log_prob_parts(value): def test_correct_log_prob_for_local_variable(self): - @tf.function(autograph=False) def run(x, data): def log_prob_parts(value): @@ -143,7 +132,7 @@ def log_prob_parts(value): ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [True, True]) + log_prob_parts, [True, True], axis_name=self.axis_name) return sharded_log_prob_parts([x, data]) @@ -151,8 +140,8 @@ def log_prob_parts(value): sharded_x = self.shard_values(x) data = tf.ones(4) sharded_data = self.shard_values(data) - out_parts = per_replica_to_tensor( - self.strategy.run(run, (sharded_x, sharded_data))) + out_parts = self.per_replica_to_tensor( + self.strategy_run(run, (sharded_x, sharded_data))) self.assertAllEqualNested( self.evaluate(out_parts), @@ -163,7 +152,6 @@ def log_prob_parts(value): def test_correct_log_prob_for_global_and_local_variable(self): - @tf.function(autograph=False) def run(w, x, data): def log_prob_parts(values): @@ -175,7 +163,7 @@ def log_prob_parts(values): ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [False, True, True]) + log_prob_parts, [False, True, True], axis_name=self.axis_name) return sharded_log_prob_parts([w, x, data]) @@ -184,8 +172,9 @@ def log_prob_parts(values): sharded_x = self.shard_values(x) data = 3 * tf.ones(4) sharded_data = self.shard_values(data) - out_parts = per_replica_to_tensor( - self.strategy.run(run, (w, sharded_x, sharded_data))) + out_parts = self.per_replica_to_tensor( + self.strategy_run( + run, (w, sharded_x, sharded_data), in_axes=(None, 0, 0))) self.assertAllEqualNested( self.evaluate(out_parts), @@ -197,7 +186,6 @@ def log_prob_parts(values): def test_correct_gradient_for_global_variable(self): - @tf.function(autograph=False) def run(x, data): def log_prob_parts(value): @@ -209,7 +197,7 @@ def log_prob_parts(value): def log_prob(x): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [False, True]) + log_prob_parts, [False, True], axis_name=self.axis_name) parts = sharded_log_prob_parts([x, data]) return tf.add_n(parts) @@ -218,7 +206,8 @@ def log_prob(x): x = tf.constant(1.) data = 2 * tf.ones(4) sharded_data = self.shard_values(data) - out_grads = per_replica_to_tensor(self.strategy.run(run, (x, sharded_data))) + out_grads = self.per_replica_to_tensor( + self.strategy_run(run, (x, sharded_data), in_axes=(None, 0))) def true_log_prob(x): return (tfd.Normal(0., 1.).log_prob(x) + @@ -242,7 +231,7 @@ def log_prob_parts(value): def log_prob(x): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [True, True]) + log_prob_parts, [True, True], axis_name=self.axis_name) parts = sharded_log_prob_parts([x, data]) return tf.add_n(parts) @@ -252,8 +241,8 @@ def log_prob(x): sharded_x = self.shard_values(x) data = 2 * tf.ones(4) sharded_data = self.shard_values(data) - out_grads = per_replica_to_tensor(self.strategy.run(run, (sharded_x, - sharded_data))) + out_grads = self.per_replica_to_tensor( + self.strategy_run(run, (sharded_x, sharded_data))) def true_log_prob(x): return (tf.reduce_sum(tfd.Normal(0., 1.).log_prob(x)) + @@ -265,7 +254,6 @@ def true_log_prob(x): def test_correct_gradient_for_global_and_local_variable(self): - @tf.function(autograph=False) def run(w, x, data): def log_prob_parts(value): @@ -279,7 +267,7 @@ def log_prob_parts(value): def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [False, True, True]) + log_prob_parts, [False, True, True], axis_name=self.axis_name) parts = sharded_log_prob_parts([w, x, data]) return tf.add_n(parts) @@ -290,8 +278,9 @@ def log_prob(*value): sharded_x = self.shard_values(x) data = 2 * tf.ones(4) sharded_data = self.shard_values(data) - out_grads = per_replica_to_tensor(self.strategy.run(run, (w, sharded_x, - sharded_data))) + out_grads = self.per_replica_to_tensor( + self.strategy_run( + run, (w, sharded_x, sharded_data), in_axes=(None, 0, 0))) def true_log_prob(*value): w, x = value @@ -302,8 +291,8 @@ def true_log_prob(*value): true_grad = tfp.math.value_and_gradient(true_log_prob, [w, x])[1] true_grad[0] = tf.ones(4) * true_grad[0] - self.assertAllEqualNested(self.evaluate(out_grads), - self.evaluate(true_grad)) + self.assertAllEqualNested( + self.evaluate(out_grads), self.evaluate(true_grad)) def test_correct_gradient_for_global_and_local_variable_dict(self): @@ -313,14 +302,15 @@ def run(w, x, data): def log_prob_parts(value): return { 'w': tfd.Normal(0., 1.).log_prob(value['w']), - 'x': tfd.Normal(w, 1.).log_prob(value['x']), - 'data': tfd.Normal(x, 1.).log_prob(value['data']), + 'x': tfd.Normal(value['w'], 1.).log_prob(value['x']), + 'data': tfd.Normal(value['x'], 1.).log_prob(value['data']), } def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, {'w': False, 'x': True, 'data': True}) + log_prob_parts, {'w': False, 'x': True, 'data': True}, + axis_name=self.axis_name) parts = sharded_log_prob_parts({'w': w, 'x': x, 'data': data}) return tf.add_n(tf.nest.flatten(parts)) @@ -331,8 +321,9 @@ def log_prob(*value): sharded_x = self.shard_values(x) data = 2 * tf.ones(4) sharded_data = self.shard_values(data) - out_grads = per_replica_to_tensor(self.strategy.run(run, (w, sharded_x, - sharded_data))) + out_grads = self.per_replica_to_tensor( + self.strategy_run(run, (w, sharded_x, sharded_data), + in_axes=(None, 0, 0))) def true_log_prob(*value): w, x = value @@ -347,11 +338,4 @@ def true_log_prob(*value): self.evaluate(true_grad)) if __name__ == '__main__': - tf.enable_v2_behavior() - physical_devices = tf.config.experimental.list_physical_devices() - - num_logical_devices = 4 - tf.config.experimental.set_virtual_device_configuration( - physical_devices[0], - [tf.config.experimental.VirtualDeviceConfiguration()] * NUM_DEVICES) tf.test.main() diff --git a/tensorflow_probability/python/experimental/distribute/distribute_test_lib.py b/tensorflow_probability/python/experimental/distribute/distribute_test_lib.py new file mode 100644 index 0000000000..9957a174fc --- /dev/null +++ b/tensorflow_probability/python/experimental/distribute/distribute_test_lib.py @@ -0,0 +1,74 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utilities for distributed testing.""" +import os + +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python.internal import test_util + +tf.enable_v2_behavior() +JAX_MODE = False +NUM_DEVICES = 4 + +if JAX_MODE: + import jax # pylint: disable=g-import-not-at-top + + +class DistributedTest(test_util.TestCase): + """Sets up distributed devices and sharding.""" + + def setUp(self): + super(DistributedTest, self).setUp() + if JAX_MODE: + os.environ['XLA_FLAGS'] = ( + '--xla_force_host_platform_device_count={}'.format(NUM_DEVICES)) + assert jax.device_count() == NUM_DEVICES + self.key = jax.random.PRNGKey(0) + else: + physical_devices = tf.config.experimental.list_physical_devices() + + tf.config.experimental.set_virtual_device_configuration( + physical_devices[0], + [tf.config.experimental.VirtualDeviceConfiguration()] * NUM_DEVICES) + self.strategy = tf.distribute.MirroredStrategy( + devices=tf.config.list_logical_devices()) + self.key = [0, 0] + self.axis_name = 'i' + + def per_replica_to_tensor(self, value): + if JAX_MODE: + return value + return tf.nest.map_structure( + lambda per_replica: tf.stack(per_replica.values, axis=0), value) + + def strategy_run(self, f, args, in_axes=0): + if JAX_MODE: + if in_axes is None: + return jax.pmap( + lambda _, args: f(*args), + in_axes=(0, None), + axis_name=self.axis_name)(tf.ones(NUM_DEVICES), args) + return jax.pmap(f, axis_name=self.axis_name, in_axes=in_axes)(*args) + return self.strategy.run(tf.function(f, autograph=False), args) + + def shard_values(self, values): + if JAX_MODE: + return jax.pmap(lambda x: x)(values) + + def value_fn(ctx): + return values[ctx.replica_id_in_sync_group] + + return self.strategy.experimental_distribute_values_from_function(value_fn) diff --git a/tensorflow_probability/python/experimental/distribute/joint_distribution.py b/tensorflow_probability/python/experimental/distribute/joint_distribution.py index 6f9b0034cd..d2e05f4e44 100644 --- a/tensorflow_probability/python/experimental/distribute/joint_distribution.py +++ b/tensorflow_probability/python/experimental/distribute/joint_distribution.py @@ -18,8 +18,11 @@ from __future__ import division from __future__ import print_function +import functools + import tensorflow.compat.v2 as tf from tensorflow_probability.python import distributions as distribution_lib +from tensorflow_probability.python.distributions import joint_distribution as jd_lib from tensorflow_probability.python.experimental.distribute import distribute_lib from tensorflow_probability.python.experimental.distribute import sharded @@ -31,9 +34,13 @@ class JointDistributionDistributedMixin(object): def get_sharded_distributions(self): """Indicates for each part distribution whether or not it is sharded.""" ds = self._get_single_sample_distributions() - return self._model_unflatten(( - isinstance(d, (sharded.ShardedIndependent, sharded.ShardedSample)) - for d in ds)) + return self._model_unflatten( + (isinstance(d, (sharded.ShardedIndependent, sharded.ShardedSample)) + for d in ds)) + + @property + def shard_axis_name(self): + return self._parameters['shard_axis_name'] def _map_measure_over_dists(self, attr, value): """Overrides the default implementation to shard its log_prob calculation.""" @@ -44,34 +51,122 @@ def _map_measure_over_dists(self, attr, value): def inner_log_prob_parts(flat_value): unflat_value = self._model_unflatten(flat_value) ds, xs = self._call_flat_sample_distributions( - value=unflat_value, seed=42) + value=unflat_value, seed=jd_lib.dummy_seed()) + # For sharded distributions, we need to make sure not to do an + # all-reduce. + flat_sharded = self._model_flatten(self.get_sharded_distributions()) + log_prob_fns = [ + functools.partial(d.log_prob, reduce_over_shards=False) + if s else d.log_prob for d, s in zip(ds, flat_sharded)] # We need to flatten and unflatten here to ensure the output structure # matches `flat_sharded_distributions`. vals = self._model_unflatten( - [getattr(d, attr)(x) for d, x in zip(ds, xs)]) + [log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs)]) return self._model_flatten(vals) flat_value = self._model_flatten(value) flat_sharded_distributions = self._model_flatten( self.get_sharded_distributions()) flat_xs = distribute_lib.make_sharded_log_prob_parts( - inner_log_prob_parts, flat_sharded_distributions)( + inner_log_prob_parts, + flat_sharded_distributions, + axis_name=self.shard_axis_name)( flat_value) return iter(flat_xs) - ds, xs = self._call_flat_sample_distributions(value=value, seed=42) + ds, xs = self._call_flat_sample_distributions( + value=value, seed=jd_lib.dummy_seed()) return (getattr(d, attr)(x) for d, x in zip(ds, xs)) class JointDistributionSequential(JointDistributionDistributedMixin, distribution_lib.JointDistributionSequential): - pass + """A sharding-aware JointDistributionSequential.""" + + def __init__(self, + model, + validate_args=False, + shard_axis_name=None, + name=None): + """Construct the `JointDistributionSequential` distribution. + + Args: + model: Python list of either tfd.Distribution instances and/or lambda + functions which take the `k` previous distributions and returns a new + tfd.Distribution instance. + validate_args: Python `bool`. Whether to validate input with asserts. If + `validate_args` is `False`, and the inputs are invalid, correct behavior + is not guaranteed. + Default value: `False`. + shard_axis_name: `str` for axis name for use in JAX backend. + name: The name for ops managed by the distribution. + Default value: `None` (i.e., `"JointDistributionSequential"`). + """ + super(JointDistributionSequential, self).__init__( + model, validate_args=validate_args, name=name) + self._parameters['shard_axis_name'] = shard_axis_name class JointDistributionNamed(JointDistributionDistributedMixin, distribution_lib.JointDistributionNamed): - pass + """A sharding-aware JointDistributionNamed.""" + + def __init__(self, + model, + validate_args=False, + shard_axis_name=None, + name=None): + """Construct the `JointDistributionNamed` distribution. + + Args: + model: Python `dict`, `collections.OrderedDict`, or `namedtuple` of + distribution-making functions each with required args corresponding only + to other keys. + validate_args: Python `bool`. Whether to validate input with asserts. If + `validate_args` is `False`, and the inputs are invalid, correct behavior + is not guaranteed. + Default value: `False`. + shard_axis_name: `str` for axis name for use in JAX backend. + name: The name for ops managed by the distribution. + Default value: `None` (i.e., `"JointDistributionNamed"`). + """ + super(JointDistributionNamed, + self).__init__(model, validate_args, name or 'JointDistributionNamed') + self._parameters['shard_axis_name'] = shard_axis_name class JointDistributionCoroutine(JointDistributionDistributedMixin, distribution_lib.JointDistributionCoroutine): - pass + """A sharding-aware JointDistributionCoroutine.""" + + def __init__( + self, + model, + sample_dtype=None, + validate_args=False, + shard_axis_name=None, + name=None, + ): + """Construct the `JointDistributionCoroutine` distribution. + + Args: + model: A generator that yields a sequence of `tfd.Distribution`-like + instances. + sample_dtype: Samples from this distribution will be structured like + `tf.nest.pack_sequence_as(sample_dtype, list_)`. `sample_dtype` is only + used for `tf.nest.pack_sequence_as` structuring of outputs, never + casting (which is the responsibility of the component distributions). + Default value: `None` (i.e. `namedtuple`). + validate_args: Python `bool`. Whether to validate input with asserts. If + `validate_args` is `False`, and the inputs are invalid, correct behavior + is not guaranteed. + Default value: `False`. + shard_axis_name: `str` for axis name for use in JAX backend. + name: The name for ops managed by the distribution. + Default value: `None` (i.e., `JointDistributionCoroutine`). + """ + super(JointDistributionCoroutine, self).__init__( + model, + sample_dtype=sample_dtype, + validate_args=validate_args, + name=name) + self._parameters['shard_axis_name'] = shard_axis_name diff --git a/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py b/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py index 14867fda3b..37a0b3799e 100644 --- a/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py +++ b/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py @@ -21,77 +21,88 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.experimental.distribute import distribute_test_lib as test_lib from tensorflow_probability.python.experimental.distribute import joint_distribution as jd from tensorflow_probability.python.experimental.distribute import sharded from tensorflow_probability.python.internal import test_util tfd = tfp.distributions -NUM_DEVICES = 4 +def make_jd_sequential(axis_name): + return jd.JointDistributionSequential([ + tfd.Normal(0., 1.), + lambda w: sharded.ShardedSample( # pylint: disable=g-long-lambda + tfd.Normal(w, 1.), test_lib.NUM_DEVICES, shard_axis_name=axis_name), + lambda x: sharded.ShardedIndependent( # pylint: disable=g-long-lambda + tfd.Normal(x, 1.), 1, shard_axis_name=axis_name), + ], shard_axis_name=axis_name) -def per_replica_to_tensor(value): - return tf.nest.map_structure( - lambda per_replica: tf.stack(per_replica.values, axis=0), value) +def make_jd_named(axis_name): + return jd.JointDistributionNamed( # pylint: disable=g-long-lambda + dict( + w=tfd.Normal(0., 1.), + x=lambda w: sharded.ShardedSample( # pylint: disable=g-long-lambda + tfd.Normal(w, 1.), + test_lib.NUM_DEVICES, + shard_axis_name=axis_name), + data=lambda x: sharded.ShardedIndependent( # pylint: disable=g-long-lambda + tfd.Normal(x, 1.), + 1, + shard_axis_name=axis_name), + ), shard_axis_name=axis_name) -def model_coroutine(): - w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.)) - x = yield sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES) - yield sharded.ShardedIndependent(tfd.Normal(x, 1.), 1) +def make_jd_coroutine(axis_name): -distributions = ( - ('coroutine', lambda: jd.JointDistributionCoroutine(model_coroutine)), - ('sequential', lambda: jd.JointDistributionSequential([ # pylint: disable=g-long-lambda - tfd.Normal(0., 1.), - lambda w: sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES), - lambda x: sharded.ShardedIndependent(tfd.Normal(x, 1.), 1), - ])), - ('named', lambda: jd.JointDistributionNamed( # pylint: disable=g-long-lambda - dict( - w=tfd.Normal(0., 1.), - x=lambda w: sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES), - data=lambda x: sharded.ShardedIndependent(tfd.Normal(x, 1.), 1), - ))), -) + def model_coroutine(): + w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.)) + x = yield sharded.ShardedSample( + tfd.Normal(w, 1.), test_lib.NUM_DEVICES, shard_axis_name=axis_name) + yield sharded.ShardedIndependent( + tfd.Normal(x, 1.), 1, shard_axis_name=axis_name) + return jd.JointDistributionCoroutine( + model_coroutine, shard_axis_name=axis_name) -@test_util.test_all_tf_execution_regimes -class JointDistributionTest(test_util.TestCase): - - def setUp(self): - super(JointDistributionTest, self).setUp() - self.strategy = tf.distribute.MirroredStrategy( - devices=tf.config.list_logical_devices()) - def shard_values(self, values): +distributions = ( + ('coroutine', make_jd_coroutine), + ('sequential', make_jd_sequential), + ('named', make_jd_named), +) - def value_fn(ctx): - return values[ctx.replica_id_in_sync_group] - return self.strategy.experimental_distribute_values_from_function(value_fn) +@test_util.test_all_tf_execution_regimes +class JointDistributionTest(test_lib.DistributedTest): + @test_util.disable_test_for_backend( + disable_jax=True, + reason='Cannot call `get_sharded_distributions` outside of pmap.') def test_get_sharded_distribution_coroutine(self): - dist = distributions[0][1]() - self.assertTupleEqual(dist.get_sharded_distributions(), - (False, True, True)) + dist = distributions[0][1](self.axis_name) + self.assertTupleEqual(dist.get_sharded_distributions(), (False, True, True)) + @test_util.disable_test_for_backend( + disable_jax=True, + reason='Cannot call `get_sharded_distributions` outside of pmap.') def test_get_sharded_distribution_sequential(self): - dist = distributions[1][1]() - self.assertListEqual(dist.get_sharded_distributions(), - [False, True, True]) + dist = distributions[1][1](self.axis_name) + self.assertListEqual(dist.get_sharded_distributions(), [False, True, True]) + @test_util.disable_test_for_backend( + disable_jax=True, + reason='Cannot call `get_sharded_distributions` outside of pmap.') def test_get_sharded_distribution_named(self): - dist = distributions[2][1]() + dist = distributions[2][1](self.axis_name) self.assertDictEqual(dist.get_sharded_distributions(), dict(w=False, x=True, data=True)) @parameterized.named_parameters(*distributions) def test_jd(self, dist_fn): - dist = dist_fn() + dist = dist_fn(self.axis_name) - @tf.function(autograph=False) def run(key): sample = dist.sample(seed=key) # The identity is to prevent reparameterization gradients from kicking in. @@ -99,9 +110,9 @@ def run(key): dist.log_prob, (tf.nest.map_structure(tf.identity, sample),)) return sample, log_prob, log_prob_grads - sample, log_prob, log_prob_grads = self.strategy.run( - run, (tf.ones(2, tf.int32),)) - sample, log_prob, log_prob_grads = per_replica_to_tensor( + sample, log_prob, log_prob_grads = self.strategy_run( + run, (self.key,), in_axes=None) + sample, log_prob, log_prob_grads = self.per_replica_to_tensor( (sample, log_prob, log_prob_grads)) def true_log_prob_fn(w, x, data): @@ -130,11 +141,4 @@ def true_log_prob_fn(w, x, data): if __name__ == '__main__': - tf.enable_v2_behavior() - physical_devices = tf.config.experimental.list_physical_devices() - - num_logical_devices = 4 - tf.config.experimental.set_virtual_device_configuration( - physical_devices[0], - [tf.config.experimental.VirtualDeviceConfiguration()] * NUM_DEVICES) tf.test.main() diff --git a/tensorflow_probability/python/experimental/distribute/sharded.py b/tensorflow_probability/python/experimental/distribute/sharded.py index 31976f5304..31fa3ab05a 100644 --- a/tensorflow_probability/python/experimental/distribute/sharded.py +++ b/tensorflow_probability/python/experimental/distribute/sharded.py @@ -24,10 +24,14 @@ from tensorflow_probability.python.distributions import independent as independent_lib from tensorflow_probability.python.distributions import sample as sample_lib +from tensorflow_probability.python.experimental.distribute import distribute_lib from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers +JAX_MODE = False + + class ShardedSample(sample_lib.Sample): """A version of `tfd.Sample` that shards its output across devices.""" @@ -35,6 +39,7 @@ def __init__(self, distribution, sample_shape=(), shard_axis=0, + shard_axis_name=None, validate_args=False, experimental_use_kahan_sum=False, name=None): @@ -47,6 +52,7 @@ def __init__(self, single sample. shard_axis: `int` representing which axis of `sample_shape` will be sharded across devices. + shard_axis_name: `str` for axis name for use in JAX backend. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. @@ -62,7 +68,7 @@ def __init__(self, with tf.name_scope(name or 'ShardedSample' + distribution.name) as name: self._shard_axis = shard_axis - + self._shard_axis_name = shard_axis_name super(ShardedSample, self).__init__( distribution, validate_args=validate_args, @@ -82,39 +88,107 @@ def sample_shape(self): sample_shape = ps.concat([ sample_shape[:self.shard_axis], [shard_size], sample_shape[self.shard_axis + 1:] - ], - axis=0) + ], axis=0) return sample_shape + @property + def shard_axis_name(self): + return self._shard_axis_name + @property def shard_axis(self): return self._shard_axis @property def replica_id(self): - ctx = tf.distribute.get_replica_context() - return ctx.replica_id_in_sync_group + return distribute_lib.get_replica_id(axis_name=self.shard_axis_name) @property def num_devices(self): - ctx = tf.distribute.get_replica_context() - return ctx.num_replicas_in_sync + return distribute_lib.get_num_replicas(axis_name=self.shard_axis_name) def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_sample_sample') - return super(ShardedSample, self)._sample_n(n, seed + self.replica_id, - **kwargs) + seed = samplers.fold_in(seed, tf.cast(self.replica_id, tf.int32)) + return super(ShardedSample, self)._sample_n(n, seed, **kwargs) + + def _log_prob(self, value, reduce_over_shards=True, **kwargs): + out_log_prob = super(ShardedSample, self)._log_prob(value, **kwargs) + if reduce_over_shards: + return distribute_lib.psum(out_log_prob, axis_name=self.shard_axis_name) + return out_log_prob + + def _parameter_control_dependencies(self, is_init=False): + if not self.validate_args: + return [] + return super(ShardedSample, self)._parameter_control_dependencies( + is_init=is_init) class ShardedIndependent(independent_lib.Independent): """A version of `tfd.Independent` that folds device id into its randomness.""" + def __init__(self, + distribution, + reinterpreted_batch_ndims=None, + validate_args=False, + shard_axis_name=None, + name=None): + """Construct a `ShardedIndependent` distribution. + + Args: + distribution: The base distribution instance to transform. Typically an + instance of `Distribution`. + reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims + which will be regarded as event dims. When `None` all but the first + batch axis (batch axis 0) will be transferred to event dimensions + (analogous to `tf.layers.flatten`). + validate_args: Python `bool`. Whether to validate input with asserts. If + `validate_args` is `False`, and the inputs are invalid, correct behavior + is not guaranteed. + shard_axis_name: `str` for axis name for use in JAX backend. + name: The name for ops managed by the distribution. + Default value: `Independent + distribution.name`. + + Raises: + ValueError: if `reinterpreted_batch_ndims` exceeds + `distribution.batch_ndims` + """ + with tf.name_scope(name or + 'ShardedIndependent' + distribution.name) as name: + self._shard_axis_name = shard_axis_name + super(ShardedIndependent, self).__init__( + distribution, + reinterpreted_batch_ndims=reinterpreted_batch_ndims, + validate_args=validate_args, + name=name) + self._parameters['shard_axis_name'] = shard_axis_name + + @property + def shard_axis_name(self): + return self._shard_axis_name + + def _log_prob(self, value, reduce_over_shards=True, **kwargs): + out_log_prob = super(ShardedIndependent, self)._log_prob(value, **kwargs) + if reduce_over_shards: + return distribute_lib.psum(out_log_prob, axis_name=self.shard_axis_name) + return out_log_prob + @property def replica_id(self): - ctx = tf.distribute.get_replica_context() - return ctx.replica_id_in_sync_group + return distribute_lib.get_replica_id(axis_name=self.shard_axis_name) + + @property + def num_devices(self): + return distribute_lib.get_num_replicas(axis_name=self.shard_axis_name) def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_independent_sample') - return super(ShardedIndependent, self)._sample_n(n, seed + self.replica_id, - **kwargs) + seed = samplers.fold_in(seed, tf.cast(self.replica_id, tf.int32)) + return super(ShardedIndependent, self)._sample_n(n, seed, **kwargs) + + def _parameter_control_dependencies(self, is_init): + if JAX_MODE: + return [] + return super(ShardedIndependent, self)._parameter_control_dependencies( + is_init=is_init) diff --git a/tensorflow_probability/python/experimental/distribute/sharded_test.py b/tensorflow_probability/python/experimental/distribute/sharded_test.py index fdb7387130..03c17f07c6 100644 --- a/tensorflow_probability/python/experimental/distribute/sharded_test.py +++ b/tensorflow_probability/python/experimental/distribute/sharded_test.py @@ -19,35 +19,28 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.experimental.distribute import distribute_test_lib as test_lib from tensorflow_probability.python.experimental.distribute import sharded from tensorflow_probability.python.internal import test_util tfd = tfp.distributions -NUM_DEVICES = 4 - -def per_replica_to_tensor(value): - return tf.nest.map_structure( - lambda per_replica: tf.stack(per_replica.values, axis=0), value) - - -class ShardedDistributionTest(test_util.TestCase): - - def setUp(self): - super(ShardedDistributionTest, self).setUp() - self.strategy = tf.distribute.MirroredStrategy( - devices=tf.config.list_logical_devices()) +@test_util.test_all_tf_execution_regimes +class ShardedDistributionTest(test_lib.DistributedTest): def test_sharded_sample_samples_differently_across_shards(self): @tf.function(autograph=False) def run(key): - return sharded.ShardedSample(tfd.Normal(0., 1.), - NUM_DEVICES).sample(seed=key) + return sharded.ShardedSample( + tfd.Normal(0., 1.), + test_lib.NUM_DEVICES, + shard_axis_name=self.axis_name).sample(seed=key) sample = self.evaluate( - per_replica_to_tensor(self.strategy.run(run, (tf.zeros(2, tf.int32),)))) + self.per_replica_to_tensor( + self.strategy_run(run, (self.key,), in_axes=None))) for i in range(4): for j in range(4): if i == j: @@ -59,10 +52,13 @@ def test_sharded_independent_samples_differently_across_shards(self): @tf.function(autograph=False) def run(key): return sharded.ShardedIndependent( - tfd.Normal(tf.zeros(1), tf.ones(1)), 1).sample(seed=key) + tfd.Normal(tf.zeros(1), tf.ones(1)), + 1, + shard_axis_name=self.axis_name).sample(seed=key) sample = self.evaluate( - per_replica_to_tensor(self.strategy.run(run, (tf.zeros(2, tf.int32),)))) + self.per_replica_to_tensor( + self.strategy_run(run, (self.key,), in_axes=None))) for i in range(4): for j in range(4): if i == j: @@ -71,11 +67,4 @@ def run(key): if __name__ == "__main__": - tf.enable_v2_behavior() - physical_devices = tf.config.experimental.list_physical_devices() - - num_logical_devices = 4 - tf.config.experimental.set_virtual_device_configuration( - physical_devices[0], - [tf.config.experimental.VirtualDeviceConfiguration()] * NUM_DEVICES) tf.test.main() diff --git a/tensorflow_probability/python/internal/samplers.py b/tensorflow_probability/python/internal/samplers.py index 2f472527c1..71f5442b33 100644 --- a/tensorflow_probability/python/internal/samplers.py +++ b/tensorflow_probability/python/internal/samplers.py @@ -34,6 +34,7 @@ __all__ = [ 'categorical', + 'fold_in', 'gamma', 'is_stateful_seed', 'normal', @@ -88,16 +89,24 @@ def sanitize_seed(seed, salt=None, name=None): if salt is not None: salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16) - if JAX_MODE: - from jax import random as jaxrand # pylint: disable=g-import-not-at-top - seed = jaxrand.fold_in(seed, salt & (2**32 - 1)) - else: - seed = tf.bitwise.bitwise_xor( - seed, np.uint64([salt & (2**64 - 1)]).view(np.int32)) + seed = fold_in(seed, salt) return tf.convert_to_tensor(seed, dtype=SEED_DTYPE, name='seed') +def fold_in(seed, salt): + """Folds salt into seed to form a new seed.""" + if JAX_MODE: + from jax import random as jaxrand # pylint: disable=g-import-not-at-top + return jaxrand.fold_in(seed, salt & (2**32 - 1)) + if isinstance(salt, (six.integer_types)): + seed = tf.bitwise.bitwise_xor( + seed, np.uint64([salt & (2**64 - 1)]).view(np.int32)) + else: + seed = tf.random.experimental.stateless_fold_in(seed, salt) + return seed + + def split_seed(seed, n=2, salt=None, name=None): """Splits a seed into `n` derived seeds. From 2c190f41cc38e95612d6cb98011963cbda5afe4e Mon Sep 17 00:00:00 2001 From: leben Date: Wed, 9 Dec 2020 20:42:43 -0800 Subject: [PATCH 12/36] Add `ThinningKernel` to `experimental.mcmc`. This is adapted from `SampleDiscardingKernel`, sans burn-in; it has the advantage of not wrapping `KernelResults`. PiperOrigin-RevId: 346698326 --- .../python/experimental/mcmc/BUILD | 30 +++ .../python/experimental/mcmc/__init__.py | 2 + .../experimental/mcmc/thinning_kernel.py | 120 +++++++++++ .../experimental/mcmc/thinning_kernel_test.py | 188 ++++++++++++++++++ 4 files changed, 340 insertions(+) create mode 100644 tensorflow_probability/python/experimental/mcmc/thinning_kernel.py create mode 100644 tensorflow_probability/python/experimental/mcmc/thinning_kernel_test.py diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index 035b8cda78..639f4d2976 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -55,6 +55,7 @@ multi_substrate_py_library( ":sample", ":sample_discarding_kernel", ":sample_fold", + ":thinning_kernel", ":tracing_reducer", ":with_reductions", ], @@ -79,6 +80,7 @@ multi_substrate_py_library( ":sample_fold", ":sample_sequential_monte_carlo", ":sequential_monte_carlo_kernel", + ":thinning_kernel", ":tracing_reducer", ":weighted_resampling", ":with_reductions", @@ -714,6 +716,34 @@ py_test( ], ) +py_library( + name = "thinning_kernel", + srcs = ["thinning_kernel.py"], + srcs_version = "PY3", + deps = [ + ":sample", + # tensorflow dep, + "//tensorflow_probability/python/mcmc:kernel", + "//tensorflow_probability/python/mcmc/internal", + ], +) + +py_test( + name = "thinning_kernel_test", + size = "small", + srcs = ["thinning_kernel_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":thinning_kernel", + # numpy dep, + # tensorflow dep, + "//tensorflow_probability", + "//tensorflow_probability/python/experimental/mcmc/internal:test_fixtures", + "//tensorflow_probability/python/internal:test_util", + ], +) + py_library( name = "tracing_reducer", srcs = ["tracing_reducer.py"], diff --git a/tensorflow_probability/python/experimental/mcmc/__init__.py b/tensorflow_probability/python/experimental/mcmc/__init__.py index a13b9bc59d..dc8c6c6d7f 100644 --- a/tensorflow_probability/python/experimental/mcmc/__init__.py +++ b/tensorflow_probability/python/experimental/mcmc/__init__.py @@ -56,6 +56,7 @@ from tensorflow_probability.python.experimental.mcmc.sequential_monte_carlo_kernel import SequentialMonteCarlo from tensorflow_probability.python.experimental.mcmc.sequential_monte_carlo_kernel import SequentialMonteCarloResults from tensorflow_probability.python.experimental.mcmc.sequential_monte_carlo_kernel import WeightedParticles +from tensorflow_probability.python.experimental.mcmc.thinning_kernel import ThinningKernel from tensorflow_probability.python.experimental.mcmc.tracing_reducer import TracingReducer from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample_deterministic_minimum_error from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample_independent @@ -116,6 +117,7 @@ 'simple_heuristic_tuning', 'StateWithHistory', 'step_kernel', + 'ThinningKernel', 'TracingReducer', 'VarianceReducer', 'WeightedParticles', diff --git a/tensorflow_probability/python/experimental/mcmc/thinning_kernel.py b/tensorflow_probability/python/experimental/mcmc/thinning_kernel.py new file mode 100644 index 0000000000..53429a235a --- /dev/null +++ b/tensorflow_probability/python/experimental/mcmc/thinning_kernel.py @@ -0,0 +1,120 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Kernel for Thinning.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf +from tensorflow_probability.python.experimental.mcmc import sample +from tensorflow_probability.python.mcmc import kernel as kernel_base +from tensorflow_probability.python.mcmc.internal import util as mcmc_util + + +__all__ = [ + 'ThinningKernel', +] + + +class ThinningKernel(kernel_base.TransitionKernel): + """Discards samples to perform thinning. + + `ThinningKernel` is a composable `TransitionKernel` that thins samples + returned by its `inner_kernel`. All Transition Kernels wrapping it will only + see non-discarded samples. + """ + + def __init__( + self, + inner_kernel, + num_steps_to_skip, + name=None): + """Instantiates this object. + + Args: + inner_kernel: `TransitionKernel` whose `one_step` will generate + MCMC results. + num_steps_to_skip: Integer or scalar `Tensor` representing + the number of chain steps skipped before collecting a result. + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "thinning_kernel"). + """ + self._parameters = dict( + inner_kernel=inner_kernel, + num_steps_to_skip=num_steps_to_skip, + name=name or 'thinning_kernel' + ) + + def one_step(self, current_state, previous_kernel_results, seed=None): + """Collects one non-thinned chain state. + + Args: + current_state: `Tensor` or Python `list` of `Tensor`s + representing the current state(s) of the Markov chain(s), + previous_kernel_results: `collections.namedtuple` containing `Tensor`s + representing values from previous calls to this function (or from the + `bootstrap_results` function). + seed: Optional seed for reproducible sampling. + + Returns: + new_chain_state: Newest non-discarded MCMC chain state drawn from + the `inner_kernel`. + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. + """ + with tf.name_scope( + mcmc_util.make_name(self.name, 'thinned_kernel', 'one_step')): + return sample.step_kernel( + num_steps=self.num_steps_to_skip + 1, + current_state=current_state, + previous_kernel_results=previous_kernel_results, + kernel=self.inner_kernel, + return_final_kernel_results=True, + seed=seed, + name=self.name) + + def bootstrap_results(self, init_state): + """Instantiates a new kernel state with no calls. + + Args: + init_state: `Tensor` or Python `list` of `Tensor`s representing the + state(s) of the Markov chain(s). + + Returns: + kernel_results: `collections.namedtuple` of `Tensor`s representing + internal calculations made within this function. + """ + return self.inner_kernel.bootstrap_results(init_state) + + @property + def is_calibrated(self): + return self.inner_kernel.is_calibrated + + @property + def inner_kernel(self): + return self._parameters['inner_kernel'] + + @property + def num_steps_to_skip(self): + return self._parameters['num_steps_to_skip'] + + @property + def name(self): + return self._parameters['name'] + + @property + def parameters(self): + return self._parameters diff --git a/tensorflow_probability/python/experimental/mcmc/thinning_kernel_test.py b/tensorflow_probability/python/experimental/mcmc/thinning_kernel_test.py new file mode 100644 index 0000000000..87bfe21d63 --- /dev/null +++ b/tensorflow_probability/python/experimental/mcmc/thinning_kernel_test.py @@ -0,0 +1,188 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for ThinningKernel TransitionKernel.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports +import numpy as np + +import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp +from tensorflow_probability.python.experimental.mcmc.internal import test_fixtures +from tensorflow_probability.python.internal import test_util + + +@test_util.test_all_tf_execution_regimes +class ThinningTest(test_util.TestCase): + + def test_thinning(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=1,) + first_state, kernel_results = thinner.one_step( + 0., thinner.bootstrap_results(0.)) + second_state, kernel_results = thinner.one_step( + first_state, kernel_results) + first_state, second_state, kernel_results = self.evaluate([ + first_state, second_state, kernel_results]) + self.assertEqual(2, first_state) + self.assertEqual(4, second_state) + self.assertEqual(4, kernel_results.counter_1) + self.assertEqual(8, kernel_results.counter_2) + + def test_no_thinning(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=0,) + first_state, kernel_results = thinner.one_step( + 0., thinner.bootstrap_results(0.)) + second_state, kernel_results = thinner.one_step( + first_state, kernel_results) + first_state, second_state, kernel_results = self.evaluate([ + first_state, second_state, kernel_results]) + self.assertEqual(1, first_state) + self.assertEqual(2, second_state) + self.assertEqual(2, kernel_results.counter_1) + self.assertEqual(4, kernel_results.counter_2) + + def test_cold_start(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=1,) + first_state, _ = thinner.one_step( + 0., thinner.bootstrap_results(0.)) + second_state, kernel_results = thinner.one_step( + first_state, thinner.bootstrap_results(first_state)) + first_state, second_state, kernel_results = self.evaluate([ + first_state, second_state, kernel_results]) + self.assertEqual(2, first_state) + self.assertEqual(4, second_state) + self.assertEqual(2, kernel_results.counter_1) + self.assertEqual(4, kernel_results.counter_2) + + def test_is_calibrated(self): + calibrated_kernel = test_fixtures.TestTransitionKernel() + uncalibrated_kernel = test_fixtures.TestTransitionKernel( + is_calibrated=False) + calibrated_thinner = tfp.experimental.mcmc.ThinningKernel( + calibrated_kernel, 0) + uncalibrated_thinner = tfp.experimental.mcmc.ThinningKernel( + uncalibrated_kernel, 0) + self.assertTrue(calibrated_thinner.is_calibrated) + self.assertFalse(uncalibrated_thinner.is_calibrated) + + def test_with_composed_kernel(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + cov_reducer = tfp.experimental.mcmc.CovarianceReducer() + reducer_kernel = tfp.experimental.mcmc.WithReductions( + inner_kernel=tfp.experimental.mcmc.ThinningKernel( + inner_kernel=fake_inner_kernel, + num_steps_to_skip=2,), + reducer=cov_reducer + ) + current_state, kernel_results = 0., reducer_kernel.bootstrap_results(0.) + for _ in range(2): + current_state, kernel_results = reducer_kernel.one_step( + current_state, kernel_results) + cov = self.evaluate(cov_reducer.finalize(kernel_results.reduction_results)) + self.assertAllEqual(6, current_state) + self.assertAllEqual(6, kernel_results.inner_results.counter_1) + self.assertAllEqual(12, kernel_results.inner_results.counter_2) + self.assertNear(np.var([3, 6]), cov, err=1e-6) + + def test_tf_while(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=1,) + + def _loop_body(i, curr_state, pkr): + new_state, kernel_results = thinner.one_step( + curr_state, pkr, + ) + return (i + 1, new_state, kernel_results) + + pkr = thinner.bootstrap_results(0.) + _, final_sample, kernel_results = tf.while_loop( + lambda i, *_: i < 2, + _loop_body, + (0., 0., pkr), + ) + final_sample, kernel_results = self.evaluate([ + final_sample, kernel_results]) + self.assertEqual(4, final_sample) + self.assertEqual(4, kernel_results.counter_1) + self.assertEqual(8, kernel_results.counter_2) + + def test_tensor_thinning(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=tf.convert_to_tensor(1),) + + def _loop_body(i, curr_state, pkr): + new_state, kernel_results = thinner.one_step( + curr_state, pkr, + ) + return (i + 1, new_state, kernel_results) + + pkr = thinner.bootstrap_results(0.) + _, final_sample, kernel_results = tf.while_loop( + lambda i, _, __: i < 2, + _loop_body, + (0., 0., pkr), + ) + + final_sample, kernel_results = self.evaluate([ + final_sample, kernel_results]) + self.assertEqual(4, final_sample) + self.assertEqual(4, kernel_results.counter_1) + self.assertEqual(8, kernel_results.counter_2) + + def test_non_static_thinning(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + num_steps_to_skip = tf.Variable(1, dtype=tf.int32) + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=num_steps_to_skip) + + def _loop_body(i, curr_state, pkr): + new_state, kernel_results = thinner.one_step( + curr_state, pkr, + ) + return (i + 1, new_state, kernel_results) + + pkr = thinner.bootstrap_results(0.) + _, final_sample, kernel_results = tf.while_loop( + lambda i, _, __: i < 2, + _loop_body, + (0., 0., pkr), + ) + self.evaluate([num_steps_to_skip.initializer]) + final_sample, kernel_results = self.evaluate([ + final_sample, kernel_results]) + self.assertEqual(4, final_sample) + self.assertEqual(4, kernel_results.counter_1) + self.assertEqual(8, kernel_results.counter_2) + + +if __name__ == '__main__': + tf.test.main() From d0e533ae13d6fec9050ad7e427a97e96ad52be38 Mon Sep 17 00:00:00 2001 From: bjp Date: Thu, 10 Dec 2020 08:51:54 -0800 Subject: [PATCH 13/36] Small fixes to docs. PiperOrigin-RevId: 346793654 --- tensorflow_probability/python/internal/custom_gradient.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow_probability/python/internal/custom_gradient.py b/tensorflow_probability/python/internal/custom_gradient.py index 07bc930881..19d5a19d16 100644 --- a/tensorflow_probability/python/internal/custom_gradient.py +++ b/tensorflow_probability/python/internal/custom_gradient.py @@ -41,8 +41,11 @@ def custom_gradient(vjp_fwd=None, vjp_bwd=None, jvp_fn=None, Args: vjp_fwd: A function (*args) => (output, auxiliaries). - vjp_bwd: A function (auxiliaries, output_gradient) => args_gradients. - jvp_fn: A function (primals, tangents) => (primal_out, tangent_out). + vjp_bwd: A function (auxiliaries, output_gradient) => + nondiff_args_gradients. `None` gradients will be inserted into the correct + positions for `nondiff_argnums`. + jvp_fn: A function (*nondiff_args, primals, tangents) => + (primal_out, tangent_out). nondiff_argnums: Tuple of argument indices which are not differentiable. Returns: From 67e6c7b42051ad6f6c168a25adac4f80b6846a94 Mon Sep 17 00:00:00 2001 From: siege Date: Thu, 10 Dec 2020 14:40:19 -0800 Subject: [PATCH 14/36] Add experimental_use_kahan_sum to ShardedIndependent + fix its parameters. PiperOrigin-RevId: 346868095 --- .../python/experimental/distribute/sharded.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/experimental/distribute/sharded.py b/tensorflow_probability/python/experimental/distribute/sharded.py index 31fa3ab05a..166704506f 100644 --- a/tensorflow_probability/python/experimental/distribute/sharded.py +++ b/tensorflow_probability/python/experimental/distribute/sharded.py @@ -133,6 +133,7 @@ def __init__(self, reinterpreted_batch_ndims=None, validate_args=False, shard_axis_name=None, + experimental_use_kahan_sum=False, name=None): """Construct a `ShardedIndependent` distribution. @@ -147,6 +148,11 @@ def __init__(self, `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. shard_axis_name: `str` for axis name for use in JAX backend. + experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan + summation to aggregate independent underlying log_prob values, which + improves against the precision of a naive float32 sum. This can be + noticeable in particular for large dimensions in float32. See CPU caveat + on `tfp.math.reduce_kahan_sum`. name: The name for ops managed by the distribution. Default value: `Independent + distribution.name`. @@ -154,6 +160,8 @@ def __init__(self, ValueError: if `reinterpreted_batch_ndims` exceeds `distribution.batch_ndims` """ + parameters = dict(locals()) + with tf.name_scope(name or 'ShardedIndependent' + distribution.name) as name: self._shard_axis_name = shard_axis_name @@ -161,8 +169,9 @@ def __init__(self, distribution, reinterpreted_batch_ndims=reinterpreted_batch_ndims, validate_args=validate_args, + experimental_use_kahan_sum=experimental_use_kahan_sum, name=name) - self._parameters['shard_axis_name'] = shard_axis_name + self._parameters = parameters @property def shard_axis_name(self): From a46e9d4059c0b51ab9fa517d83baed89902d2d90 Mon Sep 17 00:00:00 2001 From: bjp Date: Thu, 10 Dec 2020 18:44:15 -0800 Subject: [PATCH 15/36] Fix docstring description of default name. PiperOrigin-RevId: 346910502 --- .../python/experimental/distribute/sharded.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_probability/python/experimental/distribute/sharded.py b/tensorflow_probability/python/experimental/distribute/sharded.py index 166704506f..14ed7c2ab2 100644 --- a/tensorflow_probability/python/experimental/distribute/sharded.py +++ b/tensorflow_probability/python/experimental/distribute/sharded.py @@ -62,7 +62,7 @@ def __init__(self, noticeable in particular for large dimensions in float32. See CPU caveat on `tfp.math.reduce_kahan_sum`. name: The name for ops managed by the distribution. - Default value: `None` (i.e., `'Sample' + distribution.name`). + Default value: `None` (i.e., `'ShardedSample' + distribution.name`). """ parameters = dict(locals()) @@ -154,7 +154,7 @@ def __init__(self, noticeable in particular for large dimensions in float32. See CPU caveat on `tfp.math.reduce_kahan_sum`. name: The name for ops managed by the distribution. - Default value: `Independent + distribution.name`. + Default value: `'ShardedIndependent' + distribution.name`. Raises: ValueError: if `reinterpreted_batch_ndims` exceeds From e132aa47da07c9db54aba41f114ff60ce04d68c2 Mon Sep 17 00:00:00 2001 From: sharadmv Date: Fri, 11 Dec 2020 15:20:08 -0800 Subject: [PATCH 16/36] Fix cast creating Tracer in JAX backend. PiperOrigin-RevId: 347085672 --- tensorflow_probability/python/random/random_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_probability/python/random/random_ops.py b/tensorflow_probability/python/random/random_ops.py index f4ed20da2a..80c957e8ad 100644 --- a/tensorflow_probability/python/random/random_ops.py +++ b/tensorflow_probability/python/random/random_ops.py @@ -131,7 +131,7 @@ def spherical_uniform( """ with tf.name_scope(name or 'spherical_uniform'): seed = samplers.sanitize_seed(seed) - dimension = ps.convert_to_shape_tensor(tf.cast(dimension, dtype=tf.int32)) + dimension = ps.convert_to_shape_tensor(ps.cast(dimension, dtype=tf.int32)) shape = ps.convert_to_shape_tensor(shape, dtype=tf.int32) dimension_static = tf.get_static_value(dimension) sample_shape = ps.concat([shape, [dimension]], axis=0) From 8854db07132b65b2795d41e7beebf2fa3a333f56 Mon Sep 17 00:00:00 2001 From: bjp Date: Mon, 14 Dec 2020 12:41:22 -0800 Subject: [PATCH 17/36] Modify auto_composite_tensor to prefer static values for fields annotated as "_composite_tensor_shape_parameters". Prior to the fix, test fails with "ValueError: Input tensor 'IndependentNormal_2/log_prob/Sum:0' enters the loop with shape (), but has shape after one iteration. To allow the shape to vary across iterations, use the `shape_invariants` argument of tf.while_loop to specify a less-specific shape." PiperOrigin-RevId: 347446310 --- .../python/internal/auto_composite_tensor.py | 43 ++++++++++++++----- .../internal/auto_composite_tensor_test.py | 24 ++++++++++- 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/tensorflow_probability/python/internal/auto_composite_tensor.py b/tensorflow_probability/python/internal/auto_composite_tensor.py index 26ab6b3ac5..82e8cf8f22 100644 --- a/tensorflow_probability/python/internal/auto_composite_tensor.py +++ b/tensorflow_probability/python/internal/auto_composite_tensor.py @@ -36,10 +36,11 @@ _SENTINEL = object() -_AUTO_COMPOSITE_TENSOR_VERSION = 1 +_AUTO_COMPOSITE_TENSOR_VERSION = 2 -def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None): +def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None, + prefer_static_value=()): """Extract constructor kwargs to reconstruct `obj`.""" argspec = inspect.getfullargspec(obj.__init__) if argspec.varargs or argspec.varkw: @@ -61,6 +62,10 @@ def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None): raise ValueError( 'Object did not have an attr corresponding to constructor argument ' '{k}. (Tried both `obj.{k}` and obj._{k}`).'.format(k=k)) + if k in prefer_static_value and kwargs[k] is not None: + static_val = tf.get_static_value(kwargs[k]) + if static_val is not None: + kwargs[k] = static_val return kwargs @@ -101,16 +106,22 @@ def _extract_type_spec_recursively(value): class _AutoCompositeTensorTypeSpec(tf.TypeSpec): """A tf.TypeSpec for `AutoCompositeTensor` objects.""" - __slots__ = ('_param_specs', '_non_tensor_params', '_omit_kwargs') + __slots__ = ('_param_specs', '_non_tensor_params', '_omit_kwargs', + '_prefer_static_value') - def __init__(self, param_specs, non_tensor_params, omit_kwargs): + def __init__(self, param_specs, non_tensor_params, omit_kwargs, + prefer_static_value): self._param_specs = param_specs self._non_tensor_params = non_tensor_params self._omit_kwargs = omit_kwargs + self._prefer_static_value = prefer_static_value @classmethod def from_instance(cls, instance, omit_kwargs=()): - kwargs = _extract_init_kwargs(instance, omit_kwargs) + prefer_static_value = tuple( + getattr(instance, '_composite_tensor_shape_params', ())) + kwargs = _extract_init_kwargs(instance, omit_kwargs=omit_kwargs, + prefer_static_value=prefer_static_value) non_tensor_params = {} param_specs = {} @@ -125,7 +136,8 @@ def from_instance(cls, instance, omit_kwargs=()): # Construct the spec. return cls(param_specs=param_specs, non_tensor_params=non_tensor_params, - omit_kwargs=omit_kwargs) + omit_kwargs=omit_kwargs, + prefer_static_value=prefer_static_value) def _to_components(self, obj): return _extract_init_kwargs(obj, limit_to=list(self._param_specs)) @@ -142,16 +154,20 @@ def _serialize(self): result = (_AUTO_COMPOSITE_TENSOR_VERSION, self._param_specs, self._non_tensor_params, - self._omit_kwargs) + self._omit_kwargs, + self._prefer_static_value) return result @classmethod def _deserialize(cls, encoded): - version, param_specs, non_tensor_params, omit_kwargs = encoded + version = encoded[0] + if version == 1: + encoded = encoded + ((),) + version = 2 if version != _AUTO_COMPOSITE_TENSOR_VERSION: raise ValueError('Expected version {}, but got {}' .format(_AUTO_COMPOSITE_TENSOR_VERSION, version)) - return cls(param_specs, non_tensor_params, omit_kwargs) + return cls(*encoded[1:]) _TypeSpecCodec = nested_structure_coder._TypeSpecCodec # pylint: disable=protected-access @@ -193,6 +209,12 @@ def auto_composite_tensor(cls=None, omit_kwargs=()): - object.attribute = [tf.constant(1.), [tf.constant(2.)]] # valid - object.attribute = ['abc', tf.constant(1.)] # invalid + If the object has a `_composite_tensor_shape_parameters` field (presumed to + have `tuple` of `str` value), the flattening code will use + `tf.get_static_value` to attempt to preserve shapes as static metadata, for + fields whose name matches a name specified in that field. Preserving static + values can be important to correctly propagating shapes through a loop. + If the decorated class `A` does not subclass `CompositeTensor`, a *new class* will be generated, which mixes in `A` and `CompositeTensor`. @@ -277,7 +299,8 @@ def body(obj): composite_tensor_subclass: A subclass of `cls` and TF CompositeTensor. """ if cls is None: - return functools.partial(auto_composite_tensor, omit_kwargs=omit_kwargs) + return functools.partial(auto_composite_tensor, + omit_kwargs=omit_kwargs) # If the declared class is already a CompositeTensor subclass, we can avoid # affecting the actual type of the returned class. Otherwise, we need to diff --git a/tensorflow_probability/python/internal/auto_composite_tensor_test.py b/tensorflow_probability/python/internal/auto_composite_tensor_test.py index 231f6846eb..418a995e32 100644 --- a/tensorflow_probability/python/internal/auto_composite_tensor_test.py +++ b/tensorflow_probability/python/internal/auto_composite_tensor_test.py @@ -24,6 +24,9 @@ from tensorflow_probability.python.internal import test_util +tfd = tfp.distributions + + AutoIdentity = tfp.experimental.auto_composite_tensor( tf.linalg.LinearOperatorIdentity, omit_kwargs=('name',)) AutoDiag = tfp.experimental.auto_composite_tensor( @@ -33,6 +36,11 @@ AutoTriL = tfp.experimental.auto_composite_tensor( tf.linalg.LinearOperatorLowerTriangular, omit_kwargs=('name',)) +AutoNormal = tfp.experimental.auto_composite_tensor( + tfd.Normal, omit_kwargs=('name',)) +AutoIndependent = tfp.experimental.auto_composite_tensor( + tfd.Independent, omit_kwargs=('name',)) + @test_util.test_all_tf_execution_regimes class AutoCompositeTensorTest(test_util.TestCase): @@ -77,6 +85,18 @@ def body(lop): maximum_iterations=3) self.assertAllClose(2.**3 * tf.ones([3]), lop.matvec(tf.ones([3]))) + def test_shape_parameters(self): + dist = AutoIndependent(AutoNormal(0, tf.ones([1])), + reinterpreted_batch_ndims=1) + stream = test_util.test_seed_stream() + lp = dist.log_prob(dist.sample(seed=stream())) + lp, _ = tf.while_loop( + lambda *_: True, + lambda lp, d: (d.log_prob(d.sample(seed=stream())), d), + (lp, dist), + maximum_iterations=2) + self.evaluate(lp) + def test_nested(self): lop = AutoBlockDiag([AutoDiag(tf.ones([2]) * 2), AutoIdentity(1)]) self.assertAllClose( @@ -90,9 +110,9 @@ def test_preconditioner(self): is_self_adjoint=True, is_positive_definite=True) - tfd = tfp.experimental.distributions + tfed = tfp.experimental.distributions auto_ct_mvn_prec_linop = tfp.experimental.auto_composite_tensor( - tfd.MultivariateNormalPrecisionFactorLinearOperator, + tfed.MultivariateNormalPrecisionFactorLinearOperator, omit_kwargs=('name',)) tril = AutoTriL(**cov_linop.cholesky().parameters) momentum_distribution = auto_ct_mvn_prec_linop(precision_factor=tril) From 51057934c798c05a962b4c66ad1a8f011c43e338 Mon Sep 17 00:00:00 2001 From: emilyaf Date: Mon, 14 Dec 2020 12:53:48 -0800 Subject: [PATCH 18/36] Enable `build_factored_surrogate_posterior` to take a multipart unconstraining bijector. PiperOrigin-RevId: 347448699 --- .../experimental/vi/surrogate_posteriors.py | 122 ++++++++++-------- .../vi/surrogate_posteriors_test.py | 40 ++++++ 2 files changed, 110 insertions(+), 52 deletions(-) diff --git a/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py b/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py index 7d847daf28..1a4a4c37e1 100644 --- a/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py +++ b/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py @@ -25,6 +25,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python import bijectors as tfb from tensorflow_probability.python import util as tfp_util +from tensorflow_probability.python.bijectors import identity as identity_bijector from tensorflow_probability.python.bijectors import softplus as softplus_lib from tensorflow_probability.python.distributions import beta from tensorflow_probability.python.distributions import half_normal @@ -41,8 +42,10 @@ from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static -from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import - +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.util import deprecation +from tensorflow.python.util import nest +# pylint: enable=g-direct-tensorflow-import Root = joint_distribution_coroutine.JointDistributionCoroutine.Root @@ -122,8 +125,13 @@ def _not_list_of_ints(s): build_trainable_location_scale_distribution, distribution_fn=normal.Normal) +@deprecation.deprecated_args( + '2021-03-15', + '`constraining_bijectors` is deprecated, use `bijector` instead', + 'constraining_bijectors') def build_factored_surrogate_posterior( event_shape=None, + bijector=None, constraining_bijectors=None, initial_unconstrained_loc=_sample_uniform_initial_loc, initial_unconstrained_scale=1e-2, @@ -142,6 +150,13 @@ def build_factored_surrogate_posterior( Args: event_shape: `Tensor` shape, or nested structure of `Tensor` shapes, specifying the event shape(s) of the posterior variables. + bijector: Optional `tfb.Bijector` instance, or nested structure of such + instances, defining support(s) of the posterior variables. The structure + must match that of `event_shape` and may contain `None` values. A + posterior variable will be modeled as + `tfd.TransformedDistribution(underlying_dist, bijector)` if a + corresponding constraining bijector is specified, otherwise it is modeled + as supported on the unconstrained real line. constraining_bijectors: Optional `tfb.Bijector` instance, or nested structure of such instances, defining support(s) of the posterior variables. The structure must match that of `event_shape` and may @@ -156,8 +171,8 @@ def build_factored_surrogate_posterior( variable. May alternately be a nested structure of `Tensor`s, giving specific initial locations for each variable; these must have structure matching `event_shape` and shapes determined by the - inverse image of `event_shape` under `constraining_bijectors`, which - may optionally be prefixed with a common batch shape. + inverse image of `event_shape` under `bijector`, which may optionally be + prefixed with a common batch shape. Default value: `functools.partial(tf.random.uniform, minval=-2., maxval=2., dtype=tf.float32)`. initial_unconstrained_scale: Optional scalar float `Tensor` initial @@ -209,8 +224,8 @@ def model_fn(): ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], # Omit the observed `y`. - constraining_bijectors=[tfb.Softplus(), # Rate is positive. - tfb.Softplus()]) # Concentration is positive. + bijector=[tfb.Softplus(), # Rate is positive. + tfb.Softplus()]) # Concentration is positive. ``` This creates a trainable joint distribution, defined by variables in @@ -241,14 +256,13 @@ def model_fn(): ```python initial_loc = {'concentration': 0.4, 'rate': 0.2} - constraining_bijectors={'concentration': tfb.Softplus(), # Rate is positive. - 'rate': tfb.Softplus()} # Concentration is positive. + bijector={'concentration': tfb.Softplus(), # Rate is positive. + 'rate': tfb.Softplus()} # Concentration is positive. initial_unconstrained_loc = tf.nest.map_fn( - lambda b, x: b.inverse(x) if b is not None else x, - constraining_bijectors, initial_loc) + lambda b, x: b.inverse(x) if b is not None else x, bijector, initial_loc) surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=tf.nest.map_fn(tf.shape, initial_loc), - constraining_bijectors=constraining_bijectors, + bijector=bijector, initial_unconstrained_loc=initial_unconstrained_state, initial_unconstrained_scale=1e-4) ``` @@ -256,6 +270,9 @@ def model_fn(): """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): + bijector = deprecation.deprecated_argument_lookup( + 'bijector', bijector, 'constraining_bijectors', constraining_bijectors) + seed = tfp_util.SeedStream(seed, salt='build_factored_surrogate_posterior') # Convert event shapes to Tensors. @@ -263,58 +280,59 @@ def model_fn(): event_shape = nest.map_structure_up_to( shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) - flat_event_shapes = tf.nest.flatten(event_shape) - # For simplicity, we'll work with flattened lists of state parts and - # repack the structure at the end. - if constraining_bijectors is not None: - flat_bijectors = tf.nest.flatten(constraining_bijectors) + if nest.is_nested(bijector): + bijector = nest.map_structure( + lambda b: identity_bijector.Identity() if b is None else b, + bijector) + + # Support mismatched nested structures for backwards compatibility (e.g. + # non-nested `event_shape` and a single-element list of `bijector`s). + bijector = nest.pack_sequence_as(event_shape, nest.flatten(bijector)) + + event_space_bijector = tfb.JointMap(bijector, validate_args=validate_args) else: - flat_bijectors = [None for _ in flat_event_shapes] - flat_unconstrained_event_shapes = [ - b.inverse_event_shape_tensor(s) if b is not None else s - for s, b in zip(flat_event_shapes, flat_bijectors)] + event_space_bijector = bijector + + if event_space_bijector is None: + unconstrained_event_shape = event_shape + else: + unconstrained_event_shape = ( + event_space_bijector.inverse_event_shape_tensor(event_shape)) # Construct initial locations for the internal unconstrained dists. if callable(initial_unconstrained_loc): # Sample random initialization. - flat_unconstrained_locs = [initial_unconstrained_loc( - shape=s, seed=seed()) for s in flat_unconstrained_event_shapes] - else: # Use provided initialization. - flat_unconstrained_locs = nest.flatten_up_to( - shallow_structure, initial_unconstrained_loc, check_types=False) - - if nest.is_nested(initial_unconstrained_scale): - flat_unconstrained_scales = nest.flatten_up_to( - shallow_structure, initial_unconstrained_scale, check_types=False) - else: - flat_unconstrained_scales = [ - initial_unconstrained_scale for _ in flat_unconstrained_locs] + initial_unconstrained_loc = nest.map_structure( + lambda s: initial_unconstrained_loc(shape=s, seed=seed()), + unconstrained_event_shape) + + if not nest.is_nested(initial_unconstrained_scale): + initial_unconstrained_scale = nest.map_structure( + lambda _: initial_unconstrained_scale, + unconstrained_event_shape) # Extract the rank of each event, so that we build distributions with the # correct event shapes. - flat_unconstrained_event_ndims = [prefer_static.rank_from_shape(s) - for s in flat_unconstrained_event_shapes] + unconstrained_event_ndims = nest.map_structure( + prefer_static.rank_from_shape, + unconstrained_event_shape) # Build the component surrogate posteriors. - flat_component_dists = [] - for initial_loc, initial_scale, event_ndims, bijector in zip( - flat_unconstrained_locs, - flat_unconstrained_scales, - flat_unconstrained_event_ndims, - flat_bijectors): - unconstrained_dist = trainable_distribution_fn( - initial_loc=initial_loc, initial_scale=initial_scale, - event_ndims=event_ndims, validate_args=validate_args) - flat_component_dists.append( - bijector(unconstrained_dist) if bijector is not None - else unconstrained_dist) - component_distributions = tf.nest.pack_sequence_as( - event_shape, flat_component_dists) - - # Return a `Distribution` object whose events have the specified structure. - return ( + unconstrained_distributions = nest.map_structure_up_to( + unconstrained_event_shape, + lambda loc, scale, ndims: trainable_distribution_fn( # pylint: disable=g-long-lambda + loc, scale, ndims, validate_args=validate_args), + initial_unconstrained_loc, + initial_unconstrained_scale, + unconstrained_event_ndims) + + base_distribution = ( joint_distribution_util.independent_joint_distribution_from_structure( - component_distributions, validate_args=validate_args)) + unconstrained_distributions, validate_args=validate_args)) + if event_space_bijector is None: + return base_distribution + return transformed_distribution.TransformedDistribution( + base_distribution, event_space_bijector) def _as_trainable_family(distribution): diff --git a/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py b/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py index 9cfdb20d6c..b7652d8514 100644 --- a/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py +++ b/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py @@ -245,6 +245,46 @@ def model_fn(): _ = self.evaluate(posterior_mean) _ = self.evaluate(posterior_stddev) + def test_multipart_bijector(self): + dist = tfd.JointDistributionNamed({ + 'a': tfd.Exponential(1.), + 'b': tfd.Normal(0., 1.), + 'c': lambda b, a: tfd.Sample(tfd.Normal(b, a), sample_shape=[5])}) + + seed = test_util.test_seed_stream() + surrogate_posterior = ( + tfp.experimental.vi.build_factored_surrogate_posterior( + event_shape=dist.event_shape, + constraining_bijectors=( + dist.experimental_default_event_space_bijector()), + initial_unconstrained_loc=functools.partial( + tf.random.uniform, minval=-2., maxval=2.), + seed=seed(), + validate_args=True)) + self.evaluate([v.initializer + for v in surrogate_posterior.trainable_variables]) + + # Test that the posterior has the specified event shape(s). + self.assertAllEqualNested( + self.evaluate(dist.event_shape_tensor()), + self.evaluate(surrogate_posterior.event_shape_tensor())) + + posterior_sample_ = self.evaluate(surrogate_posterior.sample(seed=seed())) + posterior_logprob_ = self.evaluate( + surrogate_posterior.log_prob(posterior_sample_)) + + # Test that all sample Tensors have the expected shapes. + check_shape = lambda s, x: self.assertAllEqual(s, x.shape) + self.assertAllAssertsNested( + check_shape, dist.event_shape, posterior_sample_) + + # Test that samples are finite and not NaN. + self.assertAllAssertsNested(self.assertAllFinite, posterior_sample_) + + # Test that logprob is scalar, finite, and not NaN. + self.assertEmpty(posterior_logprob_.shape) + self.assertAllFinite(posterior_logprob_) + def _build_tensor(ndarray, dtype, use_static_shape): # Enforce parameterized dtype and static/dynamic testing. From e0ca329f0b523134247ddf2db60863293f42c2cb Mon Sep 17 00:00:00 2001 From: bjp Date: Mon, 14 Dec 2020 13:44:09 -0800 Subject: [PATCH 19/36] Add support for non-list initial variance structures. PiperOrigin-RevId: 347460285 --- .../experimental/mcmc/diagonal_mass_matrix_adaptation.py | 2 +- tensorflow_probability/python/mcmc/sample_test.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py b/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py index 0c728cff9f..0e8c966fc9 100644 --- a/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py +++ b/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py @@ -240,7 +240,7 @@ def bootstrap_results(self, init_state): sample_stats.RunningVariance): variance_parts = [self.initial_running_variance] else: - variance_parts = self.initial_running_variance + variance_parts = list(self.initial_running_variance) diags = [variance_part.variance() for variance_part in variance_parts] diff --git a/tensorflow_probability/python/mcmc/sample_test.py b/tensorflow_probability/python/mcmc/sample_test.py index b3bb921046..1228768362 100644 --- a/tensorflow_probability/python/mcmc/sample_test.py +++ b/tensorflow_probability/python/mcmc/sample_test.py @@ -433,11 +433,17 @@ def model(): momentum_distribution=momentum_dist) bijector = pinned.experimental_default_event_space_bijector() kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bijector) + pullback_shape = bijector.inverse_event_shape(pinned.event_shape) + kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation( + kernel, + initial_running_variance=struct._make( + tfp.experimental.stats.RunningVariance.from_shape(t) + for t in pullback_shape)) state = bijector(struct._make( tfd.Uniform(-2., 2.).sample(shp) for shp in bijector.inverse_event_shape(pinned.event_shape))) self.evaluate(tfp.mcmc.sample_chain( - 3, current_state=state, kernel=kernel, seed=stream())) + 3, current_state=state, kernel=kernel, seed=stream()).all_states) if __name__ == '__main__': From 2dbfdaf5fc3f36ae370aefed855a8b4c205ab4ee Mon Sep 17 00:00:00 2001 From: emilyaf Date: Mon, 14 Dec 2020 16:56:30 -0800 Subject: [PATCH 20/36] Replace deprecated arg name in `surrogate_posteriors_test`. PiperOrigin-RevId: 347499168 --- .../vi/surrogate_posteriors_test.py | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py b/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py index b7652d8514..a58fa61e87 100644 --- a/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py +++ b/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py @@ -84,28 +84,27 @@ class FactoredSurrogatePosterior(test_util.TestCase): @parameterized.named_parameters( {'testcase_name': 'TensorEvent', 'event_shape': tf.TensorShape([4]), - 'constraining_bijectors': [tfb.Sigmoid()], + 'bijector': [tfb.Sigmoid()], 'dtype': np.float64, 'use_static_shape': True}, {'testcase_name': 'ListEvent', 'event_shape': [tf.TensorShape([3]), tf.TensorShape([]), tf.TensorShape([2, 2])], - 'constraining_bijectors': [tfb.Softplus(), None, tfb.FillTriangular()], + 'bijector': [tfb.Softplus(), None, tfb.FillTriangular()], 'dtype': np.float32, 'use_static_shape': False}, {'testcase_name': 'DictEvent', 'event_shape': {'x': tf.TensorShape([1]), 'y': tf.TensorShape([])}, - 'constraining_bijectors': None, + 'bijector': None, 'dtype': np.float64, 'use_static_shape': True}, {'testcase_name': 'NestedEvent', 'event_shape': {'x': [tf.TensorShape([1]), tf.TensorShape([1, 2])], 'y': tf.TensorShape([])}, - 'constraining_bijectors': { + 'bijector': { 'x': [tfb.Identity(), tfb.Softplus()], 'y': tfb.Sigmoid()}, 'dtype': np.float32, 'use_static_shape': True}, ) - def test_specifying_event_shape(self, event_shape, - constraining_bijectors, - dtype, use_static_shape): + def test_specifying_event_shape( + self, event_shape, bijector, dtype, use_static_shape): seed = test_util.test_seed_stream() surrogate_posterior = ( tfp.experimental.vi.build_factored_surrogate_posterior( @@ -114,7 +113,7 @@ def test_specifying_event_shape(self, event_shape, dtype=np.int32, use_static_shape=use_static_shape), event_shape), - constraining_bijectors=constraining_bijectors, + bijector=bijector, initial_unconstrained_loc=functools.partial( tf.random.uniform, minval=-2., maxval=2., dtype=dtype), seed=seed(), @@ -150,7 +149,7 @@ def test_specifying_event_shape(self, event_shape, 'event_shape': [4], 'initial_loc': np.array([[[0.9, 0.1, 0.5, 0.7]]]), 'implicit_batch_shape': [1, 1], - 'constraining_bijectors': tfb.Sigmoid(), + 'bijector': tfb.Sigmoid(), 'dtype': np.float32, 'use_static_shape': False}, {'testcase_name': 'ListEvent', 'event_shape': [[3], [], [2, 2]], @@ -158,29 +157,28 @@ def test_specifying_event_shape(self, event_shape, 0.1, np.array([[1., 0], [-4., 2.]])], 'implicit_batch_shape': [], - 'constraining_bijectors': [tfb.Softplus(), None, tfb.FillTriangular()], + 'bijector': [tfb.Softplus(), None, tfb.FillTriangular()], 'dtype': np.float64, 'use_static_shape': True}, {'testcase_name': 'DictEvent', 'event_shape': {'x': [2], 'y': []}, 'initial_loc': {'x': np.array([[0.9, 1.2]]), 'y': np.array([-4.1])}, 'implicit_batch_shape': [1], - 'constraining_bijectors': None, + 'bijector': None, 'dtype': np.float32, 'use_static_shape': False}, ) def test_specifying_initial_loc(self, event_shape, initial_loc, - implicit_batch_shape, - constraining_bijectors, + implicit_batch_shape, bijector, dtype, use_static_shape): initial_loc = tf.nest.map_structure( lambda s: _build_tensor(s, dtype=dtype, # pylint: disable=g-long-lambda use_static_shape=use_static_shape), initial_loc) - if constraining_bijectors is not None: + if bijector is not None: initial_unconstrained_loc = tf.nest.map_structure( lambda x, b: x if b is None else b.inverse(x), - initial_loc, constraining_bijectors) + initial_loc, bijector) else: initial_unconstrained_loc = initial_loc @@ -189,7 +187,7 @@ def test_specifying_initial_loc(self, event_shape, initial_loc, event_shape=event_shape, initial_unconstrained_loc=initial_unconstrained_loc, initial_unconstrained_scale=1e-6, - constraining_bijectors=constraining_bijectors, + bijector=bijector, validate_args=True)) self.evaluate([v.initializer for v in surrogate_posterior.trainable_variables]) @@ -223,7 +221,7 @@ def model_fn(): surrogate_posterior = ( tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], - constraining_bijectors=[tfb.Softplus(), tfb.Softplus()])) + bijector=[tfb.Softplus(), tfb.Softplus()])) # Fit model. y = [0.2, 0.5, 0.3, 0.7] @@ -255,7 +253,7 @@ def test_multipart_bijector(self): surrogate_posterior = ( tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=dist.event_shape, - constraining_bijectors=( + bijector=( dist.experimental_default_event_space_bijector()), initial_unconstrained_loc=functools.partial( tf.random.uniform, minval=-2., maxval=2.), From 37c94552df37e49bdaae36e79374dc4e20bbf0b9 Mon Sep 17 00:00:00 2001 From: Dave Moore Date: Mon, 14 Dec 2020 17:35:05 -0800 Subject: [PATCH 21/36] Update docstring to reflect `constraining_bijectors` deprecation. PiperOrigin-RevId: 347505196 --- .../python/experimental/vi/surrogate_posteriors.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py b/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py index 1a4a4c37e1..00654bb529 100644 --- a/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py +++ b/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py @@ -157,14 +157,7 @@ def build_factored_surrogate_posterior( `tfd.TransformedDistribution(underlying_dist, bijector)` if a corresponding constraining bijector is specified, otherwise it is modeled as supported on the unconstrained real line. - constraining_bijectors: Optional `tfb.Bijector` instance, or nested - structure of such instances, defining support(s) of the posterior - variables. The structure must match that of `event_shape` and may - contain `None` values. A posterior variable will - be modeled as `tfd.TransformedDistribution(underlying_dist, - constraining_bijector)` if a corresponding constraining bijector is - specified, otherwise it is modeled as supported on the - unconstrained real line. + constraining_bijectors: Deprecated alias for `bijector`. initial_unconstrained_loc: Optional Python `callable` with signature `tensor = initial_unconstrained_loc(shape, seed)` used to sample real-valued initializations for the unconstrained representation of each From 9ce021ab0a8eff1b9e67cb27c8f98120a69e9fab Mon Sep 17 00:00:00 2001 From: leben Date: Mon, 14 Dec 2020 20:37:27 -0800 Subject: [PATCH 22/36] Exclude `StoppingRatioLogistic` distribution from `testCanConstructAndSampleDistribution`. PiperOrigin-RevId: 347528419 --- .../python/distributions/distribution_properties_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/distributions/distribution_properties_test.py b/tensorflow_probability/python/distributions/distribution_properties_test.py index 02f9064f06..f06d7c8198 100644 --- a/tensorflow_probability/python/distributions/distribution_properties_test.py +++ b/tensorflow_probability/python/distributions/distribution_properties_test.py @@ -282,7 +282,8 @@ def testCanConstructAndSampleDistribution(self, data): 'Empirical|event_ndims=2', 'FiniteDiscrete', 'MultivariateStudentTLinearOperator', 'PoissonLogNormalQuadratureCompound', - 'SphericalUniform', 'SinhArcsinh') + 'SphericalUniform', 'SinhArcsinh', + 'StoppingRatioLogistic',) non_trainable_dists = ( high_gt_low_constraint_dists + not_annotated_dists + dhps.INSTANTIABLE_META_DISTS) From a2e2b598db7ab9159416a94e43a4bb84cff90816 Mon Sep 17 00:00:00 2001 From: leben Date: Mon, 14 Dec 2020 21:39:32 -0800 Subject: [PATCH 23/36] Disable `GeneralizedExtremeValue` distribution from `jax_transformation_test.testLogProbSample`. Re-enable after http://b/175654800 is resolved. PiperOrigin-RevId: 347535479 --- .../python/distributions/jax_transformation_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow_probability/python/distributions/jax_transformation_test.py b/tensorflow_probability/python/distributions/jax_transformation_test.py index 72aee0c4b9..8a5183028c 100644 --- a/tensorflow_probability/python/distributions/jax_transformation_test.py +++ b/tensorflow_probability/python/distributions/jax_transformation_test.py @@ -80,6 +80,7 @@ JVP_SAMPLE_BLOCKLIST = () JVP_LOGPROB_SAMPLE_BLOCKLIST = ( + 'GeneralizedExtremeValue', # http://b/175654800 'Skellam', # http://b/171079052 ) JVP_LOGPROB_PARAM_BLOCKLIST = ( @@ -89,6 +90,7 @@ VJP_SAMPLE_BLOCKLIST = () VJP_LOGPROB_SAMPLE_BLOCKLIST = ( + 'GeneralizedExtremeValue', # http://b/175654800 'Skellam', # http://b/171079052 ) VJP_LOGPROB_PARAM_BLOCKLIST = ( From 7792a90ca53f6326fae1b6e5f556b9268439ab95 Mon Sep 17 00:00:00 2001 From: Dave Moore Date: Tue, 15 Dec 2020 10:09:57 -0800 Subject: [PATCH 24/36] Remove TFP references to deprecated Affine bijectors. PiperOrigin-RevId: 347638177 --- discussion/fun_mcmc/fun_mcmc_test.py | 8 +- .../Bayesian_Gaussian_Mixture_Model.ipynb | 96 +++++-------------- ...ity_Case_Study_Covariance_Estimation.ipynb | 8 +- .../python/bijectors/bijector.py | 8 +- .../bijectors/bijector_composition_test.py | 6 +- .../bijectors/bijector_properties_test.py | 8 -- .../python/bijectors/blockwise_test.py | 12 +-- .../python/bijectors/chain_test.py | 29 ++++-- .../bijectors/discrete_cosine_transform.py | 2 +- .../python/bijectors/expm1.py | 2 +- .../python/bijectors/hypothesis_testlib.py | 3 - .../python/bijectors/invert_test.py | 2 +- .../bijectors/masked_autoregressive_test.py | 2 +- .../python/bijectors/real_nvp.py | 3 +- .../python/bijectors/real_nvp_test.py | 2 +- .../python/bijectors/softfloor.py | 2 +- .../python/bijectors/tanh.py | 5 +- .../python/bijectors/tanh_test.py | 7 +- .../python/distributions/BUILD | 4 +- .../distributions/quantized_distribution.py | 2 +- .../python/distributions/sample_test.py | 3 +- .../python/distributions/sinh_arcsinh.py | 11 +-- .../distributions/transformed_distribution.py | 4 +- .../vector_exponential_diag_test.py | 2 +- .../vector_exponential_linear_operator.py | 21 ++-- .../mcmc/nuts_autobatching_test.py | 6 +- .../psd_kernels/feature_transformed_test.py | 4 +- .../python/mcmc/transformed_kernel_test.py | 6 +- .../python/sts/autoregressive.py | 2 +- .../python/sts/dynamic_regression.py | 2 +- .../python/sts/local_level.py | 2 +- .../python/sts/local_linear_trend.py | 2 +- tensorflow_probability/python/sts/seasonal.py | 2 +- .../python/sts/semilocal_linear_trend.py | 2 +- .../python/sts/smooth_seasonal.py | 2 +- tensorflow_probability/python/sts/sum.py | 2 +- 36 files changed, 115 insertions(+), 169 deletions(-) diff --git a/discussion/fun_mcmc/fun_mcmc_test.py b/discussion/fun_mcmc/fun_mcmc_test.py index 28ac5a950a..2dae3d1bba 100644 --- a/discussion/fun_mcmc/fun_mcmc_test.py +++ b/discussion/fun_mcmc/fun_mcmc_test.py @@ -371,8 +371,8 @@ def log_prob_fn(x, y): tfp.distributions.Normal(self._constant(1.), 1.).log_prob(y)), () bijectors = [ - tfp.bijectors.AffineScalar(scale=self._constant(2.)), - tfp.bijectors.AffineScalar(scale=self._constant(3.)) + tfp.bijectors.Scale(scale=self._constant(2.)), + tfp.bijectors.Scale(scale=self._constant(3.)) ] (transformed_log_prob_fn, @@ -398,8 +398,8 @@ def log_prob_fn(x, y): tfp.distributions.Normal(self._constant(1.), 1.).log_prob(y)), () bijectors = { - 'x': tfp.bijectors.AffineScalar(scale=self._constant(2.)), - 'y': tfp.bijectors.AffineScalar(scale=self._constant(3.)) + 'x': tfp.bijectors.Scale(scale=self._constant(2.)), + 'y': tfp.bijectors.Scale(scale=self._constant(3.)) } (transformed_log_prob_fn, diff --git a/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb index a87a610d77..eacd036a67 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb @@ -3,7 +3,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "htW5SiGzeXYm" }, "source": [ @@ -14,11 +13,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "cellView": "form", - "colab": {}, - "colab_type": "code", "id": "9HGeUNoteaSm" }, "outputs": [], @@ -39,7 +35,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "JJ3UDciDVcB5" }, "source": [ @@ -64,7 +59,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "lin40yCC6eBo" }, "source": [ @@ -74,7 +68,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "eZs1ShikNBK2" }, "source": [ @@ -84,25 +77,23 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "7JjokKMbk2hJ" }, "source": [ "For $k\\in\\{1,\\ldots, K\\}$ mixture components each of dimension $D$, we'd like to model $i\\in\\{1,\\ldots,N\\}$ iid samples using the following Bayesian Gaussian Mixture Model:\n", "\n", "$$\\begin{align*}\n", - "\\theta &\\sim \\text{Dirichlet}(\\text{concentration}=\\alpha_0)\\\\\n", - "\\mu_k &\\sim \\text{Normal}(\\text{loc}=\\mu_{0k}, \\text{scale}=I_D)\\\\\n", - "T_k &\\sim \\text{Wishart}(\\text{df}=5, \\text{scale}=I_D)\\\\\n", - "Z_i &\\sim \\text{Categorical}(\\text{probs}=\\theta)\\\\\n", - "Y_i &\\sim \\text{Normal}(\\text{loc}=\\mu_{z_i}, \\text{scale}=T_{z_i}^{-1/2})\\\\\n", + "\\theta \u0026\\sim \\text{Dirichlet}(\\text{concentration}=\\alpha_0)\\\\\n", + "\\mu_k \u0026\\sim \\text{Normal}(\\text{loc}=\\mu_{0k}, \\text{scale}=I_D)\\\\\n", + "T_k \u0026\\sim \\text{Wishart}(\\text{df}=5, \\text{scale}=I_D)\\\\\n", + "Z_i \u0026\\sim \\text{Categorical}(\\text{probs}=\\theta)\\\\\n", + "Y_i \u0026\\sim \\text{Normal}(\\text{loc}=\\mu_{z_i}, \\text{scale}=T_{z_i}^{-1/2})\\\\\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "iySRABi0qZnQ" }, "source": [ @@ -112,7 +103,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "Y6X_Beihwzyi" }, "source": [ @@ -131,10 +121,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "uswTWdgNu46j" }, "outputs": [], @@ -163,7 +151,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "Uj9uHZN2yUqz" }, "source": [ @@ -180,10 +167,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "nc4yy6vW-lC_" }, "outputs": [], @@ -197,9 +182,9 @@ " scale=tf.ones_like(loc)),\n", " reinterpreted_batch_ndims=1),\n", " bijector=tfb.Chain([\n", - " tfb.Affine(shift=loc),\n", - " tfb.Invert(tfb.Affine(scale_tril=chol_precision_tril,\n", - " adjoint=True)),\n", + " tfb.Shift(shift=loc),\n", + " tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=chol_precision_tril,\n", + " adjoint=True)),\n", " ]),\n", " name=name)" ] @@ -207,7 +192,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "JDOkWhDQg4ZG" }, "source": [ @@ -219,7 +203,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "Pfkc8cmhh2Qz" }, "source": [ @@ -228,12 +211,11 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": { "height": 152 }, - "colab_type": "code", "id": "GhqbjwlIh1Vn", "outputId": "3ea12c10-cb9b-4558-aedd-386b37adc909" }, @@ -291,7 +273,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "N60z8scN1v6E" }, "source": [ @@ -300,10 +281,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "xhzxySDjL2-S" }, "outputs": [], @@ -316,10 +295,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "xAOmHhZ7LzDQ" }, "outputs": [], @@ -353,10 +330,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "CpLnRJr2TXYD" }, "outputs": [], @@ -385,7 +360,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "7jTMXdymV1QJ" }, "source": [ @@ -395,7 +369,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "rl4brz3G3pS7" }, "source": [ @@ -404,10 +377,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "1AJZAtwXV8RQ" }, "outputs": [], @@ -425,7 +396,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "zVOvMh7MV37A" }, "source": [ @@ -435,7 +405,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "cdN3iKFT32Jp" }, "source": [ @@ -446,10 +415,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "tVoaDFSf7L_j" }, "outputs": [], @@ -459,10 +426,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "a0OMIWIYeMmQ" }, "outputs": [], @@ -482,7 +447,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "TVpiT3LLyfcO" }, "source": [ @@ -492,7 +456,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "JS8XOsxiyiBV" }, "source": [ @@ -507,7 +470,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "Vt9SXJzO0Cks" }, "source": [ @@ -526,10 +488,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "_atEQrDR7JvG" }, "outputs": [], @@ -545,10 +505,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "0zq6QJJ-NSPJ" }, "outputs": [], @@ -575,7 +533,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "QLEz96mg6fpZ" }, "source": [ @@ -584,10 +541,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "_ceX1A3-ZFiN" }, "outputs": [], @@ -601,12 +556,11 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": { "height": 270 }, - "colab_type": "code", "id": "bqJ6RSJxegC6", "outputId": "e0867545-0509-4077-d89d-74e1d5280062" }, @@ -642,12 +596,11 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": { "height": 289 }, - "colab_type": "code", "id": "zFOU0j9kPdUy", "outputId": "17f4ce0c-24c3-4cf4-ebe8-b932caac7ba4" }, @@ -676,7 +629,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "NmfNIM1c6mwc" }, "source": [ @@ -686,7 +638,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "t8LeIeMn6ot4" }, "source": [ @@ -698,7 +649,6 @@ "colab": { "collapsed_sections": [], "name": "Bayesian Gaussian Mixture Model", - "private_outputs": false, "provenance": [], "toc_visible": true }, diff --git a/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Probability_Case_Study_Covariance_Estimation.ipynb b/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Probability_Case_Study_Covariance_Estimation.ipynb index 66720fde64..1b53a54cda 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Probability_Case_Study_Covariance_Estimation.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Probability_Case_Study_Covariance_Estimation.ipynb @@ -2672,8 +2672,8 @@ "\n", "Our approach (courtesy of [this notebook](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb)):\n", "1. Use `tfd.Independent()` to combine a batch of 1-D `Normal` random variables into a single multi-dimensional random variable. The `reinterpreted_batch_ndims` parameter for `Independent()` specifies the number of batch dimensions that should be reinterpreted as event dimensions. In our case we create a 1-D batch of length 2 that we transform into a 1-D event of length 2, so `reinterpreted_batch_ndims=1`.\n", - "2. Apply a bijector to add the desired covariance: `tfb.Invert(tfb.Affine(scale_tril=precision_cholesky, adjoint=True))`. Note that above we're multiplying our iid normal random variables by the transpose of the inverse of the Cholesky factor of the precision matrix $(B^{-T}X)$. The `tfb.Invert` takes care of inverting $B$, and the `adjoint=True` flag performs the transpose.\n", - "3. Apply a bijector to add the desired offset: `tfb.Affine(shift=shift)` Note that we have to do the shift as a separate step from the initial inverted affine transform because otherwise the inverted scale is applied to the shift (since the inverse of $y=Ax+b$ is $x=A^{-1}y - A^{-1}b$).\n" + "2. Apply a bijector to add the desired covariance: `tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=precision_cholesky, adjoint=True))`. Note that above we're multiplying our iid normal random variables by the transpose of the inverse of the Cholesky factor of the precision matrix $(B^{-T}X)$. The `tfb.Invert` takes care of inverting $B$, and the `adjoint=True` flag performs the transpose.\n", + "3. Apply a bijector to add the desired offset: `tfb.Shift(shift=shift)` Note that we have to do the shift as a separate step from the initial inverted affine transform because otherwise the inverted scale is applied to the shift (since the inverse of $y=Ax+b$ is $x=A^{-1}y - A^{-1}b$).\n" ] }, { @@ -2694,8 +2694,8 @@ " scale=tf.ones_like(loc)),\n", " reinterpreted_batch_ndims=1),\n", " bijector=tfb.Chain([\n", - " tfb.Affine(shift=loc),\n", - " tfb.Invert(tfb.Affine(scale_tril=precision_cholesky,\n", + " tfb.Shift(shift=loc),\n", + " tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=precision_cholesky,\n", " adjoint=True)),\n", " ]),\n", " name=name)" diff --git a/tensorflow_probability/python/bijectors/bijector.py b/tensorflow_probability/python/bijectors/bijector.py index 1bbaddc9b7..4de390e18c 100644 --- a/tensorflow_probability/python/bijectors/bijector.py +++ b/tensorflow_probability/python/bijectors/bijector.py @@ -718,14 +718,14 @@ def __call__(self, value, name=None, **kwargs): ```python sigmoid = tfb.Reciprocal()( - tfb.AffineScalar(shift=1.)( + tfb.Shift(shift=1.)( tfb.Exp()( - tfb.AffineScalar(scale=-1.)))) + tfb.Scale(scale=-1.)))) # ==> `tfb.Chain([ # tfb.Reciprocal(), - # tfb.AffineScalar(shift=1.), + # tfb.Shift(shift=1.), # tfb.Exp(), - # tfb.AffineScalar(scale=-1.), + # tfb.Scale(scale=-1.), # ])` # ie, `tfb.Sigmoid()` log_normal = tfb.Exp()(tfd.Normal(0, 1)) diff --git a/tensorflow_probability/python/bijectors/bijector_composition_test.py b/tensorflow_probability/python/bijectors/bijector_composition_test.py index 5e4b80e547..ea89959302 100644 --- a/tensorflow_probability/python/bijectors/bijector_composition_test.py +++ b/tensorflow_probability/python/bijectors/bijector_composition_test.py @@ -38,9 +38,9 @@ def testComposeFromChainBijector(self): x = tf.constant([-5., 0., 5.]) sigmoid = functools.reduce(lambda chain, f: chain(f), [ tfb.Reciprocal(), - tfb.AffineScalar(shift=1.), + tfb.Shift(shift=1.), tfb.Exp(), - tfb.AffineScalar(scale=-1.), + tfb.Scale(scale=-1.), ]) self.assertIsInstance(sigmoid, tfb.Chain) self.assertAllClose( @@ -50,7 +50,7 @@ def testComposeFromChainBijector(self): def testComposeFromTransformedDistribution(self): actual_log_normal = tfb.Exp()(tfd.TransformedDistribution( distribution=tfd.Normal(0, 1), - bijector=tfb.AffineScalar(shift=0.5, scale=2.))) + bijector=tfb.Shift(shift=0.5)(tfb.Scale(scale=2.)))) expected_log_normal = tfd.LogNormal(0.5, 2.) x = tf.constant([0.1, 1., 5.]) self.assertAllClose( diff --git a/tensorflow_probability/python/bijectors/bijector_properties_test.py b/tensorflow_probability/python/bijectors/bijector_properties_test.py index c6b73b37f1..c5d73a8dad 100644 --- a/tensorflow_probability/python/bijectors/bijector_properties_test.py +++ b/tensorflow_probability/python/bijectors/bijector_properties_test.py @@ -36,7 +36,6 @@ TF2_FRIENDLY_BIJECTORS = ( - 'AffineScalar', 'Ascending', 'BatchNormalization', # 'CategoricalToDiscrete', TODO(b/137956955): Add support @@ -59,7 +58,6 @@ 'KumaraswamyCDF', 'Log', 'Log1p', - 'MatvecLU', 'MatrixInverseTriL', 'MoyalCDF', 'NormalCDF', @@ -92,13 +90,11 @@ ) BIJECTOR_PARAMS_NDIMS = { - 'AffineScalar': dict(shift=0, scale=0, log_scale=0), 'FrechetCDF': dict(loc=0, scale=0, concentration=0), 'GompertzCDF': dict(concentration=0, rate=0), 'GumbelCDF': dict(loc=0, scale=0), 'GeneralizedExtremeValueCDF': dict(loc=0, scale=0, concentration=0), 'KumaraswamyCDF': dict(concentration1=0, concentration0=0), - 'MatvecLU': dict(lower_upper=2, permutation=1), 'MoyalCDF': dict(loc=0, scale=0), 'Power': dict(power=0), 'RayleighCDF': dict(scale=0), @@ -125,7 +121,6 @@ INVERT_LDJ = {FLDJ: ILDJ, ILDJ: FLDJ} NO_LDJ_GRADS_EXPECTED = { - 'AffineScalar': dict(shift={FLDJ, ILDJ}), 'BatchNormalization': dict(beta={FLDJ, ILDJ}), 'FrechetCDF': dict(loc={ILDJ}), 'GeneralizedExtremeValueCDF': dict(loc={ILDJ}), @@ -135,7 +130,6 @@ } TRANSFORM_DIAGONAL_ALLOWLIST = { - 'AffineScalar', 'BatchNormalization', 'DiscreteCosineTransform', 'Exp', @@ -813,8 +807,6 @@ def ensure_nonzero(x): tfp_hps.softplus_plus_eps(), 'temperature': tfp_hps.softplus_plus_eps(eps=0.5), - 'AffineScalar.scale': - tfp_hps.softplus_plus_eps(), 'Scale.scale': tfp_hps.softplus_plus_eps(), 'ScaleMatvecDiag.scale_diag': diff --git a/tensorflow_probability/python/bijectors/blockwise_test.py b/tensorflow_probability/python/bijectors/blockwise_test.py index 31dc836190..b64b47f53c 100644 --- a/tensorflow_probability/python/bijectors/blockwise_test.py +++ b/tensorflow_probability/python/bijectors/blockwise_test.py @@ -48,7 +48,7 @@ def testExplicitBlocks(self, dynamic_shape, batch_shape): block_sizes.shape)) exp = tfb.Exp() sp = tfb.Softplus() - aff = tfb.Affine(scale_diag=[2., 3., 4.]) + aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise( bijectors=[exp, sp, aff], block_sizes=block_sizes, @@ -123,7 +123,7 @@ def testSizeChangingExplicitBlocks(self, dynamic_shape, batch_shape): block_sizes, shape=block_sizes.shape) exp = tfb.Exp() sc = tfb.SoftmaxCentered() - aff = tfb.Affine(scale_diag=[2., 3., 4.]) + aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise( bijectors=[exp, sc, aff], block_sizes=block_sizes, @@ -201,7 +201,7 @@ def testSizeChangingExplicitBlocks(self, dynamic_shape, batch_shape): def testBijectiveAndFinite(self): exp = tfb.Exp() sp = tfb.Softplus() - aff = tfb.Affine(scale_diag=[2., 3., 4.]) + aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff], block_sizes=[2, 1, 3]) x = tf.cast([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=tf.float32) @@ -219,17 +219,17 @@ def testBijectiveAndFinite(self): def testImplicitBlocks(self): exp = tfb.Exp() sp = tfb.Softplus() - aff = tfb.Affine(scale_diag=[2.]) + aff = tfb.ScaleMatvecDiag(scale_diag=[2.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff]) self.assertAllEqual(self.evaluate(blockwise.block_sizes), [1, 1, 1]) def testName(self): exp = tfb.Exp() sp = tfb.Softplus() - aff = tfb.Affine(scale_diag=[2., 3., 4.]) + aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff], block_sizes=[2, 1, 3]) self.assertStartsWith(blockwise.name, - 'blockwise_of_exp_and_softplus_and_affine') + 'blockwise_of_exp_and_softplus_and_scale_matvec_diag') def testNameOneBijector(self): exp = tfb.Exp() diff --git a/tensorflow_probability/python/bijectors/chain_test.py b/tensorflow_probability/python/bijectors/chain_test.py index 1bf97c1a0f..30aa14eb87 100644 --- a/tensorflow_probability/python/bijectors/chain_test.py +++ b/tensorflow_probability/python/bijectors/chain_test.py @@ -108,19 +108,24 @@ def testMinEventNdimsChain(self): self.assertEqual(0, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Affine(), tfb.Affine(), tfb.Affine()]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), + tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), + tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Exp(), tfb.Affine()]) + chain = tfb.Chain([tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Affine(), tfb.Exp()]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), tfb.Exp()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Affine(), tfb.Exp(), tfb.Softplus(), tfb.Affine()]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), + tfb.Exp(), + tfb.Softplus(), + tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) @@ -129,11 +134,13 @@ def testMinEventNdimsShapeChangingAddDims(self): self.assertEqual(0, chain.forward_min_event_ndims) self.assertEqual(3, chain.inverse_min_event_ndims) - chain = tfb.Chain([ShapeChanging(), tfb.Affine()]) + chain = tfb.Chain([ShapeChanging(), + tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(4, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Affine(), ShapeChanging()]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), + ShapeChanging()]) self.assertEqual(0, chain.forward_min_event_ndims) self.assertEqual(3, chain.inverse_min_event_ndims) @@ -146,11 +153,13 @@ def testMinEventNdimsShapeChangingRemoveDims(self): self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) - chain = tfb.Chain([ShapeChanging(3, 0), tfb.Affine()]) + chain = tfb.Chain([ShapeChanging(3, 0), + tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Affine(), ShapeChanging(3, 0)]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), + ShapeChanging(3, 0)]) self.assertEqual(4, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) @@ -191,7 +200,7 @@ def testMinEventNdimsWithJointMap(self): def testChainExpAffine(self): scale_diag = np.array([1., 2., 3.], dtype=np.float32) - chain = tfb.Chain([tfb.Exp(), tfb.Affine(scale_diag=scale_diag)]) + chain = tfb.Chain([tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=scale_diag)]) x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] y = [1., 4., 27.] self.assertAllClose(y, self.evaluate(chain.forward(x))) @@ -206,7 +215,7 @@ def testChainExpAffine(self): def testChainAffineExp(self): scale_diag = np.array([1., 2., 3.], dtype=np.float32) - chain = tfb.Chain([tfb.Affine(scale_diag=scale_diag), tfb.Exp()]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=scale_diag), tfb.Exp()]) x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] y = [1., 4., 9.] self.assertAllClose(y, self.evaluate(chain.forward(x))) diff --git a/tensorflow_probability/python/bijectors/discrete_cosine_transform.py b/tensorflow_probability/python/bijectors/discrete_cosine_transform.py index deb1449311..302f4d68e1 100644 --- a/tensorflow_probability/python/bijectors/discrete_cosine_transform.py +++ b/tensorflow_probability/python/bijectors/discrete_cosine_transform.py @@ -39,7 +39,7 @@ class DiscreteCosineTransform(bijector.Bijector): The inverse `X = g^{-1}(Y) = IDCT(Y)`, where IDCT is DCT-III for type==2. - This bijector can be interleaved with Affine bijectors to build a cascade of + This bijector can be interleaved with affine bijectors to build a cascade of structured efficient linear layers as in [1]. Note that the operator applied is orthonormal (i.e. `norm='ortho'`). diff --git a/tensorflow_probability/python/bijectors/expm1.py b/tensorflow_probability/python/bijectors/expm1.py index 7f2f6c8461..468047fd64 100644 --- a/tensorflow_probability/python/bijectors/expm1.py +++ b/tensorflow_probability/python/bijectors/expm1.py @@ -32,7 +32,7 @@ class Expm1(bijector.Bijector): """Compute `Y = g(X) = exp(X) - 1`. - This `Bijector` is no different from Chain([AffineScalar(shift=-1), Exp()]). + This `Bijector` is no different from Chain([Shift(-1), Exp()]). However, this makes use of the more numerically stable routines `tf.math.expm1` and `tf.log1p`. diff --git a/tensorflow_probability/python/bijectors/hypothesis_testlib.py b/tensorflow_probability/python/bijectors/hypothesis_testlib.py index 1a8f8160dc..5ce6e685d1 100644 --- a/tensorflow_probability/python/bijectors/hypothesis_testlib.py +++ b/tensorflow_probability/python/bijectors/hypothesis_testlib.py @@ -122,9 +122,6 @@ def bijector_supports(): return BIJECTOR_SUPPORTS Support = tfp_hps.Support # pylint: disable=invalid-name supports = { - 'AffineScalar': - BijectorSupport(Support.SCALAR_UNCONSTRAINED, - Support.SCALAR_UNCONSTRAINED), 'Ascending': BijectorSupport(Support.VECTOR_UNCONSTRAINED, Support.VECTOR_STRICTLY_INCREASING), diff --git a/tensorflow_probability/python/bijectors/invert_test.py b/tensorflow_probability/python/bijectors/invert_test.py index 9ca74513d2..fa8618fd34 100644 --- a/tensorflow_probability/python/bijectors/invert_test.py +++ b/tensorflow_probability/python/bijectors/invert_test.py @@ -36,7 +36,7 @@ def testBijector(self): for fwd in [ tfb.Identity(), tfb.Exp(), - tfb.Affine(shift=[0., 1.], scale_diag=[2., 3.]), + tfb.ScaleMatvecDiag(scale_diag=[2., 3.]), tfb.Softplus(), tfb.SoftmaxCentered(), ]: diff --git a/tensorflow_probability/python/bijectors/masked_autoregressive_test.py b/tensorflow_probability/python/bijectors/masked_autoregressive_test.py index 58b5388c22..a57f0342e2 100644 --- a/tensorflow_probability/python/bijectors/masked_autoregressive_test.py +++ b/tensorflow_probability/python/bijectors/masked_autoregressive_test.py @@ -103,7 +103,7 @@ def _bijector_fn(x): shift = reshape_output(shift) logit_gate = reshape_output(logit_gate) gate = tf.nn.sigmoid(logit_gate) - return tfb.AffineScalar(shift=(1. - gate) * shift, scale=gate) + return tfb.Shift(shift=(1. - gate) * shift)(tfb.Scale(scale=gate)) return _bijector_fn diff --git a/tensorflow_probability/python/bijectors/real_nvp.py b/tensorflow_probability/python/bijectors/real_nvp.py index 9a9feeee8b..d53b2587a5 100644 --- a/tensorflow_probability/python/bijectors/real_nvp.py +++ b/tensorflow_probability/python/bijectors/real_nvp.py @@ -48,7 +48,8 @@ class RealNVP(bijector_lib.Bijector): while the first `d` units are 'masked' and left unchanged. Real NVP's `shift_and_log_scale_fn` computes vector-valued quantities. For scale-and-shift transforms that do not depend on any masked units, i.e. - `d=0`, use the `tfb.Affine` bijector with learned parameters instead. + `d=0`, use the `tfb.Scale` and `tfb.Shift` bijectors with learned parameters + instead. Masking is currently only supported for base distributions with `event_ndims=1`. For more sophisticated masking schemes like checkerboard or diff --git a/tensorflow_probability/python/bijectors/real_nvp_test.py b/tensorflow_probability/python/bijectors/real_nvp_test.py index 613699f3fe..abb9531e96 100644 --- a/tensorflow_probability/python/bijectors/real_nvp_test.py +++ b/tensorflow_probability/python/bijectors/real_nvp_test.py @@ -227,7 +227,7 @@ def _bijector_fn(x, output_units): shift = reshape_output(shift) logit_gate = reshape_output(logit_gate) gate = tf.nn.sigmoid(logit_gate) - return tfb.AffineScalar(shift=(1. - gate) * shift, scale=gate) + return tfb.Shift(shift=(1. - gate) * shift)(tfb.Scale(scale=gate)) return tf1.make_template('gated_bijector', _bijector_fn) diff --git a/tensorflow_probability/python/bijectors/softfloor.py b/tensorflow_probability/python/bijectors/softfloor.py index e036111088..7860b88f21 100644 --- a/tensorflow_probability/python/bijectors/softfloor.py +++ b/tensorflow_probability/python/bijectors/softfloor.py @@ -93,7 +93,7 @@ class Softfloor(bijector.Bijector): # Ceiling is just a shifted floor at non-integer points. soft_ceiling = tfb.Chain( - [tfb.AffineScalar(1.), + [tfb.Shift(1.), tfb.Softfloor(temperature=1.)]) soft_ceiling.forward(x) # Should be close to [3., 5., 6.] ``` diff --git a/tensorflow_probability/python/bijectors/tanh.py b/tensorflow_probability/python/bijectors/tanh.py index 6090b94d71..44991b552b 100644 --- a/tensorflow_probability/python/bijectors/tanh.py +++ b/tensorflow_probability/python/bijectors/tanh.py @@ -34,9 +34,10 @@ class Tanh(bijector.Bijector): This can be achieved by an affine transform of the Sigmoid bijector, i.e., it is equivalent to ``` - tfb.Chain([tfb.Affine(shift=-1, scale=2.), + tfb.Chain([tfb.Shift(shift=-1.), + tfb.Scale(scale=2.), tfb.Sigmoid(), - tfb.Affine(scale=2.)]) + tfb.Scale(scale=2.)]) ``` However, using the `Tanh` bijector directly is slightly faster and more diff --git a/tensorflow_probability/python/bijectors/tanh_test.py b/tensorflow_probability/python/bijectors/tanh_test.py index cf210027db..0c53b75618 100644 --- a/tensorflow_probability/python/bijectors/tanh_test.py +++ b/tensorflow_probability/python/bijectors/tanh_test.py @@ -66,11 +66,10 @@ def testBijectiveAndFinite(self): def testMatchWithAffineTransform(self): direct_bj = tfb.Tanh() indirect_bj = tfb.Chain([ - tfb.AffineScalar( - shift=tf.cast(-1.0, dtype=tf.float64), - scale=tf.cast(2.0, dtype=tf.float64)), + tfb.Shift(tf.cast(-1.0, dtype=tf.float64)), + tfb.Scale(tf.cast(2.0, dtype=tf.float64)), tfb.Sigmoid(), - tfb.AffineScalar(scale=tf.cast(2.0, dtype=tf.float64)) + tfb.Scale(tf.cast(2.0, dtype=tf.float64)) ]) x = np.linspace(-3.0, 3.0, 100) diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index 36eb9dc026..b3cad34ec2 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -1735,9 +1735,10 @@ multi_substrate_py_library( ":normal", ":transformed_distribution", # tensorflow dep, - "//tensorflow_probability/python/bijectors:affine_scalar", "//tensorflow_probability/python/bijectors:chain", "//tensorflow_probability/python/bijectors:identity", + "//tensorflow_probability/python/bijectors:scale", + "//tensorflow_probability/python/bijectors:shift", "//tensorflow_probability/python/bijectors:sinh_arcsinh", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", @@ -1983,7 +1984,6 @@ multi_substrate_py_library( ":sample", ":transformed_distribution", # tensorflow dep, - "//tensorflow_probability/python/bijectors:affine_linear_operator", "//tensorflow_probability/python/bijectors:chain", "//tensorflow_probability/python/bijectors:scale_matvec_linear_operator", "//tensorflow_probability/python/bijectors:shift", diff --git a/tensorflow_probability/python/distributions/quantized_distribution.py b/tensorflow_probability/python/distributions/quantized_distribution.py index 2445917357..26059e9301 100644 --- a/tensorflow_probability/python/distributions/quantized_distribution.py +++ b/tensorflow_probability/python/distributions/quantized_distribution.py @@ -188,7 +188,7 @@ class QuantizedDistribution(distributions.Distribution): discretized_logistic_dist = tfd.QuantizedDistribution( distribution=tfd.TransformedDistribution( distribution=tfd.Logistic(loc=loc, scale=scale), - bijector=tfb.AffineScalar(shift=-0.5)), + bijector=tfb.Shift(shift=-0.5)), low=0., high=2**16 - 1.) mixture_dist = tfd.MixtureSameFamily( diff --git a/tensorflow_probability/python/distributions/sample_test.py b/tensorflow_probability/python/distributions/sample_test.py index 1d8c023d71..e8d94409b1 100644 --- a/tensorflow_probability/python/distributions/sample_test.py +++ b/tensorflow_probability/python/distributions/sample_test.py @@ -89,8 +89,7 @@ def test_kl_divergence(self): def test_transformed_affine(self): sample_shape = 3 mvn = tfd.Independent(tfd.Normal(loc=[0., 0], scale=1), 1) - aff = tfb.Affine(scale_tril=[[0.75, 0.], - [0.05, 0.5]]) + aff = tfb.ScaleMatvecTriL(scale_tril=[[0.75, 0.], [0.05, 0.5]]) def expected_lp(y): x = aff.inverse(y) # Ie, tf.random.normal([4, 3, 2]) diff --git a/tensorflow_probability/python/distributions/sinh_arcsinh.py b/tensorflow_probability/python/distributions/sinh_arcsinh.py index c23443cd76..6cc740dbf6 100644 --- a/tensorflow_probability/python/distributions/sinh_arcsinh.py +++ b/tensorflow_probability/python/distributions/sinh_arcsinh.py @@ -19,10 +19,12 @@ from __future__ import print_function import tensorflow.compat.v2 as tf -from tensorflow_probability.python.bijectors import affine_scalar as affine_scalar_bijector from tensorflow_probability.python.bijectors import chain as chain_bijector from tensorflow_probability.python.bijectors import identity as identity_bijector +from tensorflow_probability.python.bijectors import scale as scale_bijector +from tensorflow_probability.python.bijectors import shift as shift_bijector from tensorflow_probability.python.bijectors import sinh_arcsinh as sinh_arcsinh_bijector + from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.distributions import transformed_distribution from tensorflow_probability.python.internal import distribution_util @@ -179,11 +181,8 @@ def __init__(self, validate_args=validate_args) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) - affine = affine_scalar_bijector.AffineScalar( - shift=self._loc, - scale=self._scale, - validate_args=validate_args) - + affine = shift_bijector.Shift(shift=self._loc)( + scale_bijector.Scale(scale=self._scale)) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__( diff --git a/tensorflow_probability/python/distributions/transformed_distribution.py b/tensorflow_probability/python/distributions/transformed_distribution.py index a2e83936cf..771510b87b 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution.py +++ b/tensorflow_probability/python/distributions/transformed_distribution.py @@ -134,9 +134,7 @@ class TransformedDistribution(distribution_lib.Distribution): tfb = tfp.bijectors normal = tfd.TransformedDistribution( distribution=tfd.Normal(loc=0., scale=1.), - bijector=tfb.Affine( - shift=-1., - scale_identity_multiplier=2.) + bijector=tfb.Shift(shift=-1.)(tfb.Scale(scale=2.)), name='NormalTransformedDistribution') ``` diff --git a/tensorflow_probability/python/distributions/vector_exponential_diag_test.py b/tensorflow_probability/python/distributions/vector_exponential_diag_test.py index 141fad12fe..c26ce5baa1 100644 --- a/tensorflow_probability/python/distributions/vector_exponential_diag_test.py +++ b/tensorflow_probability/python/distributions/vector_exponential_diag_test.py @@ -83,8 +83,8 @@ def testAssertValidSample(self): def testSingularScaleRaises(self): mu = [-1., 1] diag = [1., 0] - dist = tfd.VectorExponentialDiag(mu, diag, validate_args=True) with self.assertRaisesOpError('Singular'): + dist = tfd.VectorExponentialDiag(mu, diag, validate_args=True) self.evaluate(dist.sample(seed=test_util.test_seed())) def testSampleWithBroadcastScale(self): diff --git a/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py b/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py index 0f0396f7a9..49ea2c7cf3 100644 --- a/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py +++ b/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py @@ -19,9 +19,8 @@ from __future__ import print_function import tensorflow.compat.v2 as tf -from tensorflow_probability.python.bijectors import affine_linear_operator as affine_linear_operator_bijector from tensorflow_probability.python.bijectors import chain as chain_bijector -from tensorflow_probability.python.bijectors import scale_matvec_linear_operator as scale_matvec_linear_operator_bijector +from tensorflow_probability.python.bijectors import scale_matvec_linear_operator from tensorflow_probability.python.bijectors import shift as shift_bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.distributions import exponential @@ -193,7 +192,8 @@ def __init__(self, loc, name='loc', dtype=scale.dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) - + self._loc = loc + self._scale = scale super(VectorExponentialLinearOperator, self).__init__( # TODO(b/137665504): Use batch-adding meta-distribution to set the # batch shape instead of tf.ones. @@ -206,8 +206,9 @@ def __init__(self, rate=tf.ones(batch_shape, dtype=scale.dtype), allow_nan_stats=allow_nan_stats), event_shape), - bijector=affine_linear_operator_bijector.AffineLinearOperator( - shift=loc, scale=scale, validate_args=validate_args), + bijector=shift_bijector.Shift(shift=loc)( + scale_matvec_linear_operator.ScaleMatvecLinearOperator( + scale=scale, validate_args=validate_args)), validate_args=validate_args, name=name) self._parameters = parameters @@ -215,12 +216,12 @@ def __init__(self, @property def loc(self): """The `loc` `Tensor` in `Y = scale @ X + loc`.""" - return self.bijector.shift + return self._loc @property def scale(self): """The `scale` `LinearOperator` in `Y = scale @ X + loc`.""" - return self.bijector.scale + return self._scale @distribution_util.AppendDocstring(_mvn_sample_note) def _log_prob(self, x): @@ -236,7 +237,7 @@ def _mean(self): # Then this distribution is # X = loc + LW, # and then E[X] = loc + L1, where 1 is the vector of ones. - scale_x_ones = self.bijector.scale.matvec( + scale_x_ones = self.scale.matvec( tf.ones(self._mode_mean_shape(), self.dtype)) if self.loc is None: @@ -279,7 +280,7 @@ def _stddev(self): self.scale.matmul(self.scale.to_dense(), adjoint_arg=True))) def _mode(self): - scale_x_zeros = self.bijector.scale.matvec( + scale_x_zeros = self.scale.matvec( tf.zeros(self._mode_mean_shape(), self.dtype)) if self.loc is None: @@ -311,7 +312,7 @@ def _sample_control_dependencies(self, x): def _default_event_space_bijector(self): return chain_bijector.Chain([ shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args), - scale_matvec_linear_operator_bijector.ScaleMatvecLinearOperator( + scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale=self.scale, validate_args=self.validate_args), softplus_bijector.Softplus(validate_args=self.validate_args) ], validate_args=self.validate_args) diff --git a/tensorflow_probability/python/experimental/mcmc/nuts_autobatching_test.py b/tensorflow_probability/python/experimental/mcmc/nuts_autobatching_test.py index 64aa30a707..ae5125b445 100644 --- a/tensorflow_probability/python/experimental/mcmc/nuts_autobatching_test.py +++ b/tensorflow_probability/python/experimental/mcmc/nuts_autobatching_test.py @@ -303,9 +303,9 @@ def __init__(self, loc, chol_precision_tril, name=None): scale=tf.ones_like(loc)), reinterpreted_batch_ndims=1), bijector=tfb.Chain([ - tfb.Affine(shift=loc), - tfb.Invert(tfb.Affine(scale_tril=chol_precision_tril, - adjoint=True)), + tfb.Shift(shift=loc), + tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=chol_precision_tril, + adjoint=True)), ]), name=name) diff --git a/tensorflow_probability/python/math/psd_kernels/feature_transformed_test.py b/tensorflow_probability/python/math/psd_kernels/feature_transformed_test.py index cc79930532..654dcbd050 100644 --- a/tensorflow_probability/python/math/psd_kernels/feature_transformed_test.py +++ b/tensorflow_probability/python/math/psd_kernels/feature_transformed_test.py @@ -83,7 +83,7 @@ def testValuesAreCorrectScalarTransform(self, feature_ndims, dims): amplitude, length_scale, feature_ndims) input_shape = [dims] * feature_ndims - bij = tfp.bijectors.AffineScalar(self.dtype(0.), self.dtype(2.)) + bij = tfp.bijectors.Scale(scale=self.dtype(2.)) # Flat multiplication by 2. def scale_transform(x, feature_ndims, param_expansion_ndims): del feature_ndims, param_expansion_ndims @@ -114,7 +114,7 @@ def testValuesAreCorrectVectorTransform(self, feature_ndims, dims): input_shape = [dims] * feature_ndims scale_diag = np.random.uniform(-1, 1, size=(dims,)).astype(self.dtype) - bij = tfp.bijectors.Affine(scale_diag=scale_diag) + bij = tfp.bijectors.ScaleMatvecDiag(scale_diag=scale_diag) # Scaling the last dimension. def vector_transform(x, feature_ndims, param_expansion_ndims): diff --git a/tensorflow_probability/python/mcmc/transformed_kernel_test.py b/tensorflow_probability/python/mcmc/transformed_kernel_test.py index 3b25eb8081..6c9b6be1bb 100644 --- a/tensorflow_probability/python/mcmc/transformed_kernel_test.py +++ b/tensorflow_probability/python/mcmc/transformed_kernel_test.py @@ -252,8 +252,8 @@ def target_log_prob(x, y): step_size=[1.23 / 0.75, 1.23 / 0.5], num_leapfrog_steps=2), bijector=[ - tfb.AffineScalar(scale=0.75), - tfb.AffineScalar(scale=0.5), + tfb.Scale(scale=0.75), + tfb.Scale(scale=0.5), ]) # Recall, tfp.mcmc.sample_chain calls # transformed_hmc.bootstrap_results too. @@ -304,7 +304,7 @@ def test_bootstrap_correctly_untransforms(self): def test_copy_works(self): transformed = tfp.mcmc.TransformedTransitionKernel( inner_kernel=FakeInnerKernel(target_log_prob_fn=fake_target_log_prob), - bijector=tfb.AffineScalar(2.)) + bijector=tfb.Scale(2.)) transformed_copy = tfp.mcmc.TransformedTransitionKernel( **transformed.parameters) diff --git a/tensorflow_probability/python/sts/autoregressive.py b/tensorflow_probability/python/sts/autoregressive.py index fdd32f60b2..8752cb40fe 100644 --- a/tensorflow_probability/python/sts/autoregressive.py +++ b/tensorflow_probability/python/sts/autoregressive.py @@ -387,7 +387,7 @@ def __init__(self, coefficients_prior, coefficient_constraining_bijector), Parameter('level_scale', level_scale_prior, - tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()])) ], latent_size=order, diff --git a/tensorflow_probability/python/sts/dynamic_regression.py b/tensorflow_probability/python/sts/dynamic_regression.py index 6ad093ece2..bef1b1e2c8 100644 --- a/tensorflow_probability/python/sts/dynamic_regression.py +++ b/tensorflow_probability/python/sts/dynamic_regression.py @@ -314,7 +314,7 @@ def __init__(self, super(DynamicLinearRegression, self).__init__( parameters=[ Parameter('drift_scale', drift_scale_prior, - tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()])) ], latent_size=num_features, diff --git a/tensorflow_probability/python/sts/local_level.py b/tensorflow_probability/python/sts/local_level.py index bba44f437e..8a7a3f378e 100644 --- a/tensorflow_probability/python/sts/local_level.py +++ b/tensorflow_probability/python/sts/local_level.py @@ -327,7 +327,7 @@ def __init__(self, super(LocalLevel, self).__init__( parameters=[ Parameter('level_scale', level_scale_prior, - tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()])), ], latent_size=1, diff --git a/tensorflow_probability/python/sts/local_linear_trend.py b/tensorflow_probability/python/sts/local_linear_trend.py index 2ccb6cb0a7..61ed35d77e 100644 --- a/tensorflow_probability/python/sts/local_linear_trend.py +++ b/tensorflow_probability/python/sts/local_linear_trend.py @@ -404,7 +404,7 @@ def __init__(self, initial_slope_prior.stddev() ], axis=-1)) - scaled_softplus = tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + scaled_softplus = tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()]) super(LocalLinearTrend, self).__init__( parameters=[ diff --git a/tensorflow_probability/python/sts/seasonal.py b/tensorflow_probability/python/sts/seasonal.py index 0e665450ff..7c0755a384 100644 --- a/tensorflow_probability/python/sts/seasonal.py +++ b/tensorflow_probability/python/sts/seasonal.py @@ -881,7 +881,7 @@ def __init__(self, if allow_drift: parameters.append(Parameter( 'drift_scale', drift_scale_prior, - tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()]))) self._allow_drift = allow_drift diff --git a/tensorflow_probability/python/sts/semilocal_linear_trend.py b/tensorflow_probability/python/sts/semilocal_linear_trend.py index f5f00b689d..acab1d2bd1 100644 --- a/tensorflow_probability/python/sts/semilocal_linear_trend.py +++ b/tensorflow_probability/python/sts/semilocal_linear_trend.py @@ -429,7 +429,7 @@ def __init__(self, else: autoregressive_coef_bijector = tfb.Identity() # unconstrained - stddev_preconditioner = tfb.AffineScalar(scale=observed_stddev) + stddev_preconditioner = tfb.Scale(scale=observed_stddev) scaled_softplus = tfb.Chain([stddev_preconditioner, tfb.Softplus()]) super(SemiLocalLinearTrend, self).__init__( parameters=[ diff --git a/tensorflow_probability/python/sts/smooth_seasonal.py b/tensorflow_probability/python/sts/smooth_seasonal.py index 3c79e6f74a..c18848f0e8 100644 --- a/tensorflow_probability/python/sts/smooth_seasonal.py +++ b/tensorflow_probability/python/sts/smooth_seasonal.py @@ -441,7 +441,7 @@ def __init__(self, if allow_drift: parameters.append(Parameter( 'drift_scale', drift_scale_prior, - tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()]))) self._allow_drift = allow_drift diff --git a/tensorflow_probability/python/sts/sum.py b/tensorflow_probability/python/sts/sum.py index 7fd015b094..96e2f14485 100644 --- a/tensorflow_probability/python/sts/sum.py +++ b/tensorflow_probability/python/sts/sum.py @@ -460,7 +460,7 @@ def __init__(self, parameters = [Parameter('observation_noise_scale', observation_noise_scale_prior, tfb.Chain([ - tfb.AffineScalar(scale=observed_stddev), + tfb.Scale(scale=observed_stddev), tfb.Softplus()]))] for component in components: for parameter in component.parameters: From b49a757f46e4a8c07a383e377f7b5d6841dc797d Mon Sep 17 00:00:00 2001 From: Dave Moore Date: Wed, 18 Nov 2020 13:55:28 -0800 Subject: [PATCH 25/36] Remove support for a long-deprecated trace_fn signature. PiperOrigin-RevId: 343150107 --- .../python/math/minimize.py | 32 ------------------- .../python/math/minimize_test.py | 21 +++--------- .../python/vi/optimization.py | 2 +- .../python/vi/optimization_test.py | 2 +- 4 files changed, 7 insertions(+), 50 deletions(-) diff --git a/tensorflow_probability/python/math/minimize.py b/tensorflow_probability/python/math/minimize.py index b5a5c9f47a..05c1983301 100644 --- a/tensorflow_probability/python/math/minimize.py +++ b/tensorflow_probability/python/math/minimize.py @@ -22,8 +22,6 @@ import tensorflow.compat.v2 as tf -from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import - class MinimizeTraceableQuantities(collections.namedtuple( 'MinimizeTraceableQuantities', @@ -51,33 +49,6 @@ class MinimizeTraceableQuantities(collections.namedtuple( """ -# Backwards compatibility for older `trace_fns` that took separate -# loss, grads, and params. -def _maybe_wrap_old_style_trace_fn(trace_fn): - """Returns a `trace_fn that takes the single `minimizer_state` argument.""" - - def safe_trace_fn(traceable_quantities): - """A `trace_fn that takes the single `minimizer_state` argument.""" - try: - return trace_fn(traceable_quantities) - except TypeError: - deprecated_trace_fn = deprecation.deprecated_args( - '2020-07-01', - 'The signature for `trace_fn`s passed to `minimize` has changed. ' - 'Trace functions now take a single `traceable_quantities` argument, ' - 'which is a `tfp.math.MinimizeTraceableQuantities` namedtuple ' - 'containing `traceable_quantities.loss`, ' - '`traceable_quantities.gradients`, etc. ' - 'Please update your `trace_fn` definition.', - ('loss', 'grads', 'variables') - )(trace_fn) - return deprecated_trace_fn( - traceable_quantities.loss, - traceable_quantities.gradients, - traceable_quantities.parameters) - return safe_trace_fn - - def _tile_last_written_value(trace_array, last_written_idx): last_written_value = trace_array.read(last_written_idx) _, tiled_trace_array = tf.while_loop( @@ -312,8 +283,6 @@ def minimize(loss_fn, """ - trace_fn = _maybe_wrap_old_style_trace_fn(trace_fn) - def convergence_detected(step, trace_arrays, has_converged=None, convergence_criterion_state=None): @@ -379,4 +348,3 @@ def convergence_detected(step, trace_arrays, trace_arrays) return tf.nest.map_structure(lambda array: array.stack(), trace_arrays) - diff --git a/tensorflow_probability/python/math/minimize_test.py b/tensorflow_probability/python/math/minimize_test.py index e8d2ec50a0..b776bd5f33 100644 --- a/tensorflow_probability/python/math/minimize_test.py +++ b/tensorflow_probability/python/math/minimize_test.py @@ -19,24 +19,19 @@ from __future__ import print_function # Dependency imports -from absl.testing import parameterized import numpy as np import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf import tensorflow_probability as tfp - from tensorflow_probability.python.internal import test_util @test_util.test_all_tf_execution_regimes class MinimizeTests(test_util.TestCase): - @parameterized.named_parameters( - {'testcase_name': 'new_style', 'new_style_trace_fn': True}, - {'testcase_name': 'old_style', 'new_style_trace_fn': False}) - def test_custom_trace_fn(self, new_style_trace_fn): + def test_custom_trace_fn(self): init_x = np.array([0., 0.]).astype(np.float32) target_x = np.array([3., 4.]).astype(np.float32) @@ -45,15 +40,9 @@ def test_custom_trace_fn(self, new_style_trace_fn): loss_fn = lambda: tf.reduce_sum((x - target_x)**2) # The trace_fn should determine the structure and values of the results. - if new_style_trace_fn: # Takes a `MinimizerState` namedtuple. - def trace_fn(traceable_quantities): - return {'loss': traceable_quantities.loss, 'x': x, - 'sqdiff': (x - target_x)**2} - else: - def trace_fn(loss, grads, values): # Takes individual args. - del grads - del values - return {'loss': loss, 'x': x, 'sqdiff': (x - target_x)**2} + def trace_fn(traceable_quantities): + return {'loss': traceable_quantities.loss, 'x': x, + 'sqdiff': (x - target_x)**2} results = tfp.math.minimize(loss_fn, num_steps=100, optimizer=tf.optimizers.Adam(0.1), @@ -94,7 +83,7 @@ def test_works_when_results_have_dynamic_shape(self): num_steps=num_steps, # TODO(b/137299119) Replace with TF2 optimizer. optimizer=tf1.train.AdamOptimizer(0.1), - trace_fn=lambda loss, grads, vars: (loss, grads), + trace_fn=lambda t: (t.loss, t.gradients), trainable_variables=[x]) with tf.control_dependencies([losses]): final_x = tf.identity(x) diff --git a/tensorflow_probability/python/vi/optimization.py b/tensorflow_probability/python/vi/optimization.py index db5ad4d7fc..85e01c7f15 100644 --- a/tensorflow_probability/python/vi/optimization.py +++ b/tensorflow_probability/python/vi/optimization.py @@ -24,7 +24,7 @@ from tensorflow_probability.python import math as tfp_math from tensorflow_probability.python.vi import csiszar_divergence -_trace_loss = lambda loss, grads, variables: loss +_trace_loss = lambda traceable_quantities: traceable_quantities.loss # Silent fallback to score-function gradients leads to difficult-to-debug # failures, so we force reparameterization gradients by default. diff --git a/tensorflow_probability/python/vi/optimization_test.py b/tensorflow_probability/python/vi/optimization_test.py index 546d40f020..27951c0b91 100644 --- a/tensorflow_probability/python/vi/optimization_test.py +++ b/tensorflow_probability/python/vi/optimization_test.py @@ -159,7 +159,7 @@ def variational_model_fn(): num_steps=100, seed=test_util.test_seed(), sample_size=1, - trace_fn=lambda loss, grads, variables: (loss, q.sample(seed=42)[0])) + trace_fn=lambda t: (t.loss, q.sample(seed=42)[0])) self.evaluate(tf1.global_variables_initializer()) losses_, sample_path_ = self.evaluate((losses, sample_path)) From 2f313b8b1174594611c6b41b8bb100d637a1d7d0 Mon Sep 17 00:00:00 2001 From: Dave Moore Date: Tue, 15 Dec 2020 11:28:49 -0800 Subject: [PATCH 26/36] Ensure that all traceable quantities in `minimize` are actually traceable. PiperOrigin-RevId: 347655640 --- tensorflow_probability/python/math/minimize.py | 11 ++++++++++- tensorflow_probability/python/math/minimize_test.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/math/minimize.py b/tensorflow_probability/python/math/minimize.py index 05c1983301..ed52bbaebe 100644 --- a/tensorflow_probability/python/math/minimize.py +++ b/tensorflow_probability/python/math/minimize.py @@ -49,6 +49,14 @@ class MinimizeTraceableQuantities(collections.namedtuple( """ +def _sanitize_traced_values(traced_values): + """Represents Python values and `None` as Tensors.""" + return tf.nest.map_structure( + lambda x: (tf.zeros([0], dtype=tf.int32) if x is None # pylint: disable=g-long-lambda + else tf.convert_to_tensor(x)), + traced_values) + + def _tile_last_written_value(trace_array, last_written_idx): last_written_value = trace_array.read(last_written_idx) _, tiled_trace_array = tf.while_loop( @@ -98,7 +106,7 @@ def training_loop_body(step, trace_arrays, has_converged=None, loss=loss, gradients=grads, parameters=parameters, step=step, has_converged=has_converged, convergence_criterion_state=convergence_criterion_state) - traced_values = trace_fn(traceable_quantities) + traced_values = _sanitize_traced_values(trace_fn(traceable_quantities)) trace_arrays = tf.nest.map_structure( lambda ta, x: ta.write(step, x), trace_arrays, traced_values) potential_new_loop_vars = ( @@ -112,6 +120,7 @@ def _initialize_arrays(initial_values, num_steps, truncate_at_convergence): """Construct a structure of `TraceArray`s from initial values.""" + initial_values = _sanitize_traced_values(initial_values) num_steps_ = tf.get_static_value(tf.convert_to_tensor(num_steps)) size_is_dynamic = (num_steps_ is None or truncate_at_convergence) trace_arrays = tf.nest.map_structure( diff --git a/tensorflow_probability/python/math/minimize_test.py b/tensorflow_probability/python/math/minimize_test.py index b776bd5f33..58487b44e0 100644 --- a/tensorflow_probability/python/math/minimize_test.py +++ b/tensorflow_probability/python/math/minimize_test.py @@ -53,6 +53,16 @@ def trace_fn(traceable_quantities): self.assertAllClose(results_['x'][-1], target_x, atol=0.2) self.assertAllClose(results_['sqdiff'][-1], [0., 0.], atol=0.1) + def test_can_trace_all_traceable_quantities(self): + x = tf.Variable(5.0) + trace_fn = lambda traceable_quantities: traceable_quantities + results = tfp.math.minimize(loss_fn=lambda: tf.reduce_sum((x - 1.0)**2), + num_steps=10, + optimizer=tf.optimizers.Adam(0.1), + trace_fn=trace_fn) + self.evaluate(tf1.global_variables_initializer()) + self.evaluate(results) + def test_respects_trainable_variables(self): # Variables not included in `trainable_variables` should stay fixed. x = tf.Variable(5.) From 9d42d9d61d2ab3c16e950a22257285ef24aeba93 Mon Sep 17 00:00:00 2001 From: leben Date: Tue, 15 Dec 2020 11:50:30 -0800 Subject: [PATCH 27/36] Update `tensor_shape.py` for numpy backend. PiperOrigin-RevId: 347660796 --- .../python/internal/backend/numpy/gen/tensor_shape.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py b/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py index b9c53973ff..d918546a0e 100755 --- a/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py @@ -143,7 +143,7 @@ def dimension_value(dimension): value = tensor_shape[i] # Warning: this will return the dim value in V2! ``` - Args: + Arguments: dimension: Either a `Dimension` instance, an integer, or None. Returns: @@ -189,7 +189,7 @@ def dimension_at_index(shape, index): # instantiated on the fly. ``` - Args: + Arguments: shape: A TensorShape instance. index: An integer index. From 23d93539770a0fa89dbb6a797ccebbb4974f96ad Mon Sep 17 00:00:00 2001 From: Christopher Suter Date: Tue, 15 Dec 2020 14:22:35 -0800 Subject: [PATCH 28/36] Avoid feeding `None`s to tf.control_dependencies in mixed eager/graph contexts. Add a property-based test that tries to create Distributions in eager mode then sample from them in graph mode, to exercise the failure mode. PiperOrigin-RevId: 347692243 --- .../python/distributions/distribution.py | 7 ++++++ .../distribution_properties_test.py | 23 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/tensorflow_probability/python/distributions/distribution.py b/tensorflow_probability/python/distributions/distribution.py index 9ec0f3129e..964efb2e39 100644 --- a/tensorflow_probability/python/distributions/distribution.py +++ b/tensorflow_probability/python/distributions/distribution.py @@ -1555,6 +1555,13 @@ def _name_and_control_scope(self, name=None, value=UNSET_VALUE, kwargs=None): if not deps: yield name_scope return + # In eager mode, some `assert_util.assert_xyz` calls return None. If a + # Distribution is created in eager mode with `validate_args=True`, then + # used in a `tf.function` context, it can result in errors when + # `tf.convert_to_tensor` is called on the inputs to + # `tf.control_dependencies` below. To avoid these errors, we drop the + # `None`s here. + deps = [x for x in deps if x is not None] with tf.control_dependencies(deps) as deps_scope: yield deps_scope diff --git a/tensorflow_probability/python/distributions/distribution_properties_test.py b/tensorflow_probability/python/distributions/distribution_properties_test.py index f06d7c8198..53c65b4384 100644 --- a/tensorflow_probability/python/distributions/distribution_properties_test.py +++ b/tensorflow_probability/python/distributions/distribution_properties_test.py @@ -526,6 +526,29 @@ def disabled_testFailureCase(self): # pylint: disable=invalid-name self.assertAllClose(dist.log_prob(samps)[0], dist[0].log_prob(samps[0])) +# Don't decorate with test_util.test_all_tf_execution_regimes, since we're +# explicitly mixing modes. +class TestMixingGraphAndEagerModes(test_util.TestCase): + + @parameterized.named_parameters( + {'testcase_name': dname, 'dist_name': dname} + for dname in sorted(list(dhps.INSTANTIABLE_BASE_DISTS.keys()) + + list(dhps.INSTANTIABLE_META_DISTS)) + ) + @hp.given(hps.data()) + @tfp_hps.tfp_hp_settings() + def testSampleEagerCreatedDistributionInGraphMode(self, dist_name, data): + if not tf.executing_eagerly(): + self.skipTest('Only test mixed eager/graph behavior in eager tests.') + # Create in eager mode. + dist = data.draw(dhps.distributions(dist_name=dist_name, enable_vars=False)) + + @tf.function + def f(): + dist.sample() + f() + + if __name__ == '__main__': # Hypothesis often finds numerical near misses. Debugging them is much aided # by seeing all the digits of every floating point number, instead of the From a06f0d85f7b70a05b56966512a7d5c6898670b84 Mon Sep 17 00:00:00 2001 From: Dave Moore Date: Tue, 15 Dec 2020 15:58:34 -0800 Subject: [PATCH 29/36] Avoid Shift(None) bug in `VectorExponentialLinearOperator`. PiperOrigin-RevId: 347710953 --- .../python/distributions/vector_exponential_linear_operator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py b/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py index 49ea2c7cf3..e24f2b2211 100644 --- a/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py +++ b/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py @@ -180,6 +180,8 @@ def __init__(self, TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) + if loc is None: + loc = 0.0 # Implicit value for backwards compatibility. if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): From 876d51299cf9313b1d7c53b0bc55fae121f14b91 Mon Sep 17 00:00:00 2001 From: bjp Date: Wed, 16 Dec 2020 07:56:53 -0800 Subject: [PATCH 30/36] Adds `tfp.experimental.bijectors.inverse_log_det_jacobian_ratio`. Adds `tfp.experimental.distributions.log_prob_ratio(p, x, q, y) = p(x) - q(y)`. Custom implementations are registered for `tfd.Independent`, `tfd.Sample`, `tfd.JointDistribution*`, `tfd.TransformedDistribution`, `tfb.Chain`, and `tfb.ScaleMatvecDiag`. MVNDiag is tested as a proof of concept, in transformed_distribution_test. PiperOrigin-RevId: 347822296 --- tensorflow_probability/python/bijectors/BUILD | 10 +++ .../python/bijectors/chain.py | 14 +++ .../python/bijectors/ldj_ratio.py | 86 +++++++++++++++++++ .../python/bijectors/scale_matvec_diag.py | 9 ++ .../python/distributions/BUILD | 15 ++++ .../python/distributions/independent.py | 37 ++++++-- .../python/distributions/independent_test.py | 35 +++++++- .../distributions/joint_distribution.py | 13 +++ .../python/distributions/log_prob_ratio.py | 63 ++++++++++++++ .../python/distributions/sample.py | 36 ++++++-- .../python/distributions/sample_test.py | 66 +++++++++++++- .../distributions/transformed_distribution.py | 21 +++++ .../transformed_distribution_test.py | 37 ++++++++ .../python/experimental/bijectors/BUILD | 1 + .../python/experimental/bijectors/__init__.py | 4 +- .../python/experimental/distributions/BUILD | 1 + .../experimental/distributions/__init__.py | 2 + .../internal/backend/numpy/numpy_array.py | 2 +- 18 files changed, 432 insertions(+), 20 deletions(-) create mode 100644 tensorflow_probability/python/bijectors/ldj_ratio.py create mode 100644 tensorflow_probability/python/distributions/log_prob_ratio.py diff --git a/tensorflow_probability/python/bijectors/BUILD b/tensorflow_probability/python/bijectors/BUILD index 0df4e10958..88d13d4b31 100644 --- a/tensorflow_probability/python/bijectors/BUILD +++ b/tensorflow_probability/python/bijectors/BUILD @@ -75,6 +75,7 @@ multi_substrate_py_library( ":joint_map", ":kumaraswamy_cdf", ":lambertw_transform", + ":ldj_ratio", ":masked_autoregressive", ":matrix_inverse_tril", ":moyal_cdf", @@ -260,6 +261,7 @@ multi_substrate_py_library( srcs = ["scale_matvec_diag.py"], deps = [ ":bijector", + ":ldj_ratio", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", "//tensorflow_probability/python/internal:dtype_util", @@ -581,6 +583,14 @@ multi_substrate_py_library( ], ) +multi_substrate_py_library( + name = "ldj_ratio", + srcs = ["ldj_ratio.py"], + deps = [ + # tensorflow dep, + ], +) + multi_substrate_py_library( name = "masked_autoregressive", srcs = ["masked_autoregressive.py"], diff --git a/tensorflow_probability/python/bijectors/chain.py b/tensorflow_probability/python/bijectors/chain.py index 69aa55ef9e..a8d9ff2db2 100644 --- a/tensorflow_probability/python/bijectors/chain.py +++ b/tensorflow_probability/python/bijectors/chain.py @@ -22,6 +22,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.bijectors import composition +from tensorflow_probability.python.bijectors import ldj_ratio from tensorflow_probability.python.internal import nest_util from tensorflow_probability.python.internal import prefer_static as ps @@ -245,3 +246,16 @@ def update_i_event_ndims(bij, event_ndims): return (nest.map_structure(lambda nd: rolling_offset + nd, f_event_ndims), nest.map_structure(lambda nd: rolling_offset + nd, i_event_ndims)) + +@ldj_ratio.RegisterILDJRatio(Chain) +def _ildj_ratio_chain(p, x, q, y): + """Sum-of-diffs ILDJRatio for Chains.""" + if len(p.bijectors) != len(q.bijectors): + raise ValueError('Mismatched lengths of bijectors: `p` has ' + f'{len(p.bijectors)} but `q` has {len(q.bijectors)}.') + ratios = [] + for p, q in zip(p.bijectors, q.bijectors): + ratios.append(ldj_ratio.inverse_log_det_jacobian_ratio( + p, x, q, y, p.inverse_min_event_ndims)) + x, y = p.inverse(x), q.inverse(y) + return tf.add_n(ratios) diff --git a/tensorflow_probability/python/bijectors/ldj_ratio.py b/tensorflow_probability/python/bijectors/ldj_ratio.py new file mode 100644 index 0000000000..d0c8f5c35e --- /dev/null +++ b/tensorflow_probability/python/bijectors/ldj_ratio.py @@ -0,0 +1,86 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Computes log-ratios of Jacobian determinants numerically stably.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect + +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python import math as tfp_math +from tensorflow_probability.python.internal import prefer_static as ps + +__all__ = [ + 'inverse_log_det_jacobian_ratio', + 'RegisterILDJRatio', +] + + +_ildj_ratio_registry = {} + + +def inverse_log_det_jacobian_ratio(p, x, q, y, event_ndims, use_kahan_sum=True): + """Computes `p.ildj(x, ndims) - q.idlj(y, ndims)`, numerically stably. + + Args: + p: A bijector instance. + x: A tensor from the support of `p.forward`. + q: A bijector instance of the same type as `p`, with matching shape. + y: A tensor from the support of `q.forward`. + event_ndims: The number of right-hand dimensions comprising the event shapes + of `x` and `y`. + use_kahan_sum: When `True`, the reduction of any remaining `event_ndims` + beyond the minimum is done using Kahan summation. This requires statically + known ranks. + + Returns: + ildj_ratio: `log ((abs o det o jac p^-1)(x) / (abs o det o jac q^-1)(y))`, + i.e. in TFP code, `p.inverse_log_det_jacobian(x, event_ndims) - + q.inverse_log_det_jacobian(y, event_ndims)`. In some cases + this will be computed with better than naive numerical precision, e.g. by + moving differences inside of a sum reduction. + """ + assert type(p) == type(q) # pylint: disable=unidiomatic-typecheck + + min_event_ndims = p.inverse_min_event_ndims + def ildj_ratio_fn(p, x, q, y): + return (p.inverse_log_det_jacobian(x, event_ndims=min_event_ndims) - + q.inverse_log_det_jacobian(y, event_ndims=min_event_ndims)) + + for cls in inspect.getmro(type(p)): + if cls in _ildj_ratio_registry: + ildj_ratio_fn = _ildj_ratio_registry[cls] + + if use_kahan_sum: + sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis=axis).total + else: + sum_fn = tf.reduce_sum + return sum_fn(ildj_ratio_fn(p, x, q, y), + axis=-1 - ps.range(event_ndims - min_event_ndims)) + + +class RegisterILDJRatio(object): + + def __init__(self, bijector_class): + self.cls = bijector_class + + def __call__(self, fn): + assert self.cls not in _ildj_ratio_registry + _ildj_ratio_registry[self.cls] = fn + return fn + diff --git a/tensorflow_probability/python/bijectors/scale_matvec_diag.py b/tensorflow_probability/python/bijectors/scale_matvec_diag.py index 37f22758e8..50a36373ba 100644 --- a/tensorflow_probability/python/bijectors/scale_matvec_diag.py +++ b/tensorflow_probability/python/bijectors/scale_matvec_diag.py @@ -20,6 +20,7 @@ import tensorflow.compat.v2 as tf +from tensorflow_probability.python.bijectors import ldj_ratio from tensorflow_probability.python.bijectors import scale_matvec_linear_operator from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import tensor_util @@ -115,3 +116,11 @@ def _composite_tensor_nonshape_params(self): those that are shape-related. """ return ('scale_diag',) + + +@ldj_ratio.RegisterILDJRatio(ScaleMatvecDiag) +def _ildj_ratio_scale_matvec_diag(p, x, q, y): + del x, y + return tf.math.reduce_sum(tf.math.log(tf.math.abs(q.scale.diag_part())) - + tf.math.log(tf.math.abs(p.scale.diag_part())), + axis=-1) diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index b3cad34ec2..fdcd195176 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -97,6 +97,7 @@ multi_substrate_py_library( ":laplace", ":linear_gaussian_ssm", ":lkj", + ":log_prob_ratio", ":logistic", ":logitnormal", ":loglogistic", @@ -874,6 +875,7 @@ multi_substrate_py_library( deps = [ ":distribution", ":kullback_leibler", + ":log_prob_ratio", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", "//tensorflow_probability/python/internal:prefer_static", @@ -952,6 +954,7 @@ multi_substrate_py_library( srcs = ["joint_distribution.py"], deps = [ ":distribution", + ":log_prob_ratio", # numpy dep, # six dep, # tensorflow dep, @@ -1164,6 +1167,14 @@ multi_substrate_py_library( ], ) +multi_substrate_py_library( + name = "log_prob_ratio", + srcs = ["log_prob_ratio.py"], + deps = [ + # tensorflow dep, + ], +) + multi_substrate_py_library( name = "logistic", srcs = ["logistic.py"], @@ -1718,6 +1729,7 @@ multi_substrate_py_library( deps = [ ":distribution", ":kullback_leibler", + ":log_prob_ratio", # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", @@ -1867,7 +1879,9 @@ multi_substrate_py_library( deps = [ ":distribution", ":kullback_leibler", + ":log_prob_ratio", # tensorflow dep, + "//tensorflow_probability/python/bijectors:ldj_ratio", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensorshape_util", ], @@ -3376,6 +3390,7 @@ multi_substrate_py_test( name = "sample_test", srcs = ["sample_test.py"], jax_size = "medium", + shard_count = 2, deps = [ # absl/testing:parameterized dep, # numpy dep, diff --git a/tensorflow_probability/python/distributions/independent.py b/tensorflow_probability/python/distributions/independent.py index 18f7cc4731..2d6f083d98 100644 --- a/tensorflow_probability/python/distributions/independent.py +++ b/tensorflow_probability/python/distributions/independent.py @@ -25,8 +25,9 @@ from tensorflow_probability.python import math as tfp_math from tensorflow_probability.python.distributions import distribution as distribution_lib from tensorflow_probability.python.distributions import kullback_leibler +from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import prefer_static +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.internal import tensorshape_util @@ -199,7 +200,7 @@ def __getitem__(self, slices): def _batch_shape_tensor(self): batch_shape = self.distribution.batch_shape_tensor() - batch_ndims = prefer_static.rank_from_shape( + batch_ndims = ps.rank_from_shape( batch_shape, self.distribution.batch_shape) return batch_shape[ :batch_ndims - self._get_reinterpreted_batch_ndims(batch_shape)] @@ -220,11 +221,11 @@ def _event_shape_tensor(self): batch_shape = self.distribution.batch_shape if not tensorshape_util.is_fully_defined(batch_shape): batch_shape = self.distribution.batch_shape_tensor() - batch_ndims = prefer_static.rank_from_shape(batch_shape) + batch_ndims = ps.rank_from_shape(batch_shape) event_shape = self.distribution.event_shape if not tensorshape_util.is_fully_defined(event_shape): event_shape = self.distribution.event_shape_tensor() - return prefer_static.concat([ + return ps.concat([ batch_shape[ batch_ndims - self._get_reinterpreted_batch_ndims(batch_shape):], event_shape, @@ -297,13 +298,13 @@ def _parameter_control_dependencies(self, is_init): assertions.append( assert_util.assert_less_equal( self._get_reinterpreted_batch_ndims(batch_shape_tensor), - prefer_static.rank_from_shape(batch_shape_tensor), + ps.rank_from_shape(batch_shape_tensor), message=('reinterpreted_batch_ndims cannot exceed ' 'distribution.batch_ndims'))) return assertions def _reduce(self, op, stat): - axis = 1 + prefer_static.range(self._get_reinterpreted_batch_ndims()) + axis = 1 + ps.range(self._get_reinterpreted_batch_ndims()) return op(stat, axis=-axis) _composite_tensor_nonshape_params = ('distribution',) @@ -372,10 +373,28 @@ def _kl_independent(a, b, name='kl_independent'): message='Event shapes do not match.'), ]): num_reduce_dims = ( - prefer_static.rank_from_shape( + ps.rank_from_shape( a_event_shape_tensor, a.event_shape) - - prefer_static.rank_from_shape( + ps.rank_from_shape( p_event_shape_tensor, p.event_shape)) - reduce_dims = prefer_static.range(-num_reduce_dims, 0, 1) + reduce_dims = ps.range(-num_reduce_dims, 0, 1) return tf.reduce_sum( kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) + + +@log_prob_ratio.RegisterLogProbRatio(Independent) +def _independent_log_prob_ratio(p, x, q, y): + """Sum-of-diffs log(p(x)/q(y)) for `Independent`s.""" + checks = [] + if p.validate_args or q.validate_args: + checks.append(tf.debugging.assert_equal( + p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)) + if p._experimental_use_kahan_sum or q._experimental_use_kahan_sum: # pylint: disable=protected-access + sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total + else: + sum_fn = tf.reduce_sum + with tf.control_dependencies(checks): + return sum_fn( + log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y), + axis=-1 - ps.range(p.reinterpreted_batch_ndims)) + diff --git a/tensorflow_probability/python/distributions/independent_test.py b/tensorflow_probability/python/distributions/independent_test.py index 9959641ca6..2924bf0da5 100644 --- a/tensorflow_probability/python/distributions/independent_test.py +++ b/tensorflow_probability/python/distributions/independent_test.py @@ -27,11 +27,12 @@ from scipy import stats as sp_stats import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf -from tensorflow_probability.python import distributions as tfd +import tensorflow_probability as tfp from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util +tfd = tfp.distributions JAX_MODE = False @@ -522,6 +523,38 @@ def test_kahan_precision(self, jit=False): # Fails ~75% CPU, 1-75% GPU --vary_seed runs w/o experimental_use_kahan_sum. self.assertAllClose(lp64, lp, rtol=0., atol=.01) + def testLargeLogProbDiff(self): + b = 15 + n = 5_000 + d0 = tfd.Independent(tfd.Normal(tf.fill([b, n], 0.), tf.fill([n], .1)), + reinterpreted_batch_ndims=1, + experimental_use_kahan_sum=True) + d1 = tfd.Independent(tfd.Normal(tf.fill([b, n], 1e-5), tf.fill([n], .1)), + reinterpreted_batch_ndims=1, + experimental_use_kahan_sum=True) + strm = test_util.test_seed_stream() + x0 = self.evaluate( # overdispersed + tfd.Normal(0, 2).sample([b, n], seed=strm())) + x1 = self.evaluate( # overdispersed, perturbed + x0 + tfd.Normal(0, 1e-6).sample(x0.shape, seed=strm())) + d0_64 = d0.copy(distribution=tfd.Normal( + tf.cast(d0.distribution.loc, tf.float64), + tf.cast(d0.distribution.scale, tf.float64))) + d1_64 = d1.copy(distribution=tfd.Normal( + tf.cast(d1.distribution.loc, tf.float64), + tf.cast(d1.distribution.scale, tf.float64))) + self.assertNotAllZero(d0.log_prob(x0) < -1_000_000) + self.assertAllClose( + d0_64.log_prob(tf.cast(x0, tf.float64)) - + d1_64.log_prob(tf.cast(x1, tf.float64)), + tfp.experimental.distributions.log_prob_ratio(d0, x0, d1, x1), + rtol=0., atol=0.0075) + # In contrast: the below fails consistently w/ errors around 0.5-1.0 + # self.assertAllClose( + # d0_64.log_prob(tf.cast(x0, tf.float64)) - + # d1_64.log_prob(tf.cast(x1, tf.float64)), + # d0.log_prob(x0) - d1.log_prob(x1), + # rtol=0., atol=0.007) if __name__ == '__main__': # TODO(b/173158845): XLA:CPU reassociates away the Kahan correction term. diff --git a/tensorflow_probability/python/distributions/joint_distribution.py b/tensorflow_probability/python/distributions/joint_distribution.py index d7d8082a8e..c0ed4d6706 100644 --- a/tensorflow_probability/python/distributions/joint_distribution.py +++ b/tensorflow_probability/python/distributions/joint_distribution.py @@ -28,6 +28,7 @@ from tensorflow_probability.python.bijectors import composition from tensorflow_probability.python.bijectors import identity as identity_bijector from tensorflow_probability.python.distributions import distribution as distribution_lib +from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import docstring_util @@ -810,3 +811,15 @@ def _inverse(self, y, **kwargs): def _inverse_log_det_jacobian(self, y, event_ndims, **kwargs): return super(_DefaultJointBijector, self)._inverse_log_det_jacobian( y, event_ndims, _jd_conditioning=y, **kwargs) + + +@log_prob_ratio.RegisterLogProbRatio(JointDistribution) +def _jd_log_prob_ratio(p, x, q, y): + tf.nest.assert_same_structure(x, y) + ps, _ = p.sample_distributions(value=x) + qs, _ = q.sample_distributions(value=y) + tf.nest.assert_same_structure(ps, qs) + parts = [] + for p_, x_, q_, y_ in zip(ps, x, qs, y): + parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_)) + return tf.add_n(parts) diff --git a/tensorflow_probability/python/distributions/log_prob_ratio.py b/tensorflow_probability/python/distributions/log_prob_ratio.py new file mode 100644 index 0000000000..214b0353e3 --- /dev/null +++ b/tensorflow_probability/python/distributions/log_prob_ratio.py @@ -0,0 +1,63 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Computes log-ratios of probs numerically stably.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect + + +__all__ = [ + 'log_prob_ratio', + 'RegisterLogProbRatio', +] + + +_log_prob_ratio_registry = {} + + +def log_prob_ratio(p, x, q, y): + """Computes `p.log_prob(x) - q.log_prob(y)`, numerically stably. + + Args: + p: A distribution instance. + x: A tensor from the support of `p`. + q: A distribution instance in the same family as `p`, with matching shape. + y: A tensor from the support of `q`. + + Returns: + lp_ratio: `log (p(x) / q(y)) = p.log_prob(x) - q.log_prob(y)`. In some cases + this will be computed with better than naive numerical precision, e.g. by + moving the difference inside of a sum reduction. + """ + assert type(p) == type(q) # pylint: disable=unidiomatic-typecheck + for cls in inspect.getmro(type(p)): + if cls in _log_prob_ratio_registry: + return _log_prob_ratio_registry[cls](p, x, q, y) + return p.log_prob(x) - q.log_prob(y) + + +class RegisterLogProbRatio(object): + + def __init__(self, dist_family): + self.family = dist_family + + def __call__(self, fn): + assert self.family not in _log_prob_ratio_registry + _log_prob_ratio_registry[self.family] = fn + return fn + diff --git a/tensorflow_probability/python/distributions/sample.py b/tensorflow_probability/python/distributions/sample.py index 27bfcfa434..d0f1d6d240 100644 --- a/tensorflow_probability/python/distributions/sample.py +++ b/tensorflow_probability/python/distributions/sample.py @@ -29,6 +29,7 @@ from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.distributions import distribution as distribution_lib from tensorflow_probability.python.distributions import kullback_leibler +from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps @@ -236,7 +237,7 @@ def _sum_fn(self): return lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total return tf.math.reduce_sum - def _log_prob(self, x, **kwargs): + def _prepare_for_underlying(self, x): batch_ndims = ps.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) @@ -266,10 +267,12 @@ def _log_prob(self, x, **kwargs): ndims) perm = ps.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) - x = tf.transpose(a=x, perm=perm) - # (3) Compute x's log_prob. - lp = self.distribution.log_prob(x, **kwargs) - # (4) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has + x = tf.transpose(x, perm=perm) + return x, (sample_ndims, extra_sample_ndims, batch_ndims) + + def _finish_log_prob(self, lp, aux): + (sample_ndims, extra_sample_ndims, batch_ndims) = aux + # (1) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has # full sample shape in the sample axes, before we reduce. bcast_lp_shape = ps.broadcast_shape( ps.shape(lp), @@ -277,10 +280,16 @@ def _log_prob(self, x, **kwargs): ps.reshape(self.sample_shape, shape=[-1]), ps.ones([batch_ndims], tf.int32)], axis=0)) lp = tf.broadcast_to(lp, bcast_lp_shape) - # (5) Make the final reduction in x. + # (2) Make the final reduction. axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) return self._sum_fn()(lp, axis=axis) + def _log_prob(self, x, **kwargs): + x, aux = self._prepare_for_underlying(x) + return self._finish_log_prob( + self.distribution.log_prob(x, **kwargs), + aux) + def _entropy(self, **kwargs): h = self.distribution.entropy(**kwargs) n = ps.reduce_prod(self.sample_shape) @@ -544,3 +553,18 @@ def _kl_sample(a, b, name='kl_sample'): a.distribution, b.distribution, name=name) n = ps.reduce_prod(a.sample_shape) return tf.cast(x=n, dtype=kl.dtype) * kl + + +@log_prob_ratio.RegisterLogProbRatio(Sample) +def _sample_log_prob_ratio(p, x, q, y): + checks = [] + if p.validate_args or q.validate_args: + checks.append(tf.debugging.assert_equal(p.sample_shape, q.sample_shape)) + with tf.control_dependencies(checks): + # pylint: disable=protected-access + x, aux = p._prepare_for_underlying(x) + y, _ = q._prepare_for_underlying(y) + return p._finish_log_prob( + log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y), + aux) + # pylint: enable=protected-access diff --git a/tensorflow_probability/python/distributions/sample_test.py b/tensorflow_probability/python/distributions/sample_test.py index e8d94409b1..5a40fd9df9 100644 --- a/tensorflow_probability/python/distributions/sample_test.py +++ b/tensorflow_probability/python/distributions/sample_test.py @@ -25,10 +25,12 @@ from absl.testing import parameterized import numpy as np import tensorflow.compat.v2 as tf -from tensorflow_probability.python import bijectors as tfb -from tensorflow_probability.python import distributions as tfd +import tensorflow_probability as tfp from tensorflow_probability.python.internal import test_util +tfb = tfp.bijectors +tfd = tfp.distributions + JAX_MODE = False @@ -447,6 +449,66 @@ def test_kahan_precision(self, jit=False): # Fails 75% CPU, 0-80% GPU --vary_seed runs w/o experimental_use_kahan_sum. self.assertAllClose(lp64, lp, rtol=0., atol=.01) + def testLargeLogProbDiffScalarUnderlying(self): + shp = [25, 200] + d0 = tfd.Sample(tfd.Normal(0., .1), shp) + d1 = tfd.Sample(tfd.Normal(1e-5, .1), shp) + strm = test_util.test_seed_stream() + x0 = self.evaluate( # overdispersed + tfd.Normal(0, 2).sample(shp, seed=strm())) + x1 = self.evaluate( # overdispersed, perturbed + x0 + tfd.Normal(0, 1e-6).sample(x0.shape, seed=strm())) + d0_64 = d0.copy(distribution=tfd.Normal( + tf.cast(d0.distribution.loc, tf.float64), + tf.cast(d0.distribution.scale, tf.float64))) + d1_64 = d1.copy(distribution=tfd.Normal( + tf.cast(d1.distribution.loc, tf.float64), + tf.cast(d1.distribution.scale, tf.float64))) + oracle_64 = tf.reduce_sum( + d0_64.distribution.log_prob(tf.cast(x0, tf.float64)) - + d1_64.distribution.log_prob(tf.cast(x1, tf.float64))) + self.assertAllClose( + oracle_64, + tfp.experimental.distributions.log_prob_ratio(d0, x0, d1, x1), + rtol=0., atol=0.007) + # In contrast: below fails with errors of ~0.07 - 0.15 + # self.assertAllClose( + # oracle_64, d0.log_prob(x0) - d1.log_prob(x1), rtol=0., atol=0.007) + + def testLargeLogProbDiffBatchOfVecUnderlying(self): + nsamp = 5 + nbatch = 3 + nevt = 250 + dim = 500 + d0 = tfd.Sample(tfd.MultivariateNormalDiag(tf.fill([nbatch, dim], 0.), + tf.fill([dim], .1)), + sample_shape=nevt) + self.assertEqual(tf.float32, d0.dtype) + d1 = tfd.Sample(tfd.MultivariateNormalDiag(tf.fill([nbatch, dim], 1e-5), + d0.distribution.scale.diag), + sample_shape=nevt) + strm = test_util.test_seed_stream() + x0 = self.evaluate( # overdispersed + tfd.Normal(0, 2).sample([nsamp, nbatch, nevt, dim], seed=strm())) + x1 = self.evaluate( # overdispersed + perturbed + x0 + tfd.Normal(0, 1e-6).sample(x0.shape, seed=strm())) + d0_64 = d0.copy(distribution=tfd.MultivariateNormalDiag( + tf.cast(d0.distribution.loc, tf.float64), + tf.cast(d0.distribution.scale.diag, tf.float64))) + d1_64 = d1.copy(distribution=tfd.MultivariateNormalDiag( + tf.cast(d1.distribution.loc, tf.float64), + tf.cast(d1.distribution.scale.diag, tf.float64))) + oracle_64 = (d0_64.log_prob(tf.cast(x0, tf.float64)) - + d1_64.log_prob(tf.cast(x1, tf.float64))) + self.assertNotAllZero(d0.log_prob(x0) < -10_000_000) + self.assertAllClose( + oracle_64, + tfp.experimental.distributions.log_prob_ratio(d0, x0, d1, x1), + rtol=0., atol=0.045) + # In contrast, the following fails w/ abs errors of ~5. to 10. + # self.assertAllClose( + # oracle_64, d0.log_prob(x0) - d1.log_prob(x1), rtol=0., atol=0.045) + if __name__ == '__main__': # TODO(b/173158845): XLA:CPU reassociates away the Kahan correction term. diff --git a/tensorflow_probability/python/distributions/transformed_distribution.py b/tensorflow_probability/python/distributions/transformed_distribution.py index 771510b87b..a0f0d208db 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution.py +++ b/tensorflow_probability/python/distributions/transformed_distribution.py @@ -21,8 +21,10 @@ import tensorflow.compat.v2 as tf +from tensorflow_probability.python.bijectors import ldj_ratio from tensorflow_probability.python.distributions import distribution as distribution_lib from tensorflow_probability.python.distributions import kullback_leibler +from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util @@ -569,3 +571,22 @@ def _kl_transformed_transformed(a, b, name=None): 'Unable to calculate KL divergence between {} and {} because ' 'their bijectors are not equal: {} vs. {}'.format( a, b, a.bijector, b.bijector)) + + +@log_prob_ratio.RegisterLogProbRatio(TransformedDistribution) +def _transformed_log_prob_ratio(p, x, q, y): + """Computes p.log_prob(x) - q.log_prob(y) for p and q both TDs.""" + x_ = p.bijector.inverse(x) + y_ = q.bijector.inverse(y) + + base_log_prob_ratio = log_prob_ratio.log_prob_ratio( + p.distribution, x_, q.distribution, y_) + + event_ndims = tf.nest.map_structure( + ps.rank_from_shape, + p.event_shape_tensor, + tf.nest.map_structure(tensorshape_util.merge_with, + p.event_shape, q.event_shape)) + ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( + p.bijector, x, q.bijector, y, event_ndims) + return base_log_prob_ratio + tf.cast(ildj_ratio, base_log_prob_ratio.dtype) diff --git a/tensorflow_probability/python/distributions/transformed_distribution_test.py b/tensorflow_probability/python/distributions/transformed_distribution_test.py index 712c09fe3b..66bf95f0a4 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution_test.py +++ b/tensorflow_probability/python/distributions/transformed_distribution_test.py @@ -477,6 +477,42 @@ def testTransformedNormalNormalKL(self): self.assertAllClose(kl_val, kl_expected) self.assertAllClose(kl_expected, kl_sample_, atol=0.0, rtol=1e-2) + def testLogProbRatio(self): + nsamp = 5 + nbatch = 3 + dim = 5000 + d0 = tfd.MultivariateNormalDiag(tf.fill([nbatch, dim], 0.), + tf.fill([dim], .1)) + d1 = tfd.MultivariateNormalDiag(tf.fill([nbatch, dim], 1e-5), + d0.scale.diag) + strm = test_util.test_seed_stream() + x0 = self.evaluate( # overdispersed + tfd.Normal(0, 2).sample([nsamp, nbatch, dim], seed=strm())) + x1 = self.evaluate( # overdispersed + perturbed + x0 + tfd.Normal(0, 1e-6).sample(x0.shape, seed=strm())) + d0_64 = tfd.MultivariateNormalDiag( + tf.cast(d0.loc, tf.float64), tf.cast(d0.scale.diag, tf.float64)) + d1_64 = tfd.MultivariateNormalDiag( + tf.cast(d1.loc, tf.float64), tf.cast(d1.scale.diag, tf.float64)) + oracle_64 = (d0_64.log_prob(tf.cast(x0, tf.float64)) - + d1_64.log_prob(tf.cast(x1, tf.float64))) + # For a sense of the order of magnitude log_probs we're dealing with: + self.assertNotAllZero(d0.log_prob(x0) < -1_000_000.) + self.assertAllClose( + oracle_64, + tfp.experimental.distributions.log_prob_ratio(d0, x0, d1, x1), + rtol=0., atol=0.007) + # In contrast, this test fails with max-abs-error around 0.05 to 0.1 + # self.assertAllClose( + # oracle_64, + # d0.copy(experimental_use_kahan_sum=True).log_prob(x0) - + # d1.copy(experimental_use_kahan_sum=True).log_prob(x1), + # rtol=0., atol=0.007) + # In contrast, this test fails with max-abs-error around 0.8 to 1.5 + # self.assertAllClose( + # oracle_64, d0.log_prob(x0) - d1.log_prob(x1), + # rtol=0., atol=0.007) + @test_util.test_all_tf_execution_regimes class ScalarToMultiTest(test_util.TestCase): @@ -1079,5 +1115,6 @@ def test_transform_joint_to_joint(self, split_sizes): noop_assert_fn, self.evaluate(restructured_dist.sample(seed=test_util.test_seed()))) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/experimental/bijectors/BUILD b/tensorflow_probability/python/experimental/bijectors/BUILD index 96ed131f8a..fd4076247d 100644 --- a/tensorflow_probability/python/experimental/bijectors/BUILD +++ b/tensorflow_probability/python/experimental/bijectors/BUILD @@ -36,6 +36,7 @@ multi_substrate_py_library( srcs_version = "PY3", deps = [ ":scalar_function_with_inferred_inverse", + "//tensorflow_probability/python/bijectors:ldj_ratio", ], ) diff --git a/tensorflow_probability/python/experimental/bijectors/__init__.py b/tensorflow_probability/python/experimental/bijectors/__init__.py index 81650620f0..baf5f4e9d7 100644 --- a/tensorflow_probability/python/experimental/bijectors/__init__.py +++ b/tensorflow_probability/python/experimental/bijectors/__init__.py @@ -14,8 +14,10 @@ # ============================================================================ """TensorFlow Probability experimental bijectors package.""" +from tensorflow_probability.python.bijectors.ldj_ratio import inverse_log_det_jacobian_ratio from tensorflow_probability.python.experimental.bijectors.scalar_function_with_inferred_inverse import ScalarFunctionWithInferredInverse __all__ = [ - 'ScalarFunctionWithInferredInverse' + 'inverse_log_det_jacobian_ratio', + 'ScalarFunctionWithInferredInverse', ] diff --git a/tensorflow_probability/python/experimental/distributions/BUILD b/tensorflow_probability/python/experimental/distributions/BUILD index edc7375c7b..cd74aa6788 100644 --- a/tensorflow_probability/python/experimental/distributions/BUILD +++ b/tensorflow_probability/python/experimental/distributions/BUILD @@ -37,6 +37,7 @@ multi_substrate_py_library( deps = [ ":joint_distribution_pinned", ":mvn_precision_factor_linop", + "//tensorflow_probability/python/distributions:log_prob_ratio", ], ) diff --git a/tensorflow_probability/python/experimental/distributions/__init__.py b/tensorflow_probability/python/experimental/distributions/__init__.py index 0bab103142..2334ac9767 100644 --- a/tensorflow_probability/python/experimental/distributions/__init__.py +++ b/tensorflow_probability/python/experimental/distributions/__init__.py @@ -18,11 +18,13 @@ from __future__ import division from __future__ import print_function +from tensorflow_probability.python.distributions.log_prob_ratio import log_prob_ratio from tensorflow_probability.python.experimental.distributions.joint_distribution_pinned import JointDistributionPinned from tensorflow_probability.python.experimental.distributions.mvn_precision_factor_linop import MultivariateNormalPrecisionFactorLinearOperator __all__ = [ + 'log_prob_ratio', 'JointDistributionPinned', 'MultivariateNormalPrecisionFactorLinearOperator', ] diff --git a/tensorflow_probability/python/internal/backend/numpy/numpy_array.py b/tensorflow_probability/python/internal/backend/numpy/numpy_array.py index 51ba5b2c98..0af0c81ffd 100644 --- a/tensorflow_probability/python/internal/backend/numpy/numpy_array.py +++ b/tensorflow_probability/python/internal/backend/numpy/numpy_array.py @@ -384,7 +384,7 @@ def _zeros_like(input, dtype=None, name=None): # pylint: disable=redefined-buil fill = utils.copy_docstring( 'tf.fill', - lambda dims, value, name=None: np.full(dims, value)) + lambda dims, value, name=None: np.full(dims, ops.convert_to_tensor(value))) gather = utils.copy_docstring( 'tf.gather', From e7d64b77f6d40070d1070d73678b28ec5c6f6de5 Mon Sep 17 00:00:00 2001 From: bjp Date: Wed, 16 Dec 2020 09:02:43 -0800 Subject: [PATCH 31/36] Add a default bijector to the Deterministic distributions. Also fixes an XLA compilation issue (no StringFormat, PrintV2 ops in XLA) introduced by the logdet degree-of-freedom warning. This was exposed by the new test in deterministic_test.py (Deterministic default bijector uses Chain, which uses composition.py). PiperOrigin-RevId: 347832464 --- .../python/bijectors/composition.py | 9 +++ .../python/bijectors/pad.py | 61 +++++++++---------- .../python/distributions/deterministic.py | 10 ++- .../distributions/deterministic_test.py | 48 +++++++++++++++ .../python/internal/prefer_static.py | 3 +- 5 files changed, 97 insertions(+), 34 deletions(-) diff --git a/tensorflow_probability/python/bijectors/composition.py b/tensorflow_probability/python/bijectors/composition.py index d4fd398e28..bab364d60f 100644 --- a/tensorflow_probability/python/bijectors/composition.py +++ b/tensorflow_probability/python/bijectors/composition.py @@ -21,6 +21,7 @@ import abc import sys +import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector @@ -28,6 +29,7 @@ from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import nest_util from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow.python.ops import control_flow_util # pylint: disable=g-direct-tensorflow-import from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import @@ -36,6 +38,9 @@ ] +JAX_MODE = False + + def pack_structs_like(template, *structures): """Converts a tuple of structs like `template` to a structure of tuples.""" if not structures: @@ -491,6 +496,10 @@ def _maybe_warn_increased_dof(self, raise ValueError(error_message) return assert_util.assert_equal(False, increased_dof, error_message) + if (not tf.executing_eagerly() and + control_flow_util.GraphOrParentsInXlaContext(tf1.get_default_graph())): + return # No StringFormat or Print ops in XLA. + # Otherwise, we print a warning and continue. return ps.cond( pred=increased_dof, diff --git a/tensorflow_probability/python/bijectors/pad.py b/tensorflow_probability/python/bijectors/pad.py index 714944b02a..765d121d6b 100644 --- a/tensorflow_probability/python/bijectors/pad.py +++ b/tensorflow_probability/python/bijectors/pad.py @@ -23,7 +23,7 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import prefer_static +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.internal import tensorshape_util @@ -124,14 +124,14 @@ def __init__(self, parameters = dict(locals()) with tf.name_scope(name or 'pad') as name: paddings = tensor_util.convert_nonref_to_tensor( - paddings, dtype_hint=tf.int32, name='paddings') + paddings, dtype_hint=tf.int32, name='paddings', as_shape_tensor=True) if axis is None: - axis = prefer_static.range( - start=-prefer_static.size0(paddings), limit=0, + axis = ps.range( + start=-ps.size0(paddings), limit=0, dtype=tf.int32, name='axis') else: axis = tensor_util.convert_nonref_to_tensor( - axis, dtype_hint=tf.int32, name='axis') + axis, dtype_hint=tf.int32, name='axis', as_shape_tensor=True) axis_ = tf.get_static_value(axis) if axis_ is None: raise NotImplementedError( @@ -170,29 +170,27 @@ def axis(self): return self._axis def _forward(self, x): - ndims = prefer_static.rank(x) - indices = prefer_static.reshape(prefer_static.add(self.axis, ndims), - shape=[-1, 1]) + ndims = ps.rank(x) + indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) return tf.pad( x, - paddings=prefer_static.tensor_scatter_nd_update( - prefer_static.zeros([ndims, 2], dtype=tf.int32), + paddings=ps.tensor_scatter_nd_update( + ps.zeros([ndims, 2], dtype=tf.int32), indices, self.paddings), mode=self.mode, - constant_values=prefer_static.cast(self.constant_values, dtype=x.dtype)) + constant_values=ps.cast(self.constant_values, dtype=x.dtype)) def _inverse(self, y): - ndims = prefer_static.rank(y) - indices = prefer_static.reshape(prefer_static.add(self.axis, ndims), - shape=[-1, 1]) - num_left, num_right = prefer_static.unstack(self.paddings, num=2, axis=-1) + ndims = ps.rank(y) + indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) + num_left, num_right = ps.unstack(self.paddings, num=2, axis=-1) x = tf.slice( y, - begin=prefer_static.tensor_scatter_nd_update( - prefer_static.zeros(ndims, dtype=tf.int32), + begin=ps.tensor_scatter_nd_update( + ps.zeros(ndims, dtype=tf.int32), indices, num_left), - size=prefer_static.tensor_scatter_nd_sub( - prefer_static.shape(y), + size=ps.tensor_scatter_nd_sub( + ps.shape(y), indices, num_left + num_right)) if not self.validate_args: return x @@ -225,13 +223,12 @@ def _forward_event_shape(self, input_shape, is_inverse=False): return output_shape def _forward_event_shape_tensor(self, input_shape, is_inverse=False): - ndims = prefer_static.size(input_shape) - indices = prefer_static.reshape(prefer_static.add(self.axis, ndims), - shape=[-1, 1]) - extra_sizes = prefer_static.reduce_sum(self.paddings, axis=-1) - update_fn = (prefer_static.tensor_scatter_nd_sub if is_inverse else - prefer_static.tensor_scatter_nd_add) - return update_fn(prefer_static.identity(input_shape), indices, extra_sizes) + ndims = ps.size(input_shape) + indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) + extra_sizes = ps.reduce_sum(self.paddings, axis=-1) + update_fn = (ps.tensor_scatter_nd_sub if is_inverse else + ps.tensor_scatter_nd_add) + return update_fn(ps.identity(input_shape), indices, extra_sizes) def _inverse_event_shape(self, output_shape): input_shape = self._forward_event_shape(output_shape, is_inverse=True) @@ -284,8 +281,8 @@ def _parameter_control_dependencies(self, is_init): elif self.validate_args: if axis is None: axis = tf.convert_to_tensor(self.axis) assertions.append(assert_util.assert_equal( - prefer_static.size0(axis), - prefer_static.size0(prefer_static.setdiff1d(axis)), + ps.size0(axis), + ps.size0(ps.setdiff1d(axis)), message=msg)) if is_init != tensor_util.is_ref(self.paddings): @@ -320,19 +317,19 @@ def _parameter_control_dependencies(self, is_init): axis_ = tf.get_static_value(self.axis) if axis_ is None and axis is None: axis = tf.convert_to_tensor(self.axis) - len_axis = prefer_static.size0(prefer_static.reshape( + len_axis = ps.size0(ps.reshape( axis if axis_ is None else axis_, shape=-1)) paddings_ = tf.get_static_value(self.paddings) if paddings_ is None and paddings is None: paddings = tf.convert_to_tensor(self.paddings) - len_paddings = prefer_static.size0( + len_paddings = ps.size0( paddings if paddings_ is None else paddings_) msg = ('Arguments `axis` and `paddings` must have the same number ' 'of elements.') - if (prefer_static.is_numpy(len_axis) and - prefer_static.is_numpy(len_paddings)): + if (ps.is_numpy(len_axis) and + ps.is_numpy(len_paddings)): if len_axis != len_paddings: raise ValueError(msg + ' Saw: {}, {}.'.format( self.axis, self.paddings)) diff --git a/tensorflow_probability/python/distributions/deterministic.py b/tensorflow_probability/python/distributions/deterministic.py index eb4b0df3ed..c2ac05e9e1 100644 --- a/tensorflow_probability/python/distributions/deterministic.py +++ b/tensorflow_probability/python/distributions/deterministic.py @@ -24,6 +24,7 @@ import six import tensorflow.compat.v2 as tf +from tensorflow_probability.python import bijectors as tfb from tensorflow_probability.python.distributions import distribution from tensorflow_probability.python.distributions import kullback_leibler from tensorflow_probability.python.internal import assert_util @@ -154,7 +155,14 @@ def _sample_n(self, n, seed=None): axis=0)) def _default_event_space_bijector(self): - return + """The bijector maps a zero-dimensional null Tensor input to `self.loc`.""" + # The shape of the pulled back null tensor will be `self.loc.shape + (0,)`. + # First we pad to a tensor of zeros with shape `self.loc.shape + (1,)`. + pad_zero = tfb.Pad([(1, 0)]) + # Next, we squeeze to a tensor of zeros with shape matching `self.loc`. + zeros_squeezed = tfb.Reshape([], event_shape_in=[1])(pad_zero) + # Finally, we shift the zeros by `self.loc`. + return tfb.Shift(self.loc)(zeros_squeezed) def _parameter_control_dependencies(self, is_init): assertions = [] diff --git a/tensorflow_probability/python/distributions/deterministic_test.py b/tensorflow_probability/python/distributions/deterministic_test.py index df63b9c391..50e31ee92f 100644 --- a/tensorflow_probability/python/distributions/deterministic_test.py +++ b/tensorflow_probability/python/distributions/deterministic_test.py @@ -17,11 +17,13 @@ from __future__ import print_function # Dependency imports +from absl.testing import parameterized import numpy as np import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util rng = np.random.RandomState(0) @@ -255,6 +257,7 @@ def testVariableAssertions(self): self.evaluate(deterministic.log_prob(1.)) +@test_util.test_all_tf_execution_regimes class VectorDeterministicTest(test_util.TestCase): def testParamBroadcasts(self): @@ -454,6 +457,51 @@ def testVariableAssertions(self): 'Condition x >= 0'): self.evaluate(deterministic.log_prob([1.])) + @parameterized.named_parameters( + dict(testcase_name='_scalar', + dist_fn=lambda: tfd.Deterministic(3.)), + dict(testcase_name='_batch_scalar', + dist_fn=lambda: tfd.Deterministic([3., -7.])), + dict(testcase_name='_vector', + dist_fn=lambda: tfd.VectorDeterministic([3., -7.])), + dict(testcase_name='_batch_vector', + dist_fn=lambda: tfd.VectorDeterministic([[3., -7.], [-2, 4.]]))) + def testDefaultBijector(self, dist_fn): + dist = dist_fn() + bijector = dist.experimental_default_event_space_bijector() + self.assertEqual(dist.loc.shape, dist.batch_shape + dist.event_shape) + self.assertEqual(dist.event_shape + (0,), + bijector.inverse_event_shape(dist.event_shape)) + self.assertEqual(dist.loc.shape + (0,), + bijector.inverse_event_shape(dist.loc.shape)) + null_point = tf.ones(bijector.inverse_event_shape(dist.loc.shape)) + self.assertAllEqual( + tf.zeros([]), + bijector.forward_log_det_jacobian( + null_point, tensorshape_util.rank(null_point.shape))) + self.assertAllEqual(dist.loc, bijector(null_point)) + + @parameterized.named_parameters( + dict(testcase_name='_scalar', + dist_fn=lambda: tfd.Deterministic(3.)), + dict(testcase_name='_batch_scalar', + dist_fn=lambda: tfd.Deterministic([3., -7.])), + dict(testcase_name='_vector', + dist_fn=lambda: tfd.VectorDeterministic([3., -7.])), + dict(testcase_name='_batch_vector', + dist_fn=lambda: tfd.VectorDeterministic([[3., -7.], [-2, 4.]]))) + def testDefaultBijectorXLA(self, dist_fn): + self.skip_if_no_xla() + @tf.function(experimental_compile=True) + def fn(x): + bijector = dist_fn().experimental_default_event_space_bijector() + ndim = tensorshape_util.rank(x.shape) + return (bijector(x), + bijector.forward_log_det_jacobian(x, ndim), + bijector.inverse(0 + bijector(x)), + bijector.inverse_log_det_jacobian(0 + bijector(x), ndim - 1)) + self.evaluate(fn(tf.zeros(dist_fn().loc.shape + (0,)))) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/internal/prefer_static.py b/tensorflow_probability/python/internal/prefer_static.py index b1d30bbd0d..d719eb8875 100644 --- a/tensorflow_probability/python/internal/prefer_static.py +++ b/tensorflow_probability/python/internal/prefer_static.py @@ -137,7 +137,8 @@ def _convert_to_shape_tensor_jax(value, dtype=None, dtype_hint=None, name=None): """Converts vectors and scalars of `int`-like to `ndarray`.""" dtype = dtype_util.as_numpy_dtype(dtype or dtype_hint or np.int32) try: - return np.array([int(v) for v in value], dtype=dtype) + return np.array([_convert_to_shape_tensor_jax(v, dtype) for v in value], + dtype=dtype) except: # JAX throws raw Exception in some cases. # pylint: disable=bare-except pass return np.array(int(value), dtype=dtype) From d3bf5a09913df88f1d7d202f49f0c7652cf94059 Mon Sep 17 00:00:00 2001 From: axch Date: Wed, 16 Dec 2020 09:27:05 -0800 Subject: [PATCH 32/36] Rewrite experimental sample_chain in terms of run_kernel. Implement burn-in by sequencing two calls of run_kernel, and thinning by inserting a ThinningKernel into the kernel onion. Also change the default tracing function to account for the fact that the chain state history is no longer returned separately by default. Delete tests that no longer make sense. PiperOrigin-RevId: 347836706 --- .../python/experimental/mcmc/BUILD | 2 + .../python/experimental/mcmc/__init__.py | 4 +- .../python/experimental/mcmc/run.py | 1 - .../python/experimental/mcmc/sample_fold.py | 114 +++---- .../experimental/mcmc/sample_fold_test.py | 280 +++++------------- 5 files changed, 122 insertions(+), 279 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index 639f4d2976..3ca9add236 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -663,8 +663,10 @@ py_library( srcs = ["sample_fold.py"], srcs_version = "PY3", deps = [ + ":run", ":sample", ":sample_discarding_kernel", + ":thinning_kernel", ":tracing_reducer", ":with_reductions", # numpy dep, diff --git a/tensorflow_probability/python/experimental/mcmc/__init__.py b/tensorflow_probability/python/experimental/mcmc/__init__.py index dc8c6c6d7f..3dd3ea1e93 100644 --- a/tensorflow_probability/python/experimental/mcmc/__init__.py +++ b/tensorflow_probability/python/experimental/mcmc/__init__.py @@ -44,7 +44,7 @@ from tensorflow_probability.python.experimental.mcmc.run import run_kernel from tensorflow_probability.python.experimental.mcmc.sample import step_kernel from tensorflow_probability.python.experimental.mcmc.sample_discarding_kernel import SampleDiscardingKernel -from tensorflow_probability.python.experimental.mcmc.sample_fold import sample_chain +from tensorflow_probability.python.experimental.mcmc.sample_fold import sample_chain_with_burnin from tensorflow_probability.python.experimental.mcmc.sample_fold import sample_fold from tensorflow_probability.python.experimental.mcmc.sample_sequential_monte_carlo import default_make_hmc_kernel_fn from tensorflow_probability.python.experimental.mcmc.sample_sequential_monte_carlo import gen_make_hmc_kernel_fn @@ -108,7 +108,7 @@ 'resample_stratified', 'resample_systematic', 'run_kernel', - 'sample_chain', + 'sample_chain_with_burnin', 'sample_fold', 'sample_sequential_monte_carlo', 'SampleDiscardingKernel', diff --git a/tensorflow_probability/python/experimental/mcmc/run.py b/tensorflow_probability/python/experimental/mcmc/run.py index 32db8c0077..d5b8f78b74 100644 --- a/tensorflow_probability/python/experimental/mcmc/run.py +++ b/tensorflow_probability/python/experimental/mcmc/run.py @@ -161,7 +161,6 @@ def run_kernel( Default value: `None` (i.e., 'mcmc_run_kernel'). Returns: - result: A `RunKernelResults` instance containing information about the sampling run. Main fields are `trace`, the history of outputs of `trace_fn`, and `reduction_results`, the final outputs of all supplied diff --git a/tensorflow_probability/python/experimental/mcmc/sample_fold.py b/tensorflow_probability/python/experimental/mcmc/sample_fold.py index ea82ac8275..9f8af741b4 100644 --- a/tensorflow_probability/python/experimental/mcmc/sample_fold.py +++ b/tensorflow_probability/python/experimental/mcmc/sample_fold.py @@ -22,16 +22,17 @@ # Dependency imports import tensorflow.compat.v2 as tf +from tensorflow_probability.python import random +from tensorflow_probability.python.experimental.mcmc import run from tensorflow_probability.python.experimental.mcmc import sample as exp_sample_lib from tensorflow_probability.python.experimental.mcmc import sample_discarding_kernel -from tensorflow_probability.python.experimental.mcmc import tracing_reducer +from tensorflow_probability.python.experimental.mcmc import thinning_kernel from tensorflow_probability.python.experimental.mcmc import with_reductions -from tensorflow_probability.python.mcmc import sample from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import __all__ = [ - 'sample_chain', + 'sample_chain_with_burnin', 'sample_fold', ] @@ -126,19 +127,19 @@ def sample_fold( if reducer is None: reducer = [] reducer_was_none = True - thinning_kernel = sample_discarding_kernel.SampleDiscardingKernel( + thinning_k = sample_discarding_kernel.SampleDiscardingKernel( inner_kernel=kernel, num_burnin_steps=num_burnin_steps, num_steps_between_results=num_steps_between_results) reduction_kernel = with_reductions.WithReductions( - inner_kernel=thinning_kernel, + inner_kernel=thinning_k, reducer=reducer, # Strip thinning kernel results layer adjust_kr_fn=lambda kr: kr.inner_results, ) if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) - thinning_pkr = thinning_kernel.bootstrap_results( + thinning_pkr = thinning_k.bootstrap_results( current_state, previous_kernel_results) reduction_pkr = reduction_kernel.bootstrap_results( current_state, thinning_pkr, previous_reducer_state) @@ -176,20 +177,19 @@ def sample_fold( final_kernel_results.inner_results.inner_results) -def _trace_kernel_results(current_state, kernel_results): - del current_state - return kernel_results +def _trace_current_state(current_state, kernel_results): + del kernel_results + return current_state -def sample_chain( +def sample_chain_with_burnin( num_results, current_state, previous_kernel_results=None, kernel=None, num_burnin_steps=0, num_steps_between_results=0, - trace_fn=_trace_kernel_results, - return_final_kernel_results=False, + trace_fn=_trace_current_state, parallel_iterations=10, seed=None, name=None, @@ -216,9 +216,8 @@ def sample_chain( In addition to returning the chain state, this function supports tracing of auxiliary variables used by the kernel. The traced values are selected by - specifying `trace_fn`. By default, all kernel results are traced but in the - future the default will be changed to no results being traced, so plan - accordingly. See below for some examples of this feature. + specifying `trace_fn`. By default, all chain states but no kernel results are + traced. Args: num_results: Integer number of Markov chain draws. @@ -239,27 +238,17 @@ def sample_chain( trace_fn: A callable that takes in the current chain state and the previous kernel results and return a `Tensor` or a nested collection of `Tensor`s that is then traced along with the chain state. - return_final_kernel_results: If `True`, then the final kernel results are - returned alongside the chain state and the trace specified by the - `trace_fn`. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: Optional, a seed for reproducible sampling. name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., 'experimental_mcmc_sample_chain'). + Default value: `None` (i.e., + 'experimental_mcmc_sample_chain_with_burnin'). Returns: - checkpointable_states_and_trace: if `return_final_kernel_results` is - `True`. The return value is an instance of - `CheckpointableStatesAndTrace`. - all_states: if `return_final_kernel_results` is `False` and `trace_fn` is - `None`. The return value is a `Tensor` or Python list of `Tensor`s - representing the state(s) of the Markov chain(s) at each result step. Has - same shape as input `current_state` but with a prepended - `num_results`-size dimension. - states_and_trace: if `return_final_kernel_results` is `False` and - `trace_fn` is not `None`. The return value is an instance of - `StatesAndTrace`. + result: A `RunKernelResults` instance containing information about the + sampling run. Main field is `trace`, the history of outputs of + `trace_fn`. See `RunKernelResults` for contents of other fields. #### References @@ -267,51 +256,42 @@ def sample_chain( _Technical Report_, 2017. http://statweb.stanford.edu/~owen/reports/bestthinning.pdf """ - with tf.name_scope(name or 'experimental_mcmc_sample_chain'): + with tf.name_scope(name or 'experimental_mcmc_sample_chain_with_burnin'): if not kernel.is_calibrated: warnings.warn('supplied `TransitionKernel` is not calibrated. Markov ' 'chain may not converge to intended target distribution.') if trace_fn is None: trace_fn = lambda *args: () - no_trace = True - else: - no_trace = False - - if trace_fn is sample_chain.__defaults__[4]: - warnings.warn('Tracing all kernel results by default is deprecated. Set ' - 'the `trace_fn` argument to None (the future default ' - 'value) or an explicit callback that traces the values ' - 'you are interested in.') - - def real_trace_fn(curr_state, kr): - return curr_state, trace_fn(curr_state, kr) - trace_reducer = tracing_reducer.TracingReducer( - trace_fn=real_trace_fn, - size=num_results - ) - # pylint: disable=unbalanced-tuple-unpacking - trace_results, _, final_kernel_results = sample_fold( - num_steps=num_results, + + burnin_seed, sampling_seed = random.split_seed(seed, n=2) + + # Burn-in run + chain_state, kr = exp_sample_lib.step_kernel( + num_steps=num_burnin_steps, current_state=current_state, previous_kernel_results=previous_kernel_results, kernel=kernel, - reducer=trace_reducer, - num_burnin_steps=num_burnin_steps, - num_steps_between_results=num_steps_between_results, + return_final_kernel_results=True, parallel_iterations=parallel_iterations, - seed=seed, - name=name, - ) + seed=burnin_seed, + name='burnin') - all_states, trace = trace_results - if return_final_kernel_results: - return sample.CheckpointableStatesAndTrace( - all_states=all_states, - trace=trace, - final_kernel_results=final_kernel_results) - else: - if no_trace: - return all_states - else: - return sample.StatesAndTrace(all_states=all_states, trace=trace) + thinning_k = thinning_kernel.ThinningKernel( + kernel, num_steps_to_skip=num_steps_between_results) + + # ThinningKernel doesn't wrap the kernel_results structure, so we don't need + # any of the usual munging. + results = run.run_kernel( + num_results=num_results, + current_state=chain_state, + previous_kernel_results=kr, + kernel=thinning_k, + trace_fn=trace_fn, + parallel_iterations=parallel_iterations, + seed=sampling_seed, + name='sampling') + + del results.resume_kwargs['reducer'] + del results.resume_kwargs['previous_reducer_state'] + return results diff --git a/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py b/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py index bbbbee0952..0779b0c112 100644 --- a/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py @@ -44,11 +44,9 @@ def test_simple_operation(self): num_steps=5, current_state=0., kernel=fake_kernel, - reducer=fake_reducer, - ) + reducer=fake_reducer) reduction_rslt, last_sample, kernel_results = self.evaluate([ - reduction_rslt, last_sample, kr - ]) + reduction_rslt, last_sample, kr]) self.assertEqual(3, reduction_rslt) self.assertEqual(5, last_sample) self.assertEqual(5, kernel_results.counter_1) @@ -60,11 +58,9 @@ def test_simple_operation(self): current_state=last_sample, kernel=fake_kernel, reducer=fake_reducer, - previous_kernel_results=kernel_results, - ) + previous_kernel_results=kernel_results) reduction_rslt_2, last_sample_2, kernel_results_2 = self.evaluate([ - reduction_rslt_2, last_sample_2, kr_2 - ]) + reduction_rslt_2, last_sample_2, kr_2]) self.assertEqual(8, reduction_rslt_2) self.assertEqual(10, last_sample_2) self.assertEqual(10, kernel_results_2.counter_1) @@ -78,11 +74,9 @@ def test_reducer_warm_restart(self): current_state=0., kernel=fake_kernel, reducer=fake_reducer, - return_final_reducer_states=True, - ) + return_final_reducer_states=True) red_res, last_sample, kernel_results, red_states = self.evaluate([ - red_res, last_sample, kr, red_states - ]) + red_res, last_sample, kr, red_states]) self.assertEqual(3, red_res) self.assertEqual(5, last_sample) self.assertEqual(5, kernel_results.counter_1) @@ -95,11 +89,9 @@ def test_reducer_warm_restart(self): previous_kernel_results=kernel_results, kernel=fake_kernel, reducer=fake_reducer, - previous_reducer_state=red_states - ) + previous_reducer_state=red_states) reduction_rslt_2, last_sample_2, kernel_results_2 = self.evaluate([ - reduction_rslt_2, last_sample_2, kr_2 - ]) + reduction_rslt_2, last_sample_2, kr_2]) self.assertEqual(5.5, reduction_rslt_2) self.assertEqual(10, last_sample_2) self.assertEqual(10, kernel_results_2.counter_1) @@ -113,11 +105,9 @@ def test_current_state(self, curr_state): num_steps=5, current_state=curr_state, kernel=fake_kernel, - reducer=fake_reducer, - ) + reducer=fake_reducer) reduction_rslt, last_sample, kernel_results = self.evaluate([ - reduction_rslt, last_sample, kr - ]) + reduction_rslt, last_sample, kr]) self.assertEqual( np.mean(np.arange(curr_state + 1, curr_state + 6)), reduction_rslt) self.assertEqual(curr_state + 5, last_sample) @@ -136,11 +126,9 @@ def reduction_target(current_state, kernel_results): num_steps=5, current_state=0., kernel=kernel, - reducer=reduction, - ) + reducer=reduction) reduction_rslt, last_sample, kernel_results = self.evaluate([ - reduction_rslt, last_sample, kr - ]) + reduction_rslt, last_sample, kr]) self.assertEqual(np.mean(np.arange(2, 12, 2)), reduction_rslt) self.assertEqual(5, last_sample) self.assertEqual(5, kernel_results.counter_1) @@ -151,17 +139,14 @@ def test_nested_reducers(self): fake_reducers = [ [test_fixtures.NaiveMeanReducer(), tfp.experimental.mcmc.CovarianceReducer()], - [test_fixtures.NaiveMeanReducer()] - ] + [test_fixtures.NaiveMeanReducer()]] reduction_rslt, last_sample, kr = tfp.experimental.mcmc.sample_fold( num_steps=3, current_state=0., kernel=fake_kernel, - reducer=fake_reducers, - ) + reducer=fake_reducers) reduction_rslt, last_sample, kernel_results = self.evaluate([ - reduction_rslt, last_sample, kr - ]) + reduction_rslt, last_sample, kr]) self.assertEqual(2, len(reduction_rslt)) self.assertEqual(2, len(reduction_rslt[0])) self.assertEqual(1, len(reduction_rslt[1])) @@ -196,8 +181,7 @@ def test_batched_streaming_covariance(self): current_state=tf.convert_to_tensor( [[0., 0., 0.], [0., 0., 0.]]), kernel=fake_kernel, - reducer=cov_reducer, - ) + reducer=cov_reducer) reduction_rslt = self.evaluate(reduction_rslt) self.assertEqual((2, 3, 3), reduction_rslt.shape) self.assertAllEqual(np.ones(reduction_rslt.shape) * 2, reduction_rslt) @@ -212,18 +196,15 @@ def test_seed_reproducibility(self): current_state=0., kernel=fake_kernel, reducer=fake_reducer, - seed=seed - ) + seed=seed) second_reduction_rslt, _, _ = tfp.experimental.mcmc.sample_fold( num_steps=3, current_state=0., kernel=fake_kernel, reducer=fake_reducer, - seed=seed - ) + seed=seed) first_reduction_rslt, second_reduction_rslt = self.evaluate([ - first_reduction_rslt, second_reduction_rslt - ]) + first_reduction_rslt, second_reduction_rslt]) self.assertEqual(first_reduction_rslt, second_reduction_rslt) def test_thinning_and_burnin(self): @@ -235,13 +216,11 @@ def test_thinning_and_burnin(self): kernel=fake_kernel, reducer=fake_reducer, num_burnin_steps=10, - num_steps_between_results=1, - ) + num_steps_between_results=1) reduction_rslt, last_sample, kernel_results = self.evaluate([ reduction_rslt, last_sample, - kr - ]) + kr]) self.assertEqual(16, reduction_rslt) self.assertEqual(20, last_sample) self.assertEqual( @@ -258,13 +237,11 @@ def test_tensor_thinning_and_burnin(self): kernel=fake_kernel, reducer=fake_reducer, num_burnin_steps=tf.convert_to_tensor(10), - num_steps_between_results=tf.convert_to_tensor(1), - ) + num_steps_between_results=tf.convert_to_tensor(1)) reduction_rslt, last_sample, kernel_results = self.evaluate([ reduction_rslt, last_sample, - kr - ]) + kr]) self.assertEqual(16, reduction_rslt) self.assertEqual(20, last_sample) self.assertEqual( @@ -280,11 +257,9 @@ def test_none_reducer(self): kernel=fake_kernel, reducer=None, num_burnin_steps=10, - num_steps_between_results=1, - ) + num_steps_between_results=1) last_sample, kernel_results = self.evaluate([ - last_sample, kr - ]) + last_sample, kr]) self.assertIsNone(reduction_rslt) self.assertEqual(20, last_sample) self.assertEqual(20, kernel_results.counter_1) @@ -298,11 +273,9 @@ def test_empty_reducer(self): kernel=fake_kernel, reducer=[], num_burnin_steps=10, - num_steps_between_results=1, - ) + num_steps_between_results=1) last_sample, kernel_results = self.evaluate([ - last_sample, kr - ]) + last_sample, kr]) self.assertEqual([], reduction_rslt) self.assertEqual(20, last_sample) self.assertEqual(20, kernel_results.counter_1) @@ -319,200 +292,93 @@ def setUp(self): def test_basic_operation(self): kernel = test_fixtures.TestTransitionKernel() - samples, kernel_results, final_results = tfp.experimental.mcmc.sample_chain( + result = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, current_state=0., kernel=kernel, - return_final_kernel_results=True, seed=test_util.test_seed()) + samples = result.trace + kernel_results = result.final_kernel_results self.assertAllClose( [2], tensorshape_util.as_list(samples.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_2.shape)) samples, kernel_results = self.evaluate([samples, kernel_results]) self.assertAllClose([1, 2], samples) - self.assertAllClose([1, 2], kernel_results.counter_1) - self.assertAllClose([2, 4], kernel_results.counter_2) + self.assertAllClose(2, kernel_results.counter_1) + self.assertAllClose(4, kernel_results.counter_2) # Warm-restart the underlying kernel. The Trace does not support warm # restart. - samples_2, kr_2 = tfp.experimental.mcmc.sample_chain( + result_2 = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, - current_state=samples[-1], - previous_kernel_results=final_results, - kernel=kernel, - ) - samples_2, kernel_results_2 = self.evaluate([samples_2, kr_2]) + **result.resume_kwargs) + samples_2, kernel_results_2 = self.evaluate( + [result_2.trace, result_2.final_kernel_results]) self.assertAllClose([3, 4], samples_2) - self.assertAllClose([3, 4], kernel_results_2.counter_1) - self.assertAllClose([6, 8], kernel_results_2.counter_2) - - def test_basic_operation_legacy(self): - kernel = test_fixtures.TestTransitionKernel(accepts_seed=False) - samples, kernel_results = tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel) - - self.assertAllClose( - [2], tensorshape_util.as_list(samples.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_2.shape)) - - samples, kernel_results = self.evaluate([samples, kernel_results]) - self.assertAllClose([1, 2], samples) - self.assertAllClose([1, 2], kernel_results.counter_1) - self.assertAllClose([2, 4], kernel_results.counter_2) + self.assertAllClose(4, kernel_results_2.counter_1) + self.assertAllClose(8, kernel_results_2.counter_2) def test_burn_in(self): kernel = test_fixtures.TestTransitionKernel() - samples, kernel_results = tfp.experimental.mcmc.sample_chain( + result = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, current_state=0., kernel=kernel, num_burnin_steps=1, seed=test_util.test_seed()) + samples = result.trace + kernel_results = result.final_kernel_results self.assertAllClose([2], tensorshape_util.as_list(samples.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_2.shape)) samples, kernel_results = self.evaluate([samples, kernel_results]) self.assertAllClose([2, 3], samples) - self.assertAllClose([2, 3], kernel_results.counter_1) - self.assertAllClose([4, 6], kernel_results.counter_2) + self.assertAllClose(3, kernel_results.counter_1) + self.assertAllClose(6, kernel_results.counter_2) def test_thinning(self): kernel = test_fixtures.TestTransitionKernel() - samples, kernel_results = tfp.experimental.mcmc.sample_chain( + result = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, current_state=0., kernel=kernel, num_steps_between_results=2, seed=test_util.test_seed()) + samples = result.trace + kernel_results = result.final_kernel_results self.assertAllClose([2], tensorshape_util.as_list(samples.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_2.shape)) samples, kernel_results = self.evaluate([samples, kernel_results]) self.assertAllClose([3, 6], samples) - self.assertAllClose([3, 6], kernel_results.counter_1) - self.assertAllClose([6, 12], kernel_results.counter_2) - - def test_default_trace_named_tuple(self): - kernel = test_fixtures.TestTransitionKernel() - res = tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel, - seed=test_util.test_seed()) - - self.assertAllClose([2], tensorshape_util.as_list(res.all_states.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(res.trace.counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(res.trace.counter_2.shape)) - - res = self.evaluate(res) - self.assertAllClose([1, 2], res.all_states) - self.assertAllClose([1, 2], res.trace.counter_1) - self.assertAllClose([2, 4], res.trace.counter_2) - - def test_no_trace_fn(self): - kernel = test_fixtures.TestTransitionKernel() - samples = tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel, - trace_fn=None, - seed=test_util.test_seed()) - self.assertAllClose([2], tensorshape_util.as_list(samples.shape)) - samples = self.evaluate(samples) - self.assertAllClose([1, 2], samples) + self.assertAllClose(6, kernel_results.counter_1) + self.assertAllClose(12, kernel_results.counter_2) def test_custom_trace(self): kernel = test_fixtures.TestTransitionKernel() - res = tfp.experimental.mcmc.sample_chain( + res = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, current_state=0., kernel=kernel, trace_fn=lambda *args: args, seed=test_util.test_seed()) + trace = res.trace - self.assertAllClose([2], tensorshape_util.as_list(res.all_states.shape)) - self.assertAllClose([2], tensorshape_util.as_list(res.trace[0].shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(res.trace[1].counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(res.trace[1].counter_2.shape)) - - res = self.evaluate(res) - self.assertAllClose([1, 2], res.all_states) - self.assertAllClose([1, 2], res.trace[0]) - self.assertAllClose([1, 2], res.trace[1].counter_1) - self.assertAllClose([2, 4], res.trace[1].counter_2) - - def test_checkpointing(self): - kernel = test_fixtures.TestTransitionKernel() - res = tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel, - trace_fn=None, - return_final_kernel_results=True, - seed=test_util.test_seed()) - - self.assertAllClose([2], tensorshape_util.as_list(res.all_states.shape)) - self.assertEqual((), res.trace) + self.assertAllClose([2], tensorshape_util.as_list(trace[0].shape)) self.assertAllClose( - [], tensorshape_util.as_list(res.final_kernel_results.counter_1.shape)) + [2], tensorshape_util.as_list(trace[1].counter_1.shape)) self.assertAllClose( - [], tensorshape_util.as_list(res.final_kernel_results.counter_2.shape)) - - res = self.evaluate(res) - self.assertAllClose([1, 2], res.all_states) - self.assertAllClose(2, res.final_kernel_results.counter_1) - self.assertAllClose(4, res.final_kernel_results.counter_2) + [2], tensorshape_util.as_list(trace[1].counter_2.shape)) - def test_warnings_default(self): - with warnings.catch_warnings(record=True) as triggered: - kernel = test_fixtures.TestTransitionKernel() - tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel, - seed=test_util.test_seed()) - self.assertTrue( - any('Tracing all kernel results by default is deprecated' in str( - warning.message) for warning in triggered)) - - def test_no_warnings_explicit(self): - with warnings.catch_warnings(record=True) as triggered: - kernel = test_fixtures.TestTransitionKernel() - tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel, - trace_fn=lambda current_state, kernel_results: kernel_results, - seed=test_util.test_seed()) - self.assertFalse( - any('Tracing all kernel results by default is deprecated' in str( - warning.message) for warning in triggered)) + trace = self.evaluate(trace) + self.assertAllClose([1, 2], trace[0]) + self.assertAllClose([1, 2], trace[1].counter_1) + self.assertAllClose([2, 4], trace[1].counter_2) def test_is_calibrated(self): with warnings.catch_warnings(record=True) as triggered: kernel = test_fixtures.TestTransitionKernel(is_calibrated=False) - tfp.experimental.mcmc.sample_chain( + tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, current_state=0., kernel=kernel, @@ -535,12 +401,12 @@ def log_prob(x): target_log_prob_fn=log_prob, num_leapfrog_steps=3, step_size=1e-3) - return tfp.experimental.mcmc.sample_chain( + results = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=5, num_burnin_steps=4, current_state=initial_state, - kernel=kernel, - trace_fn=None) + kernel=kernel) + return results.trace # Checking that shape inference doesn't fail. sample(2) @@ -548,24 +414,21 @@ def log_prob(x): def test_seed_reproducibility(self): first_fake_kernel = test_fixtures.RandomTransitionKernel() second_fake_kernel = test_fixtures.RandomTransitionKernel() - seed = samplers.sanitize_seed(test_util.test_seed()) - first_final_state = tfp.experimental.mcmc.sample_chain( + seed = test_util.test_seed(sampler_type='stateless') + first_trace = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=5, current_state=0., kernel=first_fake_kernel, - seed=seed, - ) - second_final_state = tfp.experimental.mcmc.sample_chain( + seed=seed).trace + second_trace = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=5, current_state=1., # difference should be irrelevant kernel=second_fake_kernel, - seed=seed, - ) - first_final_state, second_final_state = self.evaluate([ - first_final_state, second_final_state - ]) + seed=seed).trace + first_trace, second_trace = self.evaluate([ + first_trace, second_trace]) self.assertAllCloseNested( - first_final_state, second_final_state, rtol=1e-6) + first_trace, second_trace, rtol=1e-6) @test_util.test_graph_mode_only @@ -589,7 +452,7 @@ def target_log_prob(x, y): z = tf.linalg.triangular_solve(true_cov_chol, z[..., tf.newaxis])[..., 0] return -0.5 * tf.reduce_sum(z**2., axis=-1) - states = tfp.experimental.mcmc.sample_chain( + states = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=num_results, current_state=[dtype(-2), dtype(2)], kernel=tfp.mcmc.HamiltonianMonteCarlo( @@ -598,8 +461,7 @@ def target_log_prob(x, y): num_leapfrog_steps=2), num_burnin_steps=20, num_steps_between_results=1, - trace_fn=None, - seed=test_util.test_seed()) + seed=test_util.test_seed()).trace self.assertAllEqual(dict(target_calls=1), counter) states = tf.stack(states, axis=-1) From 5f0dbec93c07fc0e375f0132e13fc35904201462 Mon Sep 17 00:00:00 2001 From: bjp Date: Wed, 16 Dec 2020 12:12:33 -0800 Subject: [PATCH 33/36] Add float64 support to PHMC. Previously, had the exception: "TypeError: Tensors in list passed to 'inputs' of 'AddN' Op have types [float64, float64, float32] that don't all match." PiperOrigin-RevId: 347871205 --- .../experimental/mcmc/preconditioned_hmc.py | 2 +- .../mcmc/preconditioned_hmc_test.py | 30 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py b/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py index ceb385f4ac..892fc765cf 100644 --- a/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py +++ b/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py @@ -464,7 +464,7 @@ def _prepare_args(target_log_prob_fn, def _batched_isotropic_normal_like(state_part): event_ndims = ps.rank(state_part) - batch_rank return independent.Independent( - normal.Normal(ps.zeros_like(state_part, tf.float32), 1.), + normal.Normal(ps.zeros_like(state_part), 1.), reinterpreted_batch_ndims=event_ndims) momentum_distribution = jds.JointDistributionSequential( diff --git a/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc_test.py b/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc_test.py index 79ff9e3c19..dd625ca481 100644 --- a/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc_test.py +++ b/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc_test.py @@ -386,6 +386,36 @@ def test_correctness_with_200d_mvn_tril(self, precondition_scheme): dict(testcase_name='_explicit', use_default=False)) class PreconditionedHMCTest(test_util.TestCase): + def test_f64(self, use_default): + if use_default: + momentum_distribution = None + else: + momentum_distribution = tfp.experimental.as_composite( + tfd.Normal(0., tf.constant(.5, dtype=tf.float64))) + kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo( + lambda x: -x**2, step_size=.5, num_leapfrog_steps=2, + momentum_distribution=momentum_distribution) + kernel = tfp.mcmc.SimpleStepSizeAdaptation(kernel, num_adaptation_steps=3) + self.evaluate(tfp.mcmc.sample_chain( + 1, kernel=kernel, current_state=tf.ones([], tf.float64), + num_burnin_steps=5, trace_fn=None)) + + # TODO(b/175787154): Enable this test + def DISABLED_test_f64_multichain(self, use_default): + if use_default: + momentum_distribution = None + else: + momentum_distribution = tfp.experimental.as_composite( + tfd.Normal(0., tf.constant(.5, dtype=tf.float64))) + kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo( + lambda x: -x**2, step_size=.5, num_leapfrog_steps=2, + momentum_distribution=momentum_distribution) + kernel = tfp.mcmc.SimpleStepSizeAdaptation(kernel, num_adaptation_steps=3) + nchains = 7 + self.evaluate(tfp.mcmc.sample_chain( + 1, kernel=kernel, current_state=tf.ones([nchains], tf.float64), + num_burnin_steps=5, trace_fn=None)) + def test_diag(self, use_default): """Test that a diagonal multivariate normal can be effectively sampled from. From 1cc5c39aee7655946df3c0ba589f2711c58d7b86 Mon Sep 17 00:00:00 2001 From: axch Date: Wed, 16 Dec 2020 12:38:44 -0800 Subject: [PATCH 34/36] Allow 0 concentration in gamma samplers. Gamma(concentration=0) always samples 0. ExpGamma(concentration=0) always samples -inf. Beta(concentration1=0) always samples 0. Beta(concentration0=0) always samples 1. BetaBinomial(concentration1=0) always samples 0. BetaBinomial(concentration0=0) always samples total_counts. likewise for Dirichlet and DirichletMultinomial. Not changing the validation because (i) inertia, and (ii) any of these distributions is still degenerate with a 0 concentration, so should arguably be avoided when possible. PiperOrigin-RevId: 347876446 --- .../python/distributions/beta_binomial_test.py | 6 ++++++ .../python/distributions/exp_gamma_test.py | 7 ++++++- tensorflow_probability/python/distributions/gamma.py | 4 ++-- tensorflow_probability/python/distributions/gamma_test.py | 8 +++++++- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/tensorflow_probability/python/distributions/beta_binomial_test.py b/tensorflow_probability/python/distributions/beta_binomial_test.py index a66d576ae1..48932dd023 100644 --- a/tensorflow_probability/python/distributions/beta_binomial_test.py +++ b/tensorflow_probability/python/distributions/beta_binomial_test.py @@ -143,6 +143,12 @@ def testSampleAgainstProb(self): np.sum(x == i, axis=0) / (num_samples * 1.0), atol=0.01, rtol=0.1) + def testSampleCornerConcentrations(self): + seed_stream = test_util.test_seed_stream() + d = tfd.BetaBinomial(concentration0=[1., 0.], concentration1=[0., 1.], + total_count=50.) + self.assertAllEqual(d.sample(10, seed=seed_stream()), [[0, 50]] * 10) + def testEmpiricalCdfAgainstDirichletMultinomial(self): # This test is too slow for Eager mode. if tf.executing_eagerly(): diff --git a/tensorflow_probability/python/distributions/exp_gamma_test.py b/tensorflow_probability/python/distributions/exp_gamma_test.py index 4f6d3a9131..74a7f579c8 100644 --- a/tensorflow_probability/python/distributions/exp_gamma_test.py +++ b/tensorflow_probability/python/distributions/exp_gamma_test.py @@ -217,7 +217,7 @@ def testSample(self): d.variance(), atol=.15) - def testSampleReturnsNansForNonPositiveParameters(self): + def testSampleNonPositiveParameters(self): d = tfd.ExpGamma([1., 2.], 1., validate_args=False) seed_stream = test_util.test_seed_stream() samples = self.evaluate(d.sample(seed=seed_stream())) @@ -227,6 +227,11 @@ def testSampleReturnsNansForNonPositiveParameters(self): d = tfd.ExpGamma([0., 2.], 1., validate_args=False) samples = self.evaluate(d.sample(seed=seed_stream())) self.assertEqual(samples.shape, (2,)) + self.assertAllEqual([s == -np.inf for s in samples], [True, False]) + + d = tfd.ExpGamma([-0.001, 2.], 1., validate_args=False) + samples = self.evaluate(d.sample(seed=seed_stream())) + self.assertEqual(samples.shape, (2,)) self.assertAllEqual([np.isnan(s) for s in samples], [True, False]) d = tfd.ExpGamma([1., -1.], 1., validate_args=False) diff --git a/tensorflow_probability/python/distributions/gamma.py b/tensorflow_probability/python/distributions/gamma.py index e25349c21b..411db9e05d 100644 --- a/tensorflow_probability/python/distributions/gamma.py +++ b/tensorflow_probability/python/distributions/gamma.py @@ -384,7 +384,7 @@ def _tensorshape_or_scalar(v0, v1): def _random_gamma_cpu( shape, concentration, rate=None, log_rate=None, seed=None, log_space=False): """Sample using *fast* `tf.random.stateless_gamma`.""" - bad_concentration = (concentration <= 0.) | tf.math.is_nan(concentration) + bad_concentration = (concentration < 0.) | tf.math.is_nan(concentration) safe_concentration = tf.where( bad_concentration, dtype_util.as_numpy_dtype(concentration.dtype)(100.), concentration) @@ -711,7 +711,7 @@ def rejection_sample(concentration): # Note, concentration here already has a shape that is broadcast with rate. cast_concentration = tf.cast(concentration, internal_dtype) - good_params_mask = (concentration > 0.) + good_params_mask = (concentration >= 0.) # When replacing NaN values, use 100. for concentration, since that leads to # a high-likelihood of the rejection sampler accepting on the first pass. safe_concentration = tf.where(good_params_mask, cast_concentration, 100.) diff --git a/tensorflow_probability/python/distributions/gamma_test.py b/tensorflow_probability/python/distributions/gamma_test.py index 0f1bb7245c..86f4898b9b 100644 --- a/tensorflow_probability/python/distributions/gamma_test.py +++ b/tensorflow_probability/python/distributions/gamma_test.py @@ -303,7 +303,7 @@ def testGammaSample(self): sp_stats.gamma.var(concentration_v, scale=1 / rate_v), atol=.15) - def testGammaSampleReturnsNansForNonPositiveParameters(self): + def testGammaSampleZeroAndNegativeParameters(self): gamma = tfd.Gamma([1., 2.], 1., validate_args=False) seed_stream = test_util.test_seed_stream() samples = self.evaluate(gamma.sample(seed=seed_stream())) @@ -313,6 +313,12 @@ def testGammaSampleReturnsNansForNonPositiveParameters(self): gamma = tfd.Gamma([0., 2.], 1., validate_args=False) samples = self.evaluate(gamma.sample(seed=seed_stream())) self.assertEqual(samples.shape, (2,)) + self.assertAllEqual([s in [0, np.finfo(np.float32).tiny] + for s in samples], [True, False]) + + gamma = tfd.Gamma([-0.001, 2.], 1., validate_args=False) + samples = self.evaluate(gamma.sample(seed=seed_stream())) + self.assertEqual(samples.shape, (2,)) self.assertAllEqual([np.isnan(s) for s in samples], [True, False]) gamma = tfd.Gamma([1., -1.], 1., validate_args=False) From ead85013bf526d37682cbb3299c70b0b814fd381 Mon Sep 17 00:00:00 2001 From: vanderplas Date: Sun, 20 Dec 2020 12:55:29 -0800 Subject: [PATCH 35/36] internal change PiperOrigin-RevId: 348378916 --- spinoffs/oryx/oryx/core/interpreters/inverse/BUILD | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/spinoffs/oryx/oryx/core/interpreters/inverse/BUILD b/spinoffs/oryx/oryx/core/interpreters/inverse/BUILD index 46cb94a7da..b203bbb2f1 100644 --- a/spinoffs/oryx/oryx/core/interpreters/inverse/BUILD +++ b/spinoffs/oryx/oryx/core/interpreters/inverse/BUILD @@ -93,6 +93,11 @@ py_test( name = "inverse_test", srcs = ["inverse_test.py"], python_version = "PY3", + # This test no longer works after cl/346850541, because no inverse is registered + # for convert_element_type. + tags = [ + "notap", + ], deps = [ ":core", ":rules", From 90fa3f0693c2d67bf6c74212dcd7c267443b8c96 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Mon, 21 Dec 2020 16:25:17 -0800 Subject: [PATCH 36/36] Set the version for the TFP 0.12.0 release. --- tensorflow_probability/python/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_probability/python/version.py b/tensorflow_probability/python/version.py index 0b75390eae..e18bcad27d 100644 --- a/tensorflow_probability/python/version.py +++ b/tensorflow_probability/python/version.py @@ -24,7 +24,7 @@ # stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a # release branch, the current version is by default assumed to be a # 'development' version, labeled 'dev'. -_VERSION_SUFFIX = 'rc4' +_VERSION_SUFFIX = '' # Example, '0.4.0-dev' __version__ = '.'.join([