diff --git a/discussion/fun_mcmc/fun_mcmc_test.py b/discussion/fun_mcmc/fun_mcmc_test.py index 28ac5a950a..2dae3d1bba 100644 --- a/discussion/fun_mcmc/fun_mcmc_test.py +++ b/discussion/fun_mcmc/fun_mcmc_test.py @@ -371,8 +371,8 @@ def log_prob_fn(x, y): tfp.distributions.Normal(self._constant(1.), 1.).log_prob(y)), () bijectors = [ - tfp.bijectors.AffineScalar(scale=self._constant(2.)), - tfp.bijectors.AffineScalar(scale=self._constant(3.)) + tfp.bijectors.Scale(scale=self._constant(2.)), + tfp.bijectors.Scale(scale=self._constant(3.)) ] (transformed_log_prob_fn, @@ -398,8 +398,8 @@ def log_prob_fn(x, y): tfp.distributions.Normal(self._constant(1.), 1.).log_prob(y)), () bijectors = { - 'x': tfp.bijectors.AffineScalar(scale=self._constant(2.)), - 'y': tfp.bijectors.AffineScalar(scale=self._constant(3.)) + 'x': tfp.bijectors.Scale(scale=self._constant(2.)), + 'y': tfp.bijectors.Scale(scale=self._constant(3.)) } (transformed_log_prob_fn, diff --git a/spinoffs/oryx/oryx/core/interpreters/inverse/BUILD b/spinoffs/oryx/oryx/core/interpreters/inverse/BUILD index 46cb94a7da..b203bbb2f1 100644 --- a/spinoffs/oryx/oryx/core/interpreters/inverse/BUILD +++ b/spinoffs/oryx/oryx/core/interpreters/inverse/BUILD @@ -93,6 +93,11 @@ py_test( name = "inverse_test", srcs = ["inverse_test.py"], python_version = "PY3", + # This test no longer works after cl/346850541, because no inverse is registered + # for convert_element_type. + tags = [ + "notap", + ], deps = [ ":core", ":rules", diff --git a/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb index a87a610d77..eacd036a67 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb @@ -3,7 +3,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "htW5SiGzeXYm" }, "source": [ @@ -14,11 +13,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "cellView": "form", - "colab": {}, - "colab_type": "code", "id": "9HGeUNoteaSm" }, "outputs": [], @@ -39,7 +35,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "JJ3UDciDVcB5" }, "source": [ @@ -64,7 +59,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "lin40yCC6eBo" }, "source": [ @@ -74,7 +68,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "eZs1ShikNBK2" }, "source": [ @@ -84,25 +77,23 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "7JjokKMbk2hJ" }, "source": [ "For $k\\in\\{1,\\ldots, K\\}$ mixture components each of dimension $D$, we'd like to model $i\\in\\{1,\\ldots,N\\}$ iid samples using the following Bayesian Gaussian Mixture Model:\n", "\n", "$$\\begin{align*}\n", - "\\theta &\\sim \\text{Dirichlet}(\\text{concentration}=\\alpha_0)\\\\\n", - "\\mu_k &\\sim \\text{Normal}(\\text{loc}=\\mu_{0k}, \\text{scale}=I_D)\\\\\n", - "T_k &\\sim \\text{Wishart}(\\text{df}=5, \\text{scale}=I_D)\\\\\n", - "Z_i &\\sim \\text{Categorical}(\\text{probs}=\\theta)\\\\\n", - "Y_i &\\sim \\text{Normal}(\\text{loc}=\\mu_{z_i}, \\text{scale}=T_{z_i}^{-1/2})\\\\\n", + "\\theta \u0026\\sim \\text{Dirichlet}(\\text{concentration}=\\alpha_0)\\\\\n", + "\\mu_k \u0026\\sim \\text{Normal}(\\text{loc}=\\mu_{0k}, \\text{scale}=I_D)\\\\\n", + "T_k \u0026\\sim \\text{Wishart}(\\text{df}=5, \\text{scale}=I_D)\\\\\n", + "Z_i \u0026\\sim \\text{Categorical}(\\text{probs}=\\theta)\\\\\n", + "Y_i \u0026\\sim \\text{Normal}(\\text{loc}=\\mu_{z_i}, \\text{scale}=T_{z_i}^{-1/2})\\\\\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "iySRABi0qZnQ" }, "source": [ @@ -112,7 +103,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "Y6X_Beihwzyi" }, "source": [ @@ -131,10 +121,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "uswTWdgNu46j" }, "outputs": [], @@ -163,7 +151,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "Uj9uHZN2yUqz" }, "source": [ @@ -180,10 +167,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "nc4yy6vW-lC_" }, "outputs": [], @@ -197,9 +182,9 @@ " scale=tf.ones_like(loc)),\n", " reinterpreted_batch_ndims=1),\n", " bijector=tfb.Chain([\n", - " tfb.Affine(shift=loc),\n", - " tfb.Invert(tfb.Affine(scale_tril=chol_precision_tril,\n", - " adjoint=True)),\n", + " tfb.Shift(shift=loc),\n", + " tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=chol_precision_tril,\n", + " adjoint=True)),\n", " ]),\n", " name=name)" ] @@ -207,7 +192,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "JDOkWhDQg4ZG" }, "source": [ @@ -219,7 +203,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "Pfkc8cmhh2Qz" }, "source": [ @@ -228,12 +211,11 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": { "height": 152 }, - "colab_type": "code", "id": "GhqbjwlIh1Vn", "outputId": "3ea12c10-cb9b-4558-aedd-386b37adc909" }, @@ -291,7 +273,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "N60z8scN1v6E" }, "source": [ @@ -300,10 +281,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "xhzxySDjL2-S" }, "outputs": [], @@ -316,10 +295,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "xAOmHhZ7LzDQ" }, "outputs": [], @@ -353,10 +330,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "CpLnRJr2TXYD" }, "outputs": [], @@ -385,7 +360,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "7jTMXdymV1QJ" }, "source": [ @@ -395,7 +369,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "rl4brz3G3pS7" }, "source": [ @@ -404,10 +377,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "1AJZAtwXV8RQ" }, "outputs": [], @@ -425,7 +396,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "zVOvMh7MV37A" }, "source": [ @@ -435,7 +405,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "cdN3iKFT32Jp" }, "source": [ @@ -446,10 +415,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "tVoaDFSf7L_j" }, "outputs": [], @@ -459,10 +426,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "a0OMIWIYeMmQ" }, "outputs": [], @@ -482,7 +447,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "TVpiT3LLyfcO" }, "source": [ @@ -492,7 +456,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "JS8XOsxiyiBV" }, "source": [ @@ -507,7 +470,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "Vt9SXJzO0Cks" }, "source": [ @@ -526,10 +488,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "_atEQrDR7JvG" }, "outputs": [], @@ -545,10 +505,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "0zq6QJJ-NSPJ" }, "outputs": [], @@ -575,7 +533,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "QLEz96mg6fpZ" }, "source": [ @@ -584,10 +541,8 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "colab": {}, - "colab_type": "code", "id": "_ceX1A3-ZFiN" }, "outputs": [], @@ -601,12 +556,11 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": { "height": 270 }, - "colab_type": "code", "id": "bqJ6RSJxegC6", "outputId": "e0867545-0509-4077-d89d-74e1d5280062" }, @@ -642,12 +596,11 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": { "height": 289 }, - "colab_type": "code", "id": "zFOU0j9kPdUy", "outputId": "17f4ce0c-24c3-4cf4-ebe8-b932caac7ba4" }, @@ -676,7 +629,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "NmfNIM1c6mwc" }, "source": [ @@ -686,7 +638,6 @@ { "cell_type": "markdown", "metadata": { - "colab_type": "text", "id": "t8LeIeMn6ot4" }, "source": [ @@ -698,7 +649,6 @@ "colab": { "collapsed_sections": [], "name": "Bayesian Gaussian Mixture Model", - "private_outputs": false, "provenance": [], "toc_visible": true }, diff --git a/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Probability_Case_Study_Covariance_Estimation.ipynb b/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Probability_Case_Study_Covariance_Estimation.ipynb index 66720fde64..1b53a54cda 100644 --- a/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Probability_Case_Study_Covariance_Estimation.ipynb +++ b/tensorflow_probability/examples/jupyter_notebooks/TensorFlow_Probability_Case_Study_Covariance_Estimation.ipynb @@ -2672,8 +2672,8 @@ "\n", "Our approach (courtesy of [this notebook](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb)):\n", "1. Use `tfd.Independent()` to combine a batch of 1-D `Normal` random variables into a single multi-dimensional random variable. The `reinterpreted_batch_ndims` parameter for `Independent()` specifies the number of batch dimensions that should be reinterpreted as event dimensions. In our case we create a 1-D batch of length 2 that we transform into a 1-D event of length 2, so `reinterpreted_batch_ndims=1`.\n", - "2. Apply a bijector to add the desired covariance: `tfb.Invert(tfb.Affine(scale_tril=precision_cholesky, adjoint=True))`. Note that above we're multiplying our iid normal random variables by the transpose of the inverse of the Cholesky factor of the precision matrix $(B^{-T}X)$. The `tfb.Invert` takes care of inverting $B$, and the `adjoint=True` flag performs the transpose.\n", - "3. Apply a bijector to add the desired offset: `tfb.Affine(shift=shift)` Note that we have to do the shift as a separate step from the initial inverted affine transform because otherwise the inverted scale is applied to the shift (since the inverse of $y=Ax+b$ is $x=A^{-1}y - A^{-1}b$).\n" + "2. Apply a bijector to add the desired covariance: `tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=precision_cholesky, adjoint=True))`. Note that above we're multiplying our iid normal random variables by the transpose of the inverse of the Cholesky factor of the precision matrix $(B^{-T}X)$. The `tfb.Invert` takes care of inverting $B$, and the `adjoint=True` flag performs the transpose.\n", + "3. Apply a bijector to add the desired offset: `tfb.Shift(shift=shift)` Note that we have to do the shift as a separate step from the initial inverted affine transform because otherwise the inverted scale is applied to the shift (since the inverse of $y=Ax+b$ is $x=A^{-1}y - A^{-1}b$).\n" ] }, { @@ -2694,8 +2694,8 @@ " scale=tf.ones_like(loc)),\n", " reinterpreted_batch_ndims=1),\n", " bijector=tfb.Chain([\n", - " tfb.Affine(shift=loc),\n", - " tfb.Invert(tfb.Affine(scale_tril=precision_cholesky,\n", + " tfb.Shift(shift=loc),\n", + " tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=precision_cholesky,\n", " adjoint=True)),\n", " ]),\n", " name=name)" diff --git a/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py b/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py index 5366f7659d..239239594a 100644 --- a/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py +++ b/tensorflow_probability/examples/latent_dirichlet_allocation_distributions.py @@ -248,7 +248,7 @@ def make_prior(num_topics, initial_value): def model_fn(features, labels, mode, params, config): """Build the model function for use in an estimator. - Arguments: + Args: features: The input features for the estimator. labels: The labels, unused here. mode: Signifies whether it is train or test or predict. @@ -353,7 +353,7 @@ def get_topics_strings(topics_words, alpha, vocabulary, topics_to_print=10, words_per_topic=10): """Returns the summary of the learned topics. - Arguments: + Args: topics_words: KxV tensor with topics as rows and words as columns. alpha: 1xK tensor of prior Dirichlet concentrations for the topics. @@ -464,7 +464,7 @@ def build_input_fns(data_dir, batch_size): Each object is represented as a bag-of-words vector. - Arguments: + Args: data_dir: Folder in which to store the data. batch_size: Batch size for both train and evaluation. Returns: diff --git a/tensorflow_probability/examples/vae.py b/tensorflow_probability/examples/vae.py index 43e3ce9e5c..66d7b5a2a2 100644 --- a/tensorflow_probability/examples/vae.py +++ b/tensorflow_probability/examples/vae.py @@ -325,7 +325,7 @@ def image_tile_summary(name, tensor, rows=8, cols=8): def model_fn(features, labels, mode, params, config): """Builds the model function for use in an estimator. - Arguments: + Args: features: The input features for the estimator. labels: The labels, unused here. mode: Signifies whether it is train or test or predict. diff --git a/tensorflow_probability/python/__init__.py b/tensorflow_probability/python/__init__.py index 345dbe2511..d5fd9afa7e 100644 --- a/tensorflow_probability/python/__init__.py +++ b/tensorflow_probability/python/__init__.py @@ -19,6 +19,7 @@ from __future__ import print_function import functools +import sys import types from tensorflow_probability.python.internal import all_util @@ -39,7 +40,7 @@ def _validate_tf_environment(package): """ try: import tensorflow.compat.v1 as tf - except ImportError: + except (ImportError, ModuleNotFoundError): # Print more informative error message, then reraise. print('\n\nFailed to import TensorFlow. Please note that TensorFlow is not ' 'installed by default when you install TensorFlow Probability. This ' @@ -96,13 +97,11 @@ def _validate_tf_environment(package): util: types.ModuleType vi: types.ModuleType -_allowed_symbols = [ +_lazy_load = [ 'bijectors', 'debugging', 'distributions', - 'experimental', 'glm', - 'layers', 'math', 'mcmc', 'monte_carlo', @@ -114,11 +113,33 @@ def _validate_tf_environment(package): 'vi', ] -for pkg in _allowed_symbols: - globals()[pkg] = lazy_loader.LazyLoader( - pkg, globals(), 'tensorflow_probability.python.{}'.format(pkg), +# If TensorFlow is already imported, we should non-lazily load modules which +# include registrations (e.g., Keras layer registrations and CompositeTensor +# registrations) -- which must be loaded when deserializing tensorflow +# saved models. +_maybe_nonlazy_load = [ + 'experimental', + 'layers', +] + + +def _tf_loaded(): + return 'compat' in dir(sys.modules.get('tensorflow', None)) + + +# To start with, lazy-load everything. Later we may replace some of the +# lazy-loaded modules by forcing a load. +for pkg_name in _lazy_load + _maybe_nonlazy_load: + globals()[pkg_name] = lazy_loader.LazyLoader( + pkg_name, globals(), 'tensorflow_probability.python.{}'.format(pkg_name), # These checks need to happen before lazy-loading, since the modules # themselves will try to import tensorflow, too. - on_first_access=functools.partial(_validate_tf_environment, pkg)) + on_first_access=functools.partial(_validate_tf_environment, pkg_name)) + +if _tf_loaded(): + # Non-lazy load of packages that register with tensorflow or keras. + for pkg_name in _maybe_nonlazy_load: + dir(globals()[pkg_name]) # Forces loading the package from its lazy loader. + -all_util.remove_undocumented(__name__, _allowed_symbols) +all_util.remove_undocumented(__name__, _lazy_load + _maybe_nonlazy_load) diff --git a/tensorflow_probability/python/bijectors/BUILD b/tensorflow_probability/python/bijectors/BUILD index 0df4e10958..88d13d4b31 100644 --- a/tensorflow_probability/python/bijectors/BUILD +++ b/tensorflow_probability/python/bijectors/BUILD @@ -75,6 +75,7 @@ multi_substrate_py_library( ":joint_map", ":kumaraswamy_cdf", ":lambertw_transform", + ":ldj_ratio", ":masked_autoregressive", ":matrix_inverse_tril", ":moyal_cdf", @@ -260,6 +261,7 @@ multi_substrate_py_library( srcs = ["scale_matvec_diag.py"], deps = [ ":bijector", + ":ldj_ratio", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", "//tensorflow_probability/python/internal:dtype_util", @@ -581,6 +583,14 @@ multi_substrate_py_library( ], ) +multi_substrate_py_library( + name = "ldj_ratio", + srcs = ["ldj_ratio.py"], + deps = [ + # tensorflow dep, + ], +) + multi_substrate_py_library( name = "masked_autoregressive", srcs = ["masked_autoregressive.py"], diff --git a/tensorflow_probability/python/bijectors/bijector.py b/tensorflow_probability/python/bijectors/bijector.py index 1bbaddc9b7..4de390e18c 100644 --- a/tensorflow_probability/python/bijectors/bijector.py +++ b/tensorflow_probability/python/bijectors/bijector.py @@ -718,14 +718,14 @@ def __call__(self, value, name=None, **kwargs): ```python sigmoid = tfb.Reciprocal()( - tfb.AffineScalar(shift=1.)( + tfb.Shift(shift=1.)( tfb.Exp()( - tfb.AffineScalar(scale=-1.)))) + tfb.Scale(scale=-1.)))) # ==> `tfb.Chain([ # tfb.Reciprocal(), - # tfb.AffineScalar(shift=1.), + # tfb.Shift(shift=1.), # tfb.Exp(), - # tfb.AffineScalar(scale=-1.), + # tfb.Scale(scale=-1.), # ])` # ie, `tfb.Sigmoid()` log_normal = tfb.Exp()(tfd.Normal(0, 1)) diff --git a/tensorflow_probability/python/bijectors/bijector_composition_test.py b/tensorflow_probability/python/bijectors/bijector_composition_test.py index 5e4b80e547..ea89959302 100644 --- a/tensorflow_probability/python/bijectors/bijector_composition_test.py +++ b/tensorflow_probability/python/bijectors/bijector_composition_test.py @@ -38,9 +38,9 @@ def testComposeFromChainBijector(self): x = tf.constant([-5., 0., 5.]) sigmoid = functools.reduce(lambda chain, f: chain(f), [ tfb.Reciprocal(), - tfb.AffineScalar(shift=1.), + tfb.Shift(shift=1.), tfb.Exp(), - tfb.AffineScalar(scale=-1.), + tfb.Scale(scale=-1.), ]) self.assertIsInstance(sigmoid, tfb.Chain) self.assertAllClose( @@ -50,7 +50,7 @@ def testComposeFromChainBijector(self): def testComposeFromTransformedDistribution(self): actual_log_normal = tfb.Exp()(tfd.TransformedDistribution( distribution=tfd.Normal(0, 1), - bijector=tfb.AffineScalar(shift=0.5, scale=2.))) + bijector=tfb.Shift(shift=0.5)(tfb.Scale(scale=2.)))) expected_log_normal = tfd.LogNormal(0.5, 2.) x = tf.constant([0.1, 1., 5.]) self.assertAllClose( diff --git a/tensorflow_probability/python/bijectors/bijector_properties_test.py b/tensorflow_probability/python/bijectors/bijector_properties_test.py index c6b73b37f1..c5d73a8dad 100644 --- a/tensorflow_probability/python/bijectors/bijector_properties_test.py +++ b/tensorflow_probability/python/bijectors/bijector_properties_test.py @@ -36,7 +36,6 @@ TF2_FRIENDLY_BIJECTORS = ( - 'AffineScalar', 'Ascending', 'BatchNormalization', # 'CategoricalToDiscrete', TODO(b/137956955): Add support @@ -59,7 +58,6 @@ 'KumaraswamyCDF', 'Log', 'Log1p', - 'MatvecLU', 'MatrixInverseTriL', 'MoyalCDF', 'NormalCDF', @@ -92,13 +90,11 @@ ) BIJECTOR_PARAMS_NDIMS = { - 'AffineScalar': dict(shift=0, scale=0, log_scale=0), 'FrechetCDF': dict(loc=0, scale=0, concentration=0), 'GompertzCDF': dict(concentration=0, rate=0), 'GumbelCDF': dict(loc=0, scale=0), 'GeneralizedExtremeValueCDF': dict(loc=0, scale=0, concentration=0), 'KumaraswamyCDF': dict(concentration1=0, concentration0=0), - 'MatvecLU': dict(lower_upper=2, permutation=1), 'MoyalCDF': dict(loc=0, scale=0), 'Power': dict(power=0), 'RayleighCDF': dict(scale=0), @@ -125,7 +121,6 @@ INVERT_LDJ = {FLDJ: ILDJ, ILDJ: FLDJ} NO_LDJ_GRADS_EXPECTED = { - 'AffineScalar': dict(shift={FLDJ, ILDJ}), 'BatchNormalization': dict(beta={FLDJ, ILDJ}), 'FrechetCDF': dict(loc={ILDJ}), 'GeneralizedExtremeValueCDF': dict(loc={ILDJ}), @@ -135,7 +130,6 @@ } TRANSFORM_DIAGONAL_ALLOWLIST = { - 'AffineScalar', 'BatchNormalization', 'DiscreteCosineTransform', 'Exp', @@ -813,8 +807,6 @@ def ensure_nonzero(x): tfp_hps.softplus_plus_eps(), 'temperature': tfp_hps.softplus_plus_eps(eps=0.5), - 'AffineScalar.scale': - tfp_hps.softplus_plus_eps(), 'Scale.scale': tfp_hps.softplus_plus_eps(), 'ScaleMatvecDiag.scale_diag': diff --git a/tensorflow_probability/python/bijectors/blockwise_test.py b/tensorflow_probability/python/bijectors/blockwise_test.py index 31dc836190..b64b47f53c 100644 --- a/tensorflow_probability/python/bijectors/blockwise_test.py +++ b/tensorflow_probability/python/bijectors/blockwise_test.py @@ -48,7 +48,7 @@ def testExplicitBlocks(self, dynamic_shape, batch_shape): block_sizes.shape)) exp = tfb.Exp() sp = tfb.Softplus() - aff = tfb.Affine(scale_diag=[2., 3., 4.]) + aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise( bijectors=[exp, sp, aff], block_sizes=block_sizes, @@ -123,7 +123,7 @@ def testSizeChangingExplicitBlocks(self, dynamic_shape, batch_shape): block_sizes, shape=block_sizes.shape) exp = tfb.Exp() sc = tfb.SoftmaxCentered() - aff = tfb.Affine(scale_diag=[2., 3., 4.]) + aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise( bijectors=[exp, sc, aff], block_sizes=block_sizes, @@ -201,7 +201,7 @@ def testSizeChangingExplicitBlocks(self, dynamic_shape, batch_shape): def testBijectiveAndFinite(self): exp = tfb.Exp() sp = tfb.Softplus() - aff = tfb.Affine(scale_diag=[2., 3., 4.]) + aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff], block_sizes=[2, 1, 3]) x = tf.cast([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=tf.float32) @@ -219,17 +219,17 @@ def testBijectiveAndFinite(self): def testImplicitBlocks(self): exp = tfb.Exp() sp = tfb.Softplus() - aff = tfb.Affine(scale_diag=[2.]) + aff = tfb.ScaleMatvecDiag(scale_diag=[2.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff]) self.assertAllEqual(self.evaluate(blockwise.block_sizes), [1, 1, 1]) def testName(self): exp = tfb.Exp() sp = tfb.Softplus() - aff = tfb.Affine(scale_diag=[2., 3., 4.]) + aff = tfb.ScaleMatvecDiag(scale_diag=[2., 3., 4.]) blockwise = tfb.Blockwise(bijectors=[exp, sp, aff], block_sizes=[2, 1, 3]) self.assertStartsWith(blockwise.name, - 'blockwise_of_exp_and_softplus_and_affine') + 'blockwise_of_exp_and_softplus_and_scale_matvec_diag') def testNameOneBijector(self): exp = tfb.Exp() diff --git a/tensorflow_probability/python/bijectors/chain.py b/tensorflow_probability/python/bijectors/chain.py index 69aa55ef9e..a8d9ff2db2 100644 --- a/tensorflow_probability/python/bijectors/chain.py +++ b/tensorflow_probability/python/bijectors/chain.py @@ -22,6 +22,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.bijectors import composition +from tensorflow_probability.python.bijectors import ldj_ratio from tensorflow_probability.python.internal import nest_util from tensorflow_probability.python.internal import prefer_static as ps @@ -245,3 +246,16 @@ def update_i_event_ndims(bij, event_ndims): return (nest.map_structure(lambda nd: rolling_offset + nd, f_event_ndims), nest.map_structure(lambda nd: rolling_offset + nd, i_event_ndims)) + +@ldj_ratio.RegisterILDJRatio(Chain) +def _ildj_ratio_chain(p, x, q, y): + """Sum-of-diffs ILDJRatio for Chains.""" + if len(p.bijectors) != len(q.bijectors): + raise ValueError('Mismatched lengths of bijectors: `p` has ' + f'{len(p.bijectors)} but `q` has {len(q.bijectors)}.') + ratios = [] + for p, q in zip(p.bijectors, q.bijectors): + ratios.append(ldj_ratio.inverse_log_det_jacobian_ratio( + p, x, q, y, p.inverse_min_event_ndims)) + x, y = p.inverse(x), q.inverse(y) + return tf.add_n(ratios) diff --git a/tensorflow_probability/python/bijectors/chain_test.py b/tensorflow_probability/python/bijectors/chain_test.py index 1bf97c1a0f..30aa14eb87 100644 --- a/tensorflow_probability/python/bijectors/chain_test.py +++ b/tensorflow_probability/python/bijectors/chain_test.py @@ -108,19 +108,24 @@ def testMinEventNdimsChain(self): self.assertEqual(0, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Affine(), tfb.Affine(), tfb.Affine()]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), + tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), + tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Exp(), tfb.Affine()]) + chain = tfb.Chain([tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Affine(), tfb.Exp()]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), tfb.Exp()]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Affine(), tfb.Exp(), tfb.Softplus(), tfb.Affine()]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), + tfb.Exp(), + tfb.Softplus(), + tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) @@ -129,11 +134,13 @@ def testMinEventNdimsShapeChangingAddDims(self): self.assertEqual(0, chain.forward_min_event_ndims) self.assertEqual(3, chain.inverse_min_event_ndims) - chain = tfb.Chain([ShapeChanging(), tfb.Affine()]) + chain = tfb.Chain([ShapeChanging(), + tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(1, chain.forward_min_event_ndims) self.assertEqual(4, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Affine(), ShapeChanging()]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), + ShapeChanging()]) self.assertEqual(0, chain.forward_min_event_ndims) self.assertEqual(3, chain.inverse_min_event_ndims) @@ -146,11 +153,13 @@ def testMinEventNdimsShapeChangingRemoveDims(self): self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) - chain = tfb.Chain([ShapeChanging(3, 0), tfb.Affine()]) + chain = tfb.Chain([ShapeChanging(3, 0), + tfb.ScaleMatvecDiag(scale_diag=[1., 1.])]) self.assertEqual(3, chain.forward_min_event_ndims) self.assertEqual(0, chain.inverse_min_event_ndims) - chain = tfb.Chain([tfb.Affine(), ShapeChanging(3, 0)]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=[1., 1.]), + ShapeChanging(3, 0)]) self.assertEqual(4, chain.forward_min_event_ndims) self.assertEqual(1, chain.inverse_min_event_ndims) @@ -191,7 +200,7 @@ def testMinEventNdimsWithJointMap(self): def testChainExpAffine(self): scale_diag = np.array([1., 2., 3.], dtype=np.float32) - chain = tfb.Chain([tfb.Exp(), tfb.Affine(scale_diag=scale_diag)]) + chain = tfb.Chain([tfb.Exp(), tfb.ScaleMatvecDiag(scale_diag=scale_diag)]) x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] y = [1., 4., 27.] self.assertAllClose(y, self.evaluate(chain.forward(x))) @@ -206,7 +215,7 @@ def testChainExpAffine(self): def testChainAffineExp(self): scale_diag = np.array([1., 2., 3.], dtype=np.float32) - chain = tfb.Chain([tfb.Affine(scale_diag=scale_diag), tfb.Exp()]) + chain = tfb.Chain([tfb.ScaleMatvecDiag(scale_diag=scale_diag), tfb.Exp()]) x = [0., np.log(2., dtype=np.float32), np.log(3., dtype=np.float32)] y = [1., 4., 9.] self.assertAllClose(y, self.evaluate(chain.forward(x))) diff --git a/tensorflow_probability/python/bijectors/composition.py b/tensorflow_probability/python/bijectors/composition.py index f26277a4c6..bab364d60f 100644 --- a/tensorflow_probability/python/bijectors/composition.py +++ b/tensorflow_probability/python/bijectors/composition.py @@ -21,6 +21,7 @@ import abc import sys +import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector @@ -28,6 +29,7 @@ from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import nest_util from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow.python.ops import control_flow_util # pylint: disable=g-direct-tensorflow-import from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import @@ -36,6 +38,9 @@ ] +JAX_MODE = False + + def pack_structs_like(template, *structures): """Converts a tuple of structs like `template` to a structure of tuples.""" if not structures: @@ -300,7 +305,7 @@ def _walk_forward(self, step_fn, argument, **kwargs): The `_walk_{direction}` methods define how arguments are routed through nested bijectors, expressing the directed topology of the underlying graph. - Arguments: + Args: step_fn: A method taking a bijector, a single positional argument matching `bijector.forward_min_event_ndims`, and arbitrary **kwargs, and returning a structure matching `bijector.inverse_min_event_ndims`. @@ -319,7 +324,7 @@ def _walk_inverse(self, step_fn, argument, **kwargs): The `_walk_{direction}` methods define how arguments are routed through nested bijectors, expressing the directed topology of the underlying graph. - Arguments: + Args: step_fn: A method taking a bijector, a single positional argument matching `bijector.inverse_min_event_ndims`, and arbitrary **kwargs, and returning a structure matching `bijector.forward_min_event_ndims`. @@ -491,6 +496,10 @@ def _maybe_warn_increased_dof(self, raise ValueError(error_message) return assert_util.assert_equal(False, increased_dof, error_message) + if (not tf.executing_eagerly() and + control_flow_util.GraphOrParentsInXlaContext(tf1.get_default_graph())): + return # No StringFormat or Print ops in XLA. + # Otherwise, we print a warning and continue. return ps.cond( pred=increased_dof, diff --git a/tensorflow_probability/python/bijectors/discrete_cosine_transform.py b/tensorflow_probability/python/bijectors/discrete_cosine_transform.py index deb1449311..302f4d68e1 100644 --- a/tensorflow_probability/python/bijectors/discrete_cosine_transform.py +++ b/tensorflow_probability/python/bijectors/discrete_cosine_transform.py @@ -39,7 +39,7 @@ class DiscreteCosineTransform(bijector.Bijector): The inverse `X = g^{-1}(Y) = IDCT(Y)`, where IDCT is DCT-III for type==2. - This bijector can be interleaved with Affine bijectors to build a cascade of + This bijector can be interleaved with affine bijectors to build a cascade of structured efficient linear layers as in [1]. Note that the operator applied is orthonormal (i.e. `norm='ortho'`). diff --git a/tensorflow_probability/python/bijectors/expm1.py b/tensorflow_probability/python/bijectors/expm1.py index 7f2f6c8461..468047fd64 100644 --- a/tensorflow_probability/python/bijectors/expm1.py +++ b/tensorflow_probability/python/bijectors/expm1.py @@ -32,7 +32,7 @@ class Expm1(bijector.Bijector): """Compute `Y = g(X) = exp(X) - 1`. - This `Bijector` is no different from Chain([AffineScalar(shift=-1), Exp()]). + This `Bijector` is no different from Chain([Shift(-1), Exp()]). However, this makes use of the more numerically stable routines `tf.math.expm1` and `tf.log1p`. diff --git a/tensorflow_probability/python/bijectors/hypothesis_testlib.py b/tensorflow_probability/python/bijectors/hypothesis_testlib.py index 1a8f8160dc..5ce6e685d1 100644 --- a/tensorflow_probability/python/bijectors/hypothesis_testlib.py +++ b/tensorflow_probability/python/bijectors/hypothesis_testlib.py @@ -122,9 +122,6 @@ def bijector_supports(): return BIJECTOR_SUPPORTS Support = tfp_hps.Support # pylint: disable=invalid-name supports = { - 'AffineScalar': - BijectorSupport(Support.SCALAR_UNCONSTRAINED, - Support.SCALAR_UNCONSTRAINED), 'Ascending': BijectorSupport(Support.VECTOR_UNCONSTRAINED, Support.VECTOR_STRICTLY_INCREASING), diff --git a/tensorflow_probability/python/bijectors/invert_test.py b/tensorflow_probability/python/bijectors/invert_test.py index 9ca74513d2..fa8618fd34 100644 --- a/tensorflow_probability/python/bijectors/invert_test.py +++ b/tensorflow_probability/python/bijectors/invert_test.py @@ -36,7 +36,7 @@ def testBijector(self): for fwd in [ tfb.Identity(), tfb.Exp(), - tfb.Affine(shift=[0., 1.], scale_diag=[2., 3.]), + tfb.ScaleMatvecDiag(scale_diag=[2., 3.]), tfb.Softplus(), tfb.SoftmaxCentered(), ]: diff --git a/tensorflow_probability/python/bijectors/ldj_ratio.py b/tensorflow_probability/python/bijectors/ldj_ratio.py new file mode 100644 index 0000000000..d0c8f5c35e --- /dev/null +++ b/tensorflow_probability/python/bijectors/ldj_ratio.py @@ -0,0 +1,86 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Computes log-ratios of Jacobian determinants numerically stably.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect + +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python import math as tfp_math +from tensorflow_probability.python.internal import prefer_static as ps + +__all__ = [ + 'inverse_log_det_jacobian_ratio', + 'RegisterILDJRatio', +] + + +_ildj_ratio_registry = {} + + +def inverse_log_det_jacobian_ratio(p, x, q, y, event_ndims, use_kahan_sum=True): + """Computes `p.ildj(x, ndims) - q.idlj(y, ndims)`, numerically stably. + + Args: + p: A bijector instance. + x: A tensor from the support of `p.forward`. + q: A bijector instance of the same type as `p`, with matching shape. + y: A tensor from the support of `q.forward`. + event_ndims: The number of right-hand dimensions comprising the event shapes + of `x` and `y`. + use_kahan_sum: When `True`, the reduction of any remaining `event_ndims` + beyond the minimum is done using Kahan summation. This requires statically + known ranks. + + Returns: + ildj_ratio: `log ((abs o det o jac p^-1)(x) / (abs o det o jac q^-1)(y))`, + i.e. in TFP code, `p.inverse_log_det_jacobian(x, event_ndims) - + q.inverse_log_det_jacobian(y, event_ndims)`. In some cases + this will be computed with better than naive numerical precision, e.g. by + moving differences inside of a sum reduction. + """ + assert type(p) == type(q) # pylint: disable=unidiomatic-typecheck + + min_event_ndims = p.inverse_min_event_ndims + def ildj_ratio_fn(p, x, q, y): + return (p.inverse_log_det_jacobian(x, event_ndims=min_event_ndims) - + q.inverse_log_det_jacobian(y, event_ndims=min_event_ndims)) + + for cls in inspect.getmro(type(p)): + if cls in _ildj_ratio_registry: + ildj_ratio_fn = _ildj_ratio_registry[cls] + + if use_kahan_sum: + sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis=axis).total + else: + sum_fn = tf.reduce_sum + return sum_fn(ildj_ratio_fn(p, x, q, y), + axis=-1 - ps.range(event_ndims - min_event_ndims)) + + +class RegisterILDJRatio(object): + + def __init__(self, bijector_class): + self.cls = bijector_class + + def __call__(self, fn): + assert self.cls not in _ildj_ratio_registry + _ildj_ratio_registry[self.cls] = fn + return fn + diff --git a/tensorflow_probability/python/bijectors/masked_autoregressive.py b/tensorflow_probability/python/bijectors/masked_autoregressive.py index b6ee69e146..1230522692 100644 --- a/tensorflow_probability/python/bijectors/masked_autoregressive.py +++ b/tensorflow_probability/python/bijectors/masked_autoregressive.py @@ -435,7 +435,7 @@ def masked_dense(inputs, See [Germain et al. (2015)][1] for detailed explanation. - Arguments: + Args: inputs: Tensor input. units: Python `int` scalar representing the dimensionality of the output space. @@ -894,7 +894,7 @@ def __init__(self, **kwargs): """Constructs the MADE layer. - Arguments: + Args: params: Python integer specifying the number of parameters to output per input. event_shape: Python `list`-like of positive integers (or a single int), diff --git a/tensorflow_probability/python/bijectors/masked_autoregressive_test.py b/tensorflow_probability/python/bijectors/masked_autoregressive_test.py index 58b5388c22..a57f0342e2 100644 --- a/tensorflow_probability/python/bijectors/masked_autoregressive_test.py +++ b/tensorflow_probability/python/bijectors/masked_autoregressive_test.py @@ -103,7 +103,7 @@ def _bijector_fn(x): shift = reshape_output(shift) logit_gate = reshape_output(logit_gate) gate = tf.nn.sigmoid(logit_gate) - return tfb.AffineScalar(shift=(1. - gate) * shift, scale=gate) + return tfb.Shift(shift=(1. - gate) * shift)(tfb.Scale(scale=gate)) return _bijector_fn diff --git a/tensorflow_probability/python/bijectors/pad.py b/tensorflow_probability/python/bijectors/pad.py index 714944b02a..765d121d6b 100644 --- a/tensorflow_probability/python/bijectors/pad.py +++ b/tensorflow_probability/python/bijectors/pad.py @@ -23,7 +23,7 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import prefer_static +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.internal import tensorshape_util @@ -124,14 +124,14 @@ def __init__(self, parameters = dict(locals()) with tf.name_scope(name or 'pad') as name: paddings = tensor_util.convert_nonref_to_tensor( - paddings, dtype_hint=tf.int32, name='paddings') + paddings, dtype_hint=tf.int32, name='paddings', as_shape_tensor=True) if axis is None: - axis = prefer_static.range( - start=-prefer_static.size0(paddings), limit=0, + axis = ps.range( + start=-ps.size0(paddings), limit=0, dtype=tf.int32, name='axis') else: axis = tensor_util.convert_nonref_to_tensor( - axis, dtype_hint=tf.int32, name='axis') + axis, dtype_hint=tf.int32, name='axis', as_shape_tensor=True) axis_ = tf.get_static_value(axis) if axis_ is None: raise NotImplementedError( @@ -170,29 +170,27 @@ def axis(self): return self._axis def _forward(self, x): - ndims = prefer_static.rank(x) - indices = prefer_static.reshape(prefer_static.add(self.axis, ndims), - shape=[-1, 1]) + ndims = ps.rank(x) + indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) return tf.pad( x, - paddings=prefer_static.tensor_scatter_nd_update( - prefer_static.zeros([ndims, 2], dtype=tf.int32), + paddings=ps.tensor_scatter_nd_update( + ps.zeros([ndims, 2], dtype=tf.int32), indices, self.paddings), mode=self.mode, - constant_values=prefer_static.cast(self.constant_values, dtype=x.dtype)) + constant_values=ps.cast(self.constant_values, dtype=x.dtype)) def _inverse(self, y): - ndims = prefer_static.rank(y) - indices = prefer_static.reshape(prefer_static.add(self.axis, ndims), - shape=[-1, 1]) - num_left, num_right = prefer_static.unstack(self.paddings, num=2, axis=-1) + ndims = ps.rank(y) + indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) + num_left, num_right = ps.unstack(self.paddings, num=2, axis=-1) x = tf.slice( y, - begin=prefer_static.tensor_scatter_nd_update( - prefer_static.zeros(ndims, dtype=tf.int32), + begin=ps.tensor_scatter_nd_update( + ps.zeros(ndims, dtype=tf.int32), indices, num_left), - size=prefer_static.tensor_scatter_nd_sub( - prefer_static.shape(y), + size=ps.tensor_scatter_nd_sub( + ps.shape(y), indices, num_left + num_right)) if not self.validate_args: return x @@ -225,13 +223,12 @@ def _forward_event_shape(self, input_shape, is_inverse=False): return output_shape def _forward_event_shape_tensor(self, input_shape, is_inverse=False): - ndims = prefer_static.size(input_shape) - indices = prefer_static.reshape(prefer_static.add(self.axis, ndims), - shape=[-1, 1]) - extra_sizes = prefer_static.reduce_sum(self.paddings, axis=-1) - update_fn = (prefer_static.tensor_scatter_nd_sub if is_inverse else - prefer_static.tensor_scatter_nd_add) - return update_fn(prefer_static.identity(input_shape), indices, extra_sizes) + ndims = ps.size(input_shape) + indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1]) + extra_sizes = ps.reduce_sum(self.paddings, axis=-1) + update_fn = (ps.tensor_scatter_nd_sub if is_inverse else + ps.tensor_scatter_nd_add) + return update_fn(ps.identity(input_shape), indices, extra_sizes) def _inverse_event_shape(self, output_shape): input_shape = self._forward_event_shape(output_shape, is_inverse=True) @@ -284,8 +281,8 @@ def _parameter_control_dependencies(self, is_init): elif self.validate_args: if axis is None: axis = tf.convert_to_tensor(self.axis) assertions.append(assert_util.assert_equal( - prefer_static.size0(axis), - prefer_static.size0(prefer_static.setdiff1d(axis)), + ps.size0(axis), + ps.size0(ps.setdiff1d(axis)), message=msg)) if is_init != tensor_util.is_ref(self.paddings): @@ -320,19 +317,19 @@ def _parameter_control_dependencies(self, is_init): axis_ = tf.get_static_value(self.axis) if axis_ is None and axis is None: axis = tf.convert_to_tensor(self.axis) - len_axis = prefer_static.size0(prefer_static.reshape( + len_axis = ps.size0(ps.reshape( axis if axis_ is None else axis_, shape=-1)) paddings_ = tf.get_static_value(self.paddings) if paddings_ is None and paddings is None: paddings = tf.convert_to_tensor(self.paddings) - len_paddings = prefer_static.size0( + len_paddings = ps.size0( paddings if paddings_ is None else paddings_) msg = ('Arguments `axis` and `paddings` must have the same number ' 'of elements.') - if (prefer_static.is_numpy(len_axis) and - prefer_static.is_numpy(len_paddings)): + if (ps.is_numpy(len_axis) and + ps.is_numpy(len_paddings)): if len_axis != len_paddings: raise ValueError(msg + ' Saw: {}, {}.'.format( self.axis, self.paddings)) diff --git a/tensorflow_probability/python/bijectors/real_nvp.py b/tensorflow_probability/python/bijectors/real_nvp.py index 7645195606..d53b2587a5 100644 --- a/tensorflow_probability/python/bijectors/real_nvp.py +++ b/tensorflow_probability/python/bijectors/real_nvp.py @@ -48,7 +48,8 @@ class RealNVP(bijector_lib.Bijector): while the first `d` units are 'masked' and left unchanged. Real NVP's `shift_and_log_scale_fn` computes vector-valued quantities. For scale-and-shift transforms that do not depend on any masked units, i.e. - `d=0`, use the `tfb.Affine` bijector with learned parameters instead. + `d=0`, use the `tfb.Scale` and `tfb.Shift` bijectors with learned parameters + instead. Masking is currently only supported for base distributions with `event_ndims=1`. For more sophisticated masking schemes like checkerboard or @@ -344,7 +345,7 @@ def real_nvp_default_template(hidden_layers, Real NVP bijector, implement a conditioned shift/scale template that handles the `condition_kwargs`. - Arguments: + Args: hidden_layers: Python `list`-like of non-negative integer, scalars indicating the number of units in each hidden layer. Default: `[512, 512]`. diff --git a/tensorflow_probability/python/bijectors/real_nvp_test.py b/tensorflow_probability/python/bijectors/real_nvp_test.py index 613699f3fe..abb9531e96 100644 --- a/tensorflow_probability/python/bijectors/real_nvp_test.py +++ b/tensorflow_probability/python/bijectors/real_nvp_test.py @@ -227,7 +227,7 @@ def _bijector_fn(x, output_units): shift = reshape_output(shift) logit_gate = reshape_output(logit_gate) gate = tf.nn.sigmoid(logit_gate) - return tfb.AffineScalar(shift=(1. - gate) * shift, scale=gate) + return tfb.Shift(shift=(1. - gate) * shift)(tfb.Scale(scale=gate)) return tf1.make_template('gated_bijector', _bijector_fn) diff --git a/tensorflow_probability/python/bijectors/scale_matvec_diag.py b/tensorflow_probability/python/bijectors/scale_matvec_diag.py index 37f22758e8..50a36373ba 100644 --- a/tensorflow_probability/python/bijectors/scale_matvec_diag.py +++ b/tensorflow_probability/python/bijectors/scale_matvec_diag.py @@ -20,6 +20,7 @@ import tensorflow.compat.v2 as tf +from tensorflow_probability.python.bijectors import ldj_ratio from tensorflow_probability.python.bijectors import scale_matvec_linear_operator from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import tensor_util @@ -115,3 +116,11 @@ def _composite_tensor_nonshape_params(self): those that are shape-related. """ return ('scale_diag',) + + +@ldj_ratio.RegisterILDJRatio(ScaleMatvecDiag) +def _ildj_ratio_scale_matvec_diag(p, x, q, y): + del x, y + return tf.math.reduce_sum(tf.math.log(tf.math.abs(q.scale.diag_part())) - + tf.math.log(tf.math.abs(p.scale.diag_part())), + axis=-1) diff --git a/tensorflow_probability/python/bijectors/softfloor.py b/tensorflow_probability/python/bijectors/softfloor.py index e036111088..7860b88f21 100644 --- a/tensorflow_probability/python/bijectors/softfloor.py +++ b/tensorflow_probability/python/bijectors/softfloor.py @@ -93,7 +93,7 @@ class Softfloor(bijector.Bijector): # Ceiling is just a shifted floor at non-integer points. soft_ceiling = tfb.Chain( - [tfb.AffineScalar(1.), + [tfb.Shift(1.), tfb.Softfloor(temperature=1.)]) soft_ceiling.forward(x) # Should be close to [3., 5., 6.] ``` diff --git a/tensorflow_probability/python/bijectors/tanh.py b/tensorflow_probability/python/bijectors/tanh.py index 6090b94d71..44991b552b 100644 --- a/tensorflow_probability/python/bijectors/tanh.py +++ b/tensorflow_probability/python/bijectors/tanh.py @@ -34,9 +34,10 @@ class Tanh(bijector.Bijector): This can be achieved by an affine transform of the Sigmoid bijector, i.e., it is equivalent to ``` - tfb.Chain([tfb.Affine(shift=-1, scale=2.), + tfb.Chain([tfb.Shift(shift=-1.), + tfb.Scale(scale=2.), tfb.Sigmoid(), - tfb.Affine(scale=2.)]) + tfb.Scale(scale=2.)]) ``` However, using the `Tanh` bijector directly is slightly faster and more diff --git a/tensorflow_probability/python/bijectors/tanh_test.py b/tensorflow_probability/python/bijectors/tanh_test.py index cf210027db..0c53b75618 100644 --- a/tensorflow_probability/python/bijectors/tanh_test.py +++ b/tensorflow_probability/python/bijectors/tanh_test.py @@ -66,11 +66,10 @@ def testBijectiveAndFinite(self): def testMatchWithAffineTransform(self): direct_bj = tfb.Tanh() indirect_bj = tfb.Chain([ - tfb.AffineScalar( - shift=tf.cast(-1.0, dtype=tf.float64), - scale=tf.cast(2.0, dtype=tf.float64)), + tfb.Shift(tf.cast(-1.0, dtype=tf.float64)), + tfb.Scale(tf.cast(2.0, dtype=tf.float64)), tfb.Sigmoid(), - tfb.AffineScalar(scale=tf.cast(2.0, dtype=tf.float64)) + tfb.Scale(tf.cast(2.0, dtype=tf.float64)) ]) x = np.linspace(-3.0, 3.0, 100) diff --git a/tensorflow_probability/python/build_defs.bzl b/tensorflow_probability/python/build_defs.bzl index 7e66c9c826..7559e764c7 100644 --- a/tensorflow_probability/python/build_defs.bzl +++ b/tensorflow_probability/python/build_defs.bzl @@ -183,6 +183,9 @@ def multi_substrate_py_library( srcs_version: As with `py_library`. """ + if srcs_version != "PY3": + fail("Must use PY3 for srcs_version", srcs_version) + native.py_library( name = name, srcs = srcs, diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index 7c0c1fb2e6..fdcd195176 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -97,6 +97,7 @@ multi_substrate_py_library( ":laplace", ":linear_gaussian_ssm", ":lkj", + ":log_prob_ratio", ":logistic", ":logitnormal", ":loglogistic", @@ -131,6 +132,7 @@ multi_substrate_py_library( ":sinh_arcsinh", ":skellam", ":spherical_uniform", + ":stopping_ratio_logistic", ":student_t", ":student_t_process", ":transformed_distribution", @@ -573,6 +575,21 @@ multi_substrate_py_library( ], ) +multi_substrate_py_library( + name = "exponentially_modified_gaussian", + srcs = ["exponentially_modified_gaussian.py"], + deps = [ + ":distribution", + ":exponential", + ":normal", + # tensorflow dep, + "//tensorflow_probability/python/internal:dtype_util", + "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:reparameterization", + "//tensorflow_probability/python/internal:tensor_util", + ], +) + multi_substrate_py_library( name = "finite_discrete", srcs = ["finite_discrete.py"], @@ -858,6 +875,7 @@ multi_substrate_py_library( deps = [ ":distribution", ":kullback_leibler", + ":log_prob_ratio", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", "//tensorflow_probability/python/internal:prefer_static", @@ -936,6 +954,7 @@ multi_substrate_py_library( srcs = ["joint_distribution.py"], deps = [ ":distribution", + ":log_prob_ratio", # numpy dep, # six dep, # tensorflow dep, @@ -1148,6 +1167,14 @@ multi_substrate_py_library( ], ) +multi_substrate_py_library( + name = "log_prob_ratio", + srcs = ["log_prob_ratio.py"], + deps = [ + # tensorflow dep, + ], +) + multi_substrate_py_library( name = "logistic", srcs = ["logistic.py"], @@ -1702,6 +1729,7 @@ multi_substrate_py_library( deps = [ ":distribution", ":kullback_leibler", + ":log_prob_ratio", # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", @@ -1719,9 +1747,10 @@ multi_substrate_py_library( ":normal", ":transformed_distribution", # tensorflow dep, - "//tensorflow_probability/python/bijectors:affine_scalar", "//tensorflow_probability/python/bijectors:chain", "//tensorflow_probability/python/bijectors:identity", + "//tensorflow_probability/python/bijectors:scale", + "//tensorflow_probability/python/bijectors:shift", "//tensorflow_probability/python/bijectors:sinh_arcsinh", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", @@ -1748,6 +1777,23 @@ multi_substrate_py_library( ], ) +multi_substrate_py_library( + name = "stopping_ratio_logistic", + srcs = ["stopping_ratio_logistic.py"], + deps = [ + ":distribution", + ":kullback_leibler", + # tensorflow dep, + "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:distribution_util", + "//tensorflow_probability/python/internal:dtype_util", + "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:reparameterization", + "//tensorflow_probability/python/internal:tensor_util", + "//tensorflow_probability/python/internal:tensorshape_util", + ], +) + multi_substrate_py_library( name = "half_student_t", srcs = ["half_student_t.py"], @@ -1833,7 +1879,9 @@ multi_substrate_py_library( deps = [ ":distribution", ":kullback_leibler", + ":log_prob_ratio", # tensorflow dep, + "//tensorflow_probability/python/bijectors:ldj_ratio", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensorshape_util", ], @@ -1950,7 +1998,6 @@ multi_substrate_py_library( ":sample", ":transformed_distribution", # tensorflow dep, - "//tensorflow_probability/python/bijectors:affine_linear_operator", "//tensorflow_probability/python/bijectors:chain", "//tensorflow_probability/python/bijectors:scale_matvec_linear_operator", "//tensorflow_probability/python/bijectors:shift", @@ -2414,6 +2461,21 @@ multi_substrate_py_test( ], ) +multi_substrate_py_test( + name = "exponentially_modified_gaussian_test", + srcs = ["exponentially_modified_gaussian_test.py"], + jax_size = "medium", + # Disable numpy test for now because a bug in the types returned by special_math.ndtr + numpy_tags = ["notap"], + deps = [ + # numpy dep, + # scipy dep, + # tensorflow dep, + "//tensorflow_probability", + "//tensorflow_probability/python/internal:test_util", + ], +) + multi_substrate_py_test( name = "finite_discrete_test", size = "medium", @@ -3328,6 +3390,7 @@ multi_substrate_py_test( name = "sample_test", srcs = ["sample_test.py"], jax_size = "medium", + shard_count = 2, deps = [ # absl/testing:parameterized dep, # numpy dep, @@ -3383,6 +3446,18 @@ multi_substrate_py_test( ], ) +multi_substrate_py_test( + name = "stopping_ratio_logistic_test", + srcs = ["stopping_ratio_logistic_test.py"], + deps = [ + # absl/testing:parameterized dep, + # numpy dep, + # tensorflow dep, + "//tensorflow_probability", + "//tensorflow_probability/python/internal:test_util", + ], +) + multi_substrate_py_test( name = "half_student_t_test", size = "medium", @@ -3750,33 +3825,3 @@ py_binary( "//tensorflow_probability/python/distributions:hypothesis_testlib", ], ) - -multi_substrate_py_library( - name = "exponentially_modified_gaussian", - srcs = ["exponentially_modified_gaussian.py"], - deps = [ - ":distribution", - ":exponential", - ":normal", - # tensorflow dep, - "//tensorflow_probability/python/internal:dtype_util", - "//tensorflow_probability/python/internal:prefer_static", - "//tensorflow_probability/python/internal:reparameterization", - "//tensorflow_probability/python/internal:tensor_util", - ], -) - -multi_substrate_py_test( - name = "exponentially_modified_gaussian_test", - srcs = ["exponentially_modified_gaussian_test.py"], - jax_size = "medium", - # Disable numpy test for now because a bug in the types returned by special_math.ndtr - numpy_tags = ["notap"], - deps = [ - # numpy dep, - # scipy dep, - # tensorflow dep, - "//tensorflow_probability", - "//tensorflow_probability/python/internal:test_util", - ], -) diff --git a/tensorflow_probability/python/distributions/__init__.py b/tensorflow_probability/python/distributions/__init__.py index 8a9f81e225..e97b16edf4 100644 --- a/tensorflow_probability/python/distributions/__init__.py +++ b/tensorflow_probability/python/distributions/__init__.py @@ -109,6 +109,7 @@ from tensorflow_probability.python.distributions.sinh_arcsinh import SinhArcsinh from tensorflow_probability.python.distributions.skellam import Skellam from tensorflow_probability.python.distributions.spherical_uniform import SphericalUniform +from tensorflow_probability.python.distributions.stopping_ratio_logistic import StoppingRatioLogistic from tensorflow_probability.python.distributions.student_t import StudentT from tensorflow_probability.python.distributions.student_t_process import StudentTProcess from tensorflow_probability.python.distributions.transformed_distribution import TransformedDistribution @@ -223,6 +224,7 @@ 'SinhArcsinh', 'Skellam', 'SphericalUniform', + 'StoppingRatioLogistic', 'StudentT', 'StudentTProcess', 'Triangular', diff --git a/tensorflow_probability/python/distributions/beta_binomial_test.py b/tensorflow_probability/python/distributions/beta_binomial_test.py index a66d576ae1..48932dd023 100644 --- a/tensorflow_probability/python/distributions/beta_binomial_test.py +++ b/tensorflow_probability/python/distributions/beta_binomial_test.py @@ -143,6 +143,12 @@ def testSampleAgainstProb(self): np.sum(x == i, axis=0) / (num_samples * 1.0), atol=0.01, rtol=0.1) + def testSampleCornerConcentrations(self): + seed_stream = test_util.test_seed_stream() + d = tfd.BetaBinomial(concentration0=[1., 0.], concentration1=[0., 1.], + total_count=50.) + self.assertAllEqual(d.sample(10, seed=seed_stream()), [[0, 50]] * 10) + def testEmpiricalCdfAgainstDirichletMultinomial(self): # This test is too slow for Eager mode. if tf.executing_eagerly(): diff --git a/tensorflow_probability/python/distributions/deterministic.py b/tensorflow_probability/python/distributions/deterministic.py index eb4b0df3ed..c2ac05e9e1 100644 --- a/tensorflow_probability/python/distributions/deterministic.py +++ b/tensorflow_probability/python/distributions/deterministic.py @@ -24,6 +24,7 @@ import six import tensorflow.compat.v2 as tf +from tensorflow_probability.python import bijectors as tfb from tensorflow_probability.python.distributions import distribution from tensorflow_probability.python.distributions import kullback_leibler from tensorflow_probability.python.internal import assert_util @@ -154,7 +155,14 @@ def _sample_n(self, n, seed=None): axis=0)) def _default_event_space_bijector(self): - return + """The bijector maps a zero-dimensional null Tensor input to `self.loc`.""" + # The shape of the pulled back null tensor will be `self.loc.shape + (0,)`. + # First we pad to a tensor of zeros with shape `self.loc.shape + (1,)`. + pad_zero = tfb.Pad([(1, 0)]) + # Next, we squeeze to a tensor of zeros with shape matching `self.loc`. + zeros_squeezed = tfb.Reshape([], event_shape_in=[1])(pad_zero) + # Finally, we shift the zeros by `self.loc`. + return tfb.Shift(self.loc)(zeros_squeezed) def _parameter_control_dependencies(self, is_init): assertions = [] diff --git a/tensorflow_probability/python/distributions/deterministic_test.py b/tensorflow_probability/python/distributions/deterministic_test.py index df63b9c391..50e31ee92f 100644 --- a/tensorflow_probability/python/distributions/deterministic_test.py +++ b/tensorflow_probability/python/distributions/deterministic_test.py @@ -17,11 +17,13 @@ from __future__ import print_function # Dependency imports +from absl.testing import parameterized import numpy as np import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util rng = np.random.RandomState(0) @@ -255,6 +257,7 @@ def testVariableAssertions(self): self.evaluate(deterministic.log_prob(1.)) +@test_util.test_all_tf_execution_regimes class VectorDeterministicTest(test_util.TestCase): def testParamBroadcasts(self): @@ -454,6 +457,51 @@ def testVariableAssertions(self): 'Condition x >= 0'): self.evaluate(deterministic.log_prob([1.])) + @parameterized.named_parameters( + dict(testcase_name='_scalar', + dist_fn=lambda: tfd.Deterministic(3.)), + dict(testcase_name='_batch_scalar', + dist_fn=lambda: tfd.Deterministic([3., -7.])), + dict(testcase_name='_vector', + dist_fn=lambda: tfd.VectorDeterministic([3., -7.])), + dict(testcase_name='_batch_vector', + dist_fn=lambda: tfd.VectorDeterministic([[3., -7.], [-2, 4.]]))) + def testDefaultBijector(self, dist_fn): + dist = dist_fn() + bijector = dist.experimental_default_event_space_bijector() + self.assertEqual(dist.loc.shape, dist.batch_shape + dist.event_shape) + self.assertEqual(dist.event_shape + (0,), + bijector.inverse_event_shape(dist.event_shape)) + self.assertEqual(dist.loc.shape + (0,), + bijector.inverse_event_shape(dist.loc.shape)) + null_point = tf.ones(bijector.inverse_event_shape(dist.loc.shape)) + self.assertAllEqual( + tf.zeros([]), + bijector.forward_log_det_jacobian( + null_point, tensorshape_util.rank(null_point.shape))) + self.assertAllEqual(dist.loc, bijector(null_point)) + + @parameterized.named_parameters( + dict(testcase_name='_scalar', + dist_fn=lambda: tfd.Deterministic(3.)), + dict(testcase_name='_batch_scalar', + dist_fn=lambda: tfd.Deterministic([3., -7.])), + dict(testcase_name='_vector', + dist_fn=lambda: tfd.VectorDeterministic([3., -7.])), + dict(testcase_name='_batch_vector', + dist_fn=lambda: tfd.VectorDeterministic([[3., -7.], [-2, 4.]]))) + def testDefaultBijectorXLA(self, dist_fn): + self.skip_if_no_xla() + @tf.function(experimental_compile=True) + def fn(x): + bijector = dist_fn().experimental_default_event_space_bijector() + ndim = tensorshape_util.rank(x.shape) + return (bijector(x), + bijector.forward_log_det_jacobian(x, ndim), + bijector.inverse(0 + bijector(x)), + bijector.inverse_log_det_jacobian(0 + bijector(x), ndim - 1)) + self.evaluate(fn(tf.zeros(dist_fn().loc.shape + (0,)))) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/distributions/distribution.py b/tensorflow_probability/python/distributions/distribution.py index 9ec0f3129e..964efb2e39 100644 --- a/tensorflow_probability/python/distributions/distribution.py +++ b/tensorflow_probability/python/distributions/distribution.py @@ -1555,6 +1555,13 @@ def _name_and_control_scope(self, name=None, value=UNSET_VALUE, kwargs=None): if not deps: yield name_scope return + # In eager mode, some `assert_util.assert_xyz` calls return None. If a + # Distribution is created in eager mode with `validate_args=True`, then + # used in a `tf.function` context, it can result in errors when + # `tf.convert_to_tensor` is called on the inputs to + # `tf.control_dependencies` below. To avoid these errors, we drop the + # `None`s here. + deps = [x for x in deps if x is not None] with tf.control_dependencies(deps) as deps_scope: yield deps_scope diff --git a/tensorflow_probability/python/distributions/distribution_properties_test.py b/tensorflow_probability/python/distributions/distribution_properties_test.py index 02f9064f06..53c65b4384 100644 --- a/tensorflow_probability/python/distributions/distribution_properties_test.py +++ b/tensorflow_probability/python/distributions/distribution_properties_test.py @@ -282,7 +282,8 @@ def testCanConstructAndSampleDistribution(self, data): 'Empirical|event_ndims=2', 'FiniteDiscrete', 'MultivariateStudentTLinearOperator', 'PoissonLogNormalQuadratureCompound', - 'SphericalUniform', 'SinhArcsinh') + 'SphericalUniform', 'SinhArcsinh', + 'StoppingRatioLogistic',) non_trainable_dists = ( high_gt_low_constraint_dists + not_annotated_dists + dhps.INSTANTIABLE_META_DISTS) @@ -525,6 +526,29 @@ def disabled_testFailureCase(self): # pylint: disable=invalid-name self.assertAllClose(dist.log_prob(samps)[0], dist[0].log_prob(samps[0])) +# Don't decorate with test_util.test_all_tf_execution_regimes, since we're +# explicitly mixing modes. +class TestMixingGraphAndEagerModes(test_util.TestCase): + + @parameterized.named_parameters( + {'testcase_name': dname, 'dist_name': dname} + for dname in sorted(list(dhps.INSTANTIABLE_BASE_DISTS.keys()) + + list(dhps.INSTANTIABLE_META_DISTS)) + ) + @hp.given(hps.data()) + @tfp_hps.tfp_hp_settings() + def testSampleEagerCreatedDistributionInGraphMode(self, dist_name, data): + if not tf.executing_eagerly(): + self.skipTest('Only test mixed eager/graph behavior in eager tests.') + # Create in eager mode. + dist = data.draw(dhps.distributions(dist_name=dist_name, enable_vars=False)) + + @tf.function + def f(): + dist.sample() + f() + + if __name__ == '__main__': # Hypothesis often finds numerical near misses. Debugging them is much aided # by seeing all the digits of every floating point number, instead of the diff --git a/tensorflow_probability/python/distributions/exp_gamma_test.py b/tensorflow_probability/python/distributions/exp_gamma_test.py index 4f6d3a9131..74a7f579c8 100644 --- a/tensorflow_probability/python/distributions/exp_gamma_test.py +++ b/tensorflow_probability/python/distributions/exp_gamma_test.py @@ -217,7 +217,7 @@ def testSample(self): d.variance(), atol=.15) - def testSampleReturnsNansForNonPositiveParameters(self): + def testSampleNonPositiveParameters(self): d = tfd.ExpGamma([1., 2.], 1., validate_args=False) seed_stream = test_util.test_seed_stream() samples = self.evaluate(d.sample(seed=seed_stream())) @@ -227,6 +227,11 @@ def testSampleReturnsNansForNonPositiveParameters(self): d = tfd.ExpGamma([0., 2.], 1., validate_args=False) samples = self.evaluate(d.sample(seed=seed_stream())) self.assertEqual(samples.shape, (2,)) + self.assertAllEqual([s == -np.inf for s in samples], [True, False]) + + d = tfd.ExpGamma([-0.001, 2.], 1., validate_args=False) + samples = self.evaluate(d.sample(seed=seed_stream())) + self.assertEqual(samples.shape, (2,)) self.assertAllEqual([np.isnan(s) for s in samples], [True, False]) d = tfd.ExpGamma([1., -1.], 1., validate_args=False) diff --git a/tensorflow_probability/python/distributions/gamma.py b/tensorflow_probability/python/distributions/gamma.py index e25349c21b..411db9e05d 100644 --- a/tensorflow_probability/python/distributions/gamma.py +++ b/tensorflow_probability/python/distributions/gamma.py @@ -384,7 +384,7 @@ def _tensorshape_or_scalar(v0, v1): def _random_gamma_cpu( shape, concentration, rate=None, log_rate=None, seed=None, log_space=False): """Sample using *fast* `tf.random.stateless_gamma`.""" - bad_concentration = (concentration <= 0.) | tf.math.is_nan(concentration) + bad_concentration = (concentration < 0.) | tf.math.is_nan(concentration) safe_concentration = tf.where( bad_concentration, dtype_util.as_numpy_dtype(concentration.dtype)(100.), concentration) @@ -711,7 +711,7 @@ def rejection_sample(concentration): # Note, concentration here already has a shape that is broadcast with rate. cast_concentration = tf.cast(concentration, internal_dtype) - good_params_mask = (concentration > 0.) + good_params_mask = (concentration >= 0.) # When replacing NaN values, use 100. for concentration, since that leads to # a high-likelihood of the rejection sampler accepting on the first pass. safe_concentration = tf.where(good_params_mask, cast_concentration, 100.) diff --git a/tensorflow_probability/python/distributions/gamma_test.py b/tensorflow_probability/python/distributions/gamma_test.py index 0f1bb7245c..86f4898b9b 100644 --- a/tensorflow_probability/python/distributions/gamma_test.py +++ b/tensorflow_probability/python/distributions/gamma_test.py @@ -303,7 +303,7 @@ def testGammaSample(self): sp_stats.gamma.var(concentration_v, scale=1 / rate_v), atol=.15) - def testGammaSampleReturnsNansForNonPositiveParameters(self): + def testGammaSampleZeroAndNegativeParameters(self): gamma = tfd.Gamma([1., 2.], 1., validate_args=False) seed_stream = test_util.test_seed_stream() samples = self.evaluate(gamma.sample(seed=seed_stream())) @@ -313,6 +313,12 @@ def testGammaSampleReturnsNansForNonPositiveParameters(self): gamma = tfd.Gamma([0., 2.], 1., validate_args=False) samples = self.evaluate(gamma.sample(seed=seed_stream())) self.assertEqual(samples.shape, (2,)) + self.assertAllEqual([s in [0, np.finfo(np.float32).tiny] + for s in samples], [True, False]) + + gamma = tfd.Gamma([-0.001, 2.], 1., validate_args=False) + samples = self.evaluate(gamma.sample(seed=seed_stream())) + self.assertEqual(samples.shape, (2,)) self.assertAllEqual([np.isnan(s) for s in samples], [True, False]) gamma = tfd.Gamma([1., -1.], 1., validate_args=False) diff --git a/tensorflow_probability/python/distributions/hypothesis_testlib.py b/tensorflow_probability/python/distributions/hypothesis_testlib.py index 3247d8aef6..62f2a2a529 100644 --- a/tensorflow_probability/python/distributions/hypothesis_testlib.py +++ b/tensorflow_probability/python/distributions/hypothesis_testlib.py @@ -572,31 +572,6 @@ def stringify_slices(slices): return pretty_slices -@hps.composite -def broadcasting_params(draw, - dist_name, - batch_shape, - event_dim=None, - enable_vars=False): - """Strategy for drawing parameters broadcasting to `batch_shape`.""" - if dist_name not in INSTANTIABLE_BASE_DISTS: - raise ValueError('Unknown Distribution name {}'.format(dist_name)) - - params_event_ndims = INSTANTIABLE_BASE_DISTS[dist_name].params_event_ndims - - def _constraint(param): - return constraint_for(dist_name, param) - - return draw( - tfp_hps.broadcasting_params( - batch_shape, - params_event_ndims, - event_dim=event_dim, - enable_vars=enable_vars, - constraint_fn_for=_constraint, - mutex_params=MUTEX_PARAMS)) - - def prime_factors(v): """Compute the prime factors of v.""" factors = [] @@ -639,6 +614,7 @@ def base_distribution_unconstrained_params(draw, batch_shape=None, event_dim=None, enable_vars=False, + param_strategy_fn=None, params=None): """Strategy for drawing unconstrained parameters of a base Distribution. @@ -660,6 +636,10 @@ def base_distribution_unconstrained_params(draw, initialization in slicing_test. If `False`, the returned parameters are all `tf.Tensor`s and not {`tf.Variable`, `tfp.util.DeferredTensor` `tfp.util.TransformedVariable`}. + param_strategy_fn: Optional callable with signature + `strategy = param_strategy_fn(shape, dtype, constraint_fn)`. If provided, + overrides the default strategy for generating float-valued parameters. + Default value: `None`. params: An optional set of Distribution parameters. If params are not provided, Hypothesis will choose a set of parameters. @@ -675,11 +655,21 @@ def base_distribution_unconstrained_params(draw, batch_shape = draw(tfp_hps.shapes()) # Draw raw parameters + if dist_name not in INSTANTIABLE_BASE_DISTS: + raise ValueError('Unknown Distribution name {}'.format(dist_name)) + params_event_ndims = INSTANTIABLE_BASE_DISTS[dist_name].params_event_ndims + params_kwargs = draw( - broadcasting_params( - dist_name, batch_shape, event_dim=event_dim, enable_vars=enable_vars)) - hp.note('Forming dist {} with raw parameters {}'.format( - dist_name, params_kwargs)) + tfp_hps.broadcasting_params( + batch_shape, + params_event_ndims, + event_dim=event_dim, + enable_vars=enable_vars, + constraint_fn_for=lambda param: constraint_for(dist_name, param), + mutex_params=MUTEX_PARAMS, + param_strategy_fn=param_strategy_fn)) + hp.note('Forming dist {} with raw parameters {}'.format(dist_name, + params_kwargs)) return params_kwargs, batch_shape @@ -732,8 +722,9 @@ def base_distributions(draw, event_dim=None, enable_vars=False, eligibility_filter=lambda name: True, - validate_args=True, - params=None): + params=None, + param_strategy_fn=None, + validate_args=True): """Strategy for drawing arbitrary base Distributions. This does not draw compound distributions like `Independent`, @@ -756,9 +747,13 @@ def base_distributions(draw, `tfp.util.TransformedVariable`}. eligibility_filter: Optional Python callable. Blacklists some Distribution class names so they will not be drawn at the top level. - validate_args: Python `bool`; whether to enable runtime assertions. params: An optional set of Distribution parameters. If params are not provided, Hypothesis will choose a set of parameters. + param_strategy_fn: Optional callable with signature + `strategy = param_strategy_fn(shape, dtype, constraint_fn)`. If provided, + overrides the default strategy for generating float-valued parameters. + Default value: `None`. + validate_args: Python `bool`; whether to enable runtime assertions. Returns: dists: A strategy for drawing Distributions with the specified `batch_shape` @@ -780,13 +775,14 @@ class names so they will not be drawn at the top level. if params is None: params_unconstrained, batch_shape = draw( - base_distribution_unconstrained_params(dist_name, - batch_shape=batch_shape, - event_dim=event_dim, - enable_vars=enable_vars)) + base_distribution_unconstrained_params( + dist_name, + batch_shape=batch_shape, + event_dim=event_dim, + enable_vars=enable_vars, + param_strategy_fn=param_strategy_fn)) params = constrain_params(params_unconstrained, dist_name) - params = modify_params( - params, dist_name, validate_args=validate_args) + params = modify_params(params, dist_name, validate_args=validate_args) # Actually construct the distribution dist_cls = INSTANTIABLE_BASE_DISTS[dist_name].cls result_dist = dist_cls(**params) @@ -1436,8 +1432,12 @@ class names so they will not be drawn. or dist_name in INSTANTIABLE_BASE_DISTS or dist_name == 'Empirical'): return draw(base_distributions( - dist_name, batch_shape, event_dim, enable_vars, - eligibility_filter, validate_args)) + dist_name, + batch_shape=batch_shape, + event_dim=event_dim, + enable_vars=enable_vars, + eligibility_filter=eligibility_filter, + validate_args=validate_args)) if dist_name == 'BatchReshape': return draw(batch_reshapes( batch_shape, event_dim, enable_vars, depth, diff --git a/tensorflow_probability/python/distributions/independent.py b/tensorflow_probability/python/distributions/independent.py index 18f7cc4731..2d6f083d98 100644 --- a/tensorflow_probability/python/distributions/independent.py +++ b/tensorflow_probability/python/distributions/independent.py @@ -25,8 +25,9 @@ from tensorflow_probability.python import math as tfp_math from tensorflow_probability.python.distributions import distribution as distribution_lib from tensorflow_probability.python.distributions import kullback_leibler +from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import prefer_static +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.internal import tensorshape_util @@ -199,7 +200,7 @@ def __getitem__(self, slices): def _batch_shape_tensor(self): batch_shape = self.distribution.batch_shape_tensor() - batch_ndims = prefer_static.rank_from_shape( + batch_ndims = ps.rank_from_shape( batch_shape, self.distribution.batch_shape) return batch_shape[ :batch_ndims - self._get_reinterpreted_batch_ndims(batch_shape)] @@ -220,11 +221,11 @@ def _event_shape_tensor(self): batch_shape = self.distribution.batch_shape if not tensorshape_util.is_fully_defined(batch_shape): batch_shape = self.distribution.batch_shape_tensor() - batch_ndims = prefer_static.rank_from_shape(batch_shape) + batch_ndims = ps.rank_from_shape(batch_shape) event_shape = self.distribution.event_shape if not tensorshape_util.is_fully_defined(event_shape): event_shape = self.distribution.event_shape_tensor() - return prefer_static.concat([ + return ps.concat([ batch_shape[ batch_ndims - self._get_reinterpreted_batch_ndims(batch_shape):], event_shape, @@ -297,13 +298,13 @@ def _parameter_control_dependencies(self, is_init): assertions.append( assert_util.assert_less_equal( self._get_reinterpreted_batch_ndims(batch_shape_tensor), - prefer_static.rank_from_shape(batch_shape_tensor), + ps.rank_from_shape(batch_shape_tensor), message=('reinterpreted_batch_ndims cannot exceed ' 'distribution.batch_ndims'))) return assertions def _reduce(self, op, stat): - axis = 1 + prefer_static.range(self._get_reinterpreted_batch_ndims()) + axis = 1 + ps.range(self._get_reinterpreted_batch_ndims()) return op(stat, axis=-axis) _composite_tensor_nonshape_params = ('distribution',) @@ -372,10 +373,28 @@ def _kl_independent(a, b, name='kl_independent'): message='Event shapes do not match.'), ]): num_reduce_dims = ( - prefer_static.rank_from_shape( + ps.rank_from_shape( a_event_shape_tensor, a.event_shape) - - prefer_static.rank_from_shape( + ps.rank_from_shape( p_event_shape_tensor, p.event_shape)) - reduce_dims = prefer_static.range(-num_reduce_dims, 0, 1) + reduce_dims = ps.range(-num_reduce_dims, 0, 1) return tf.reduce_sum( kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims) + + +@log_prob_ratio.RegisterLogProbRatio(Independent) +def _independent_log_prob_ratio(p, x, q, y): + """Sum-of-diffs log(p(x)/q(y)) for `Independent`s.""" + checks = [] + if p.validate_args or q.validate_args: + checks.append(tf.debugging.assert_equal( + p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)) + if p._experimental_use_kahan_sum or q._experimental_use_kahan_sum: # pylint: disable=protected-access + sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total + else: + sum_fn = tf.reduce_sum + with tf.control_dependencies(checks): + return sum_fn( + log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y), + axis=-1 - ps.range(p.reinterpreted_batch_ndims)) + diff --git a/tensorflow_probability/python/distributions/independent_test.py b/tensorflow_probability/python/distributions/independent_test.py index 9959641ca6..2924bf0da5 100644 --- a/tensorflow_probability/python/distributions/independent_test.py +++ b/tensorflow_probability/python/distributions/independent_test.py @@ -27,11 +27,12 @@ from scipy import stats as sp_stats import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf -from tensorflow_probability.python import distributions as tfd +import tensorflow_probability as tfp from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util +tfd = tfp.distributions JAX_MODE = False @@ -522,6 +523,38 @@ def test_kahan_precision(self, jit=False): # Fails ~75% CPU, 1-75% GPU --vary_seed runs w/o experimental_use_kahan_sum. self.assertAllClose(lp64, lp, rtol=0., atol=.01) + def testLargeLogProbDiff(self): + b = 15 + n = 5_000 + d0 = tfd.Independent(tfd.Normal(tf.fill([b, n], 0.), tf.fill([n], .1)), + reinterpreted_batch_ndims=1, + experimental_use_kahan_sum=True) + d1 = tfd.Independent(tfd.Normal(tf.fill([b, n], 1e-5), tf.fill([n], .1)), + reinterpreted_batch_ndims=1, + experimental_use_kahan_sum=True) + strm = test_util.test_seed_stream() + x0 = self.evaluate( # overdispersed + tfd.Normal(0, 2).sample([b, n], seed=strm())) + x1 = self.evaluate( # overdispersed, perturbed + x0 + tfd.Normal(0, 1e-6).sample(x0.shape, seed=strm())) + d0_64 = d0.copy(distribution=tfd.Normal( + tf.cast(d0.distribution.loc, tf.float64), + tf.cast(d0.distribution.scale, tf.float64))) + d1_64 = d1.copy(distribution=tfd.Normal( + tf.cast(d1.distribution.loc, tf.float64), + tf.cast(d1.distribution.scale, tf.float64))) + self.assertNotAllZero(d0.log_prob(x0) < -1_000_000) + self.assertAllClose( + d0_64.log_prob(tf.cast(x0, tf.float64)) - + d1_64.log_prob(tf.cast(x1, tf.float64)), + tfp.experimental.distributions.log_prob_ratio(d0, x0, d1, x1), + rtol=0., atol=0.0075) + # In contrast: the below fails consistently w/ errors around 0.5-1.0 + # self.assertAllClose( + # d0_64.log_prob(tf.cast(x0, tf.float64)) - + # d1_64.log_prob(tf.cast(x1, tf.float64)), + # d0.log_prob(x0) - d1.log_prob(x1), + # rtol=0., atol=0.007) if __name__ == '__main__': # TODO(b/173158845): XLA:CPU reassociates away the Kahan correction term. diff --git a/tensorflow_probability/python/distributions/jax_transformation_test.py b/tensorflow_probability/python/distributions/jax_transformation_test.py index 72aee0c4b9..8a5183028c 100644 --- a/tensorflow_probability/python/distributions/jax_transformation_test.py +++ b/tensorflow_probability/python/distributions/jax_transformation_test.py @@ -80,6 +80,7 @@ JVP_SAMPLE_BLOCKLIST = () JVP_LOGPROB_SAMPLE_BLOCKLIST = ( + 'GeneralizedExtremeValue', # http://b/175654800 'Skellam', # http://b/171079052 ) JVP_LOGPROB_PARAM_BLOCKLIST = ( @@ -89,6 +90,7 @@ VJP_SAMPLE_BLOCKLIST = () VJP_LOGPROB_SAMPLE_BLOCKLIST = ( + 'GeneralizedExtremeValue', # http://b/175654800 'Skellam', # http://b/171079052 ) VJP_LOGPROB_PARAM_BLOCKLIST = ( diff --git a/tensorflow_probability/python/distributions/joint_distribution.py b/tensorflow_probability/python/distributions/joint_distribution.py index d7d8082a8e..c0ed4d6706 100644 --- a/tensorflow_probability/python/distributions/joint_distribution.py +++ b/tensorflow_probability/python/distributions/joint_distribution.py @@ -28,6 +28,7 @@ from tensorflow_probability.python.bijectors import composition from tensorflow_probability.python.bijectors import identity as identity_bijector from tensorflow_probability.python.distributions import distribution as distribution_lib +from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import docstring_util @@ -810,3 +811,15 @@ def _inverse(self, y, **kwargs): def _inverse_log_det_jacobian(self, y, event_ndims, **kwargs): return super(_DefaultJointBijector, self)._inverse_log_det_jacobian( y, event_ndims, _jd_conditioning=y, **kwargs) + + +@log_prob_ratio.RegisterLogProbRatio(JointDistribution) +def _jd_log_prob_ratio(p, x, q, y): + tf.nest.assert_same_structure(x, y) + ps, _ = p.sample_distributions(value=x) + qs, _ = q.sample_distributions(value=y) + tf.nest.assert_same_structure(ps, qs) + parts = [] + for p_, x_, q_, y_ in zip(ps, x, qs, y): + parts.append(log_prob_ratio.log_prob_ratio(p_, x_, q_, y_)) + return tf.add_n(parts) diff --git a/tensorflow_probability/python/distributions/log_prob_ratio.py b/tensorflow_probability/python/distributions/log_prob_ratio.py new file mode 100644 index 0000000000..214b0353e3 --- /dev/null +++ b/tensorflow_probability/python/distributions/log_prob_ratio.py @@ -0,0 +1,63 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Computes log-ratios of probs numerically stably.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import inspect + + +__all__ = [ + 'log_prob_ratio', + 'RegisterLogProbRatio', +] + + +_log_prob_ratio_registry = {} + + +def log_prob_ratio(p, x, q, y): + """Computes `p.log_prob(x) - q.log_prob(y)`, numerically stably. + + Args: + p: A distribution instance. + x: A tensor from the support of `p`. + q: A distribution instance in the same family as `p`, with matching shape. + y: A tensor from the support of `q`. + + Returns: + lp_ratio: `log (p(x) / q(y)) = p.log_prob(x) - q.log_prob(y)`. In some cases + this will be computed with better than naive numerical precision, e.g. by + moving the difference inside of a sum reduction. + """ + assert type(p) == type(q) # pylint: disable=unidiomatic-typecheck + for cls in inspect.getmro(type(p)): + if cls in _log_prob_ratio_registry: + return _log_prob_ratio_registry[cls](p, x, q, y) + return p.log_prob(x) - q.log_prob(y) + + +class RegisterLogProbRatio(object): + + def __init__(self, dist_family): + self.family = dist_family + + def __call__(self, fn): + assert self.family not in _log_prob_ratio_registry + _log_prob_ratio_registry[self.family] = fn + return fn + diff --git a/tensorflow_probability/python/distributions/mixture_same_family.py b/tensorflow_probability/python/distributions/mixture_same_family.py index 4c5c9d5f47..2d27201c6b 100644 --- a/tensorflow_probability/python/distributions/mixture_same_family.py +++ b/tensorflow_probability/python/distributions/mixture_same_family.py @@ -453,7 +453,7 @@ def _reparameterize_sample(self, x, event_shape): 3. Distributional transform currently only works for known rank of the batch tensor. - Arguments: + Args: x: Sample of mixture distribution event_shape: The event shape of this distribution @@ -512,7 +512,7 @@ def _distributional_transform(self, x, event_shape): w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1) and w_0^k = w_k is the mixture probability of the k-th component. - Arguments: + Args: x: Sample of mixture distribution event_shape: The event shape of this distribution @@ -682,7 +682,7 @@ def _prevent_2nd_derivative(x): NB: you need to apply a non-identity function to the output tensor for the exception to be raised. - Arguments: + Args: x: A tensor. Returns: diff --git a/tensorflow_probability/python/distributions/quantized_distribution.py b/tensorflow_probability/python/distributions/quantized_distribution.py index 2445917357..26059e9301 100644 --- a/tensorflow_probability/python/distributions/quantized_distribution.py +++ b/tensorflow_probability/python/distributions/quantized_distribution.py @@ -188,7 +188,7 @@ class QuantizedDistribution(distributions.Distribution): discretized_logistic_dist = tfd.QuantizedDistribution( distribution=tfd.TransformedDistribution( distribution=tfd.Logistic(loc=loc, scale=scale), - bijector=tfb.AffineScalar(shift=-0.5)), + bijector=tfb.Shift(shift=-0.5)), low=0., high=2**16 - 1.) mixture_dist = tfd.MixtureSameFamily( diff --git a/tensorflow_probability/python/distributions/sample.py b/tensorflow_probability/python/distributions/sample.py index 27bfcfa434..d0f1d6d240 100644 --- a/tensorflow_probability/python/distributions/sample.py +++ b/tensorflow_probability/python/distributions/sample.py @@ -29,6 +29,7 @@ from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.distributions import distribution as distribution_lib from tensorflow_probability.python.distributions import kullback_leibler +from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps @@ -236,7 +237,7 @@ def _sum_fn(self): return lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total return tf.math.reduce_sum - def _log_prob(self, x, **kwargs): + def _prepare_for_underlying(self, x): batch_ndims = ps.rank_from_shape( self.distribution.batch_shape_tensor, self.distribution.batch_shape) @@ -266,10 +267,12 @@ def _log_prob(self, x, **kwargs): ndims) perm = ps.concat( [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0) - x = tf.transpose(a=x, perm=perm) - # (3) Compute x's log_prob. - lp = self.distribution.log_prob(x, **kwargs) - # (4) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has + x = tf.transpose(x, perm=perm) + return x, (sample_ndims, extra_sample_ndims, batch_ndims) + + def _finish_log_prob(self, lp, aux): + (sample_ndims, extra_sample_ndims, batch_ndims) = aux + # (1) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has # full sample shape in the sample axes, before we reduce. bcast_lp_shape = ps.broadcast_shape( ps.shape(lp), @@ -277,10 +280,16 @@ def _log_prob(self, x, **kwargs): ps.reshape(self.sample_shape, shape=[-1]), ps.ones([batch_ndims], tf.int32)], axis=0)) lp = tf.broadcast_to(lp, bcast_lp_shape) - # (5) Make the final reduction in x. + # (2) Make the final reduction. axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) return self._sum_fn()(lp, axis=axis) + def _log_prob(self, x, **kwargs): + x, aux = self._prepare_for_underlying(x) + return self._finish_log_prob( + self.distribution.log_prob(x, **kwargs), + aux) + def _entropy(self, **kwargs): h = self.distribution.entropy(**kwargs) n = ps.reduce_prod(self.sample_shape) @@ -544,3 +553,18 @@ def _kl_sample(a, b, name='kl_sample'): a.distribution, b.distribution, name=name) n = ps.reduce_prod(a.sample_shape) return tf.cast(x=n, dtype=kl.dtype) * kl + + +@log_prob_ratio.RegisterLogProbRatio(Sample) +def _sample_log_prob_ratio(p, x, q, y): + checks = [] + if p.validate_args or q.validate_args: + checks.append(tf.debugging.assert_equal(p.sample_shape, q.sample_shape)) + with tf.control_dependencies(checks): + # pylint: disable=protected-access + x, aux = p._prepare_for_underlying(x) + y, _ = q._prepare_for_underlying(y) + return p._finish_log_prob( + log_prob_ratio.log_prob_ratio(p.distribution, x, q.distribution, y), + aux) + # pylint: enable=protected-access diff --git a/tensorflow_probability/python/distributions/sample_test.py b/tensorflow_probability/python/distributions/sample_test.py index 1d8c023d71..5a40fd9df9 100644 --- a/tensorflow_probability/python/distributions/sample_test.py +++ b/tensorflow_probability/python/distributions/sample_test.py @@ -25,10 +25,12 @@ from absl.testing import parameterized import numpy as np import tensorflow.compat.v2 as tf -from tensorflow_probability.python import bijectors as tfb -from tensorflow_probability.python import distributions as tfd +import tensorflow_probability as tfp from tensorflow_probability.python.internal import test_util +tfb = tfp.bijectors +tfd = tfp.distributions + JAX_MODE = False @@ -89,8 +91,7 @@ def test_kl_divergence(self): def test_transformed_affine(self): sample_shape = 3 mvn = tfd.Independent(tfd.Normal(loc=[0., 0], scale=1), 1) - aff = tfb.Affine(scale_tril=[[0.75, 0.], - [0.05, 0.5]]) + aff = tfb.ScaleMatvecTriL(scale_tril=[[0.75, 0.], [0.05, 0.5]]) def expected_lp(y): x = aff.inverse(y) # Ie, tf.random.normal([4, 3, 2]) @@ -448,6 +449,66 @@ def test_kahan_precision(self, jit=False): # Fails 75% CPU, 0-80% GPU --vary_seed runs w/o experimental_use_kahan_sum. self.assertAllClose(lp64, lp, rtol=0., atol=.01) + def testLargeLogProbDiffScalarUnderlying(self): + shp = [25, 200] + d0 = tfd.Sample(tfd.Normal(0., .1), shp) + d1 = tfd.Sample(tfd.Normal(1e-5, .1), shp) + strm = test_util.test_seed_stream() + x0 = self.evaluate( # overdispersed + tfd.Normal(0, 2).sample(shp, seed=strm())) + x1 = self.evaluate( # overdispersed, perturbed + x0 + tfd.Normal(0, 1e-6).sample(x0.shape, seed=strm())) + d0_64 = d0.copy(distribution=tfd.Normal( + tf.cast(d0.distribution.loc, tf.float64), + tf.cast(d0.distribution.scale, tf.float64))) + d1_64 = d1.copy(distribution=tfd.Normal( + tf.cast(d1.distribution.loc, tf.float64), + tf.cast(d1.distribution.scale, tf.float64))) + oracle_64 = tf.reduce_sum( + d0_64.distribution.log_prob(tf.cast(x0, tf.float64)) - + d1_64.distribution.log_prob(tf.cast(x1, tf.float64))) + self.assertAllClose( + oracle_64, + tfp.experimental.distributions.log_prob_ratio(d0, x0, d1, x1), + rtol=0., atol=0.007) + # In contrast: below fails with errors of ~0.07 - 0.15 + # self.assertAllClose( + # oracle_64, d0.log_prob(x0) - d1.log_prob(x1), rtol=0., atol=0.007) + + def testLargeLogProbDiffBatchOfVecUnderlying(self): + nsamp = 5 + nbatch = 3 + nevt = 250 + dim = 500 + d0 = tfd.Sample(tfd.MultivariateNormalDiag(tf.fill([nbatch, dim], 0.), + tf.fill([dim], .1)), + sample_shape=nevt) + self.assertEqual(tf.float32, d0.dtype) + d1 = tfd.Sample(tfd.MultivariateNormalDiag(tf.fill([nbatch, dim], 1e-5), + d0.distribution.scale.diag), + sample_shape=nevt) + strm = test_util.test_seed_stream() + x0 = self.evaluate( # overdispersed + tfd.Normal(0, 2).sample([nsamp, nbatch, nevt, dim], seed=strm())) + x1 = self.evaluate( # overdispersed + perturbed + x0 + tfd.Normal(0, 1e-6).sample(x0.shape, seed=strm())) + d0_64 = d0.copy(distribution=tfd.MultivariateNormalDiag( + tf.cast(d0.distribution.loc, tf.float64), + tf.cast(d0.distribution.scale.diag, tf.float64))) + d1_64 = d1.copy(distribution=tfd.MultivariateNormalDiag( + tf.cast(d1.distribution.loc, tf.float64), + tf.cast(d1.distribution.scale.diag, tf.float64))) + oracle_64 = (d0_64.log_prob(tf.cast(x0, tf.float64)) - + d1_64.log_prob(tf.cast(x1, tf.float64))) + self.assertNotAllZero(d0.log_prob(x0) < -10_000_000) + self.assertAllClose( + oracle_64, + tfp.experimental.distributions.log_prob_ratio(d0, x0, d1, x1), + rtol=0., atol=0.045) + # In contrast, the following fails w/ abs errors of ~5. to 10. + # self.assertAllClose( + # oracle_64, d0.log_prob(x0) - d1.log_prob(x1), rtol=0., atol=0.045) + if __name__ == '__main__': # TODO(b/173158845): XLA:CPU reassociates away the Kahan correction term. diff --git a/tensorflow_probability/python/distributions/sinh_arcsinh.py b/tensorflow_probability/python/distributions/sinh_arcsinh.py index c23443cd76..6cc740dbf6 100644 --- a/tensorflow_probability/python/distributions/sinh_arcsinh.py +++ b/tensorflow_probability/python/distributions/sinh_arcsinh.py @@ -19,10 +19,12 @@ from __future__ import print_function import tensorflow.compat.v2 as tf -from tensorflow_probability.python.bijectors import affine_scalar as affine_scalar_bijector from tensorflow_probability.python.bijectors import chain as chain_bijector from tensorflow_probability.python.bijectors import identity as identity_bijector +from tensorflow_probability.python.bijectors import scale as scale_bijector +from tensorflow_probability.python.bijectors import shift as shift_bijector from tensorflow_probability.python.bijectors import sinh_arcsinh as sinh_arcsinh_bijector + from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.distributions import transformed_distribution from tensorflow_probability.python.internal import distribution_util @@ -179,11 +181,8 @@ def __init__(self, validate_args=validate_args) # Make the AffineScalar bijector, Z --> loc + scale * Z (2 / F_0(2)) - affine = affine_scalar_bijector.AffineScalar( - shift=self._loc, - scale=self._scale, - validate_args=validate_args) - + affine = shift_bijector.Shift(shift=self._loc)( + scale_bijector.Scale(scale=self._scale)) bijector = chain_bijector.Chain([affine, f]) super(SinhArcsinh, self).__init__( diff --git a/tensorflow_probability/python/distributions/stopping_ratio_logistic.py b/tensorflow_probability/python/distributions/stopping_ratio_logistic.py new file mode 100644 index 0000000000..49d4d7266c --- /dev/null +++ b/tensorflow_probability/python/distributions/stopping_ratio_logistic.py @@ -0,0 +1,363 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The stopping ratio logistic distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python import math as tfp_math +from tensorflow_probability.python.distributions import categorical +from tensorflow_probability.python.distributions import distribution +from tensorflow_probability.python.distributions import kullback_leibler +from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import distribution_util +from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import prefer_static +from tensorflow_probability.python.internal import reparameterization +from tensorflow_probability.python.internal import tensor_util +from tensorflow_probability.python.internal import tensorshape_util + + +class StoppingRatioLogistic(distribution.Distribution): + """Stopping ratio logistic distribution. + + The StoppingRatioLogistic distribution is parameterized by a location and a + set of non-decreasing cutpoints. It is defined over the integers + `{0, 1, ..., K}` for `K` non-decreasing cutpoints. + + The difference to the OrderedLogistic is that categories can only be reached + one after another, i.e., sequentially. Specifically, while the probability + of an ordinal random variable `X` to be in category `c` + for the OrderedLogistic reads as + + ```none + P(X = c; cutpoints, loc) = P(X > c - 1) - P(X > c) + = sigmoid(loc - concat([-inf, cutpoints, inf])[c]) - + sigmoid(loc - concat([-inf, cutpoints, inf])[c + 1]) + ``` + + the StoppingRatioLogistic distribution models the probability of an ordinal + random variable `X` to be in category `c` given `X >= c` as + + ```none + P(X = c; X >= c, cutpoints, loc) = sigmoid(cutpoints[c] - loc) + ``` + + The sequential mechanism for `X` starts in category `c = 0` where a binary + decision between `c = 0` and `c > 0` is made: + + ```none + P(X = 0; cutpoints, loc) = sigmoid(cutpoints[0] - loc) + ``` + + If `X = 0`, the process stops. Otherwise the process continues with + + ```none + P(X = 1; X >= 1, cutpoints, loc) = sigmoid(cutpoints[1] - loc) + ``` + + The process continues to move on to higher level categories until it stops at + some category `X = c`. + + This distribution is useful for ordinal variables where lower categories + need to be reached first, for instance modelling the degree of a person + where the categories are `[Bachelor, Master, PhD]`. In order to obtain a PhD + title, first the degrees `Bachelor` and `Master` need to be reached. + + #### Mathematical Details + + The probability mass function (pmf) is + + ```none + pmf(x; cutpoints, loc) = + sigmoid(cutpoints[x] - loc) * + prod_{s=0}^{x - 1} (1 - sigmoid(cutpoints[s] - loc)) + ``` + + where `loc` is the location of a latent logistic distribution and + `cutpoints` define points to split up this latent distribution. + + #### Examples + + To expand on the `[Bachelor, Master, PhD]` from above, create a distribution + of three ordered categories: + + ```python + import tensorflow_probability as tfp + tfd = tfp.distributions + + dist = tfd.StoppingRatioLogistic(cutpoints=[-1.0, 1.0], loc=0.) + + dist.categorical_probs() + # ==> array([0.2689414 0.53444666 0.19661193], dtype=float32) + ``` + + Here, the probability of finishing one's education with a Bachelor would be + approx. 26% in this example, while the probability of continuing to pursue + a Master's would be approx. 53% and the probability of even attaining a PhD + would be 20%. + + Some further functionality: + + ```python + dist = tfd.StoppingRatioLogistic(cutpoints=[-2., 0., 2.], loc=0.) + + dist.prob([0, 3]) + # ==> array([0.11920291, 0.05249681], dtype=float32) + + dist.log_prob(1) + # ==> -0.82007515 + + dist.sample(3) + # ==> array([2, 1, 2], dtype=int32) + ``` + + """ + + def __init__( + self, + cutpoints, + loc, + dtype=tf.int32, + validate_args=False, + allow_nan_stats=True, + name='StoppingRatioLogistic', + ): + """Initialize Stopping Ratio Logistic distributions. + + Args: + cutpoints: A floating-point `Tensor` with shape `(K,)` where + `K` is the number of cutpoints. The vector of cutpoints should be + non-decreasing, which is only checked if `validate_args=True`. + loc: A floating-point `Tensor` with shape `()`. The entry represents the + mean of the latent logistic distribution. + dtype: The type of the event samples (default: int32). + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g. mode) use the value "`NaN`" to indicate the result is + undefined. When `False`, an exception is raised if one or more of the + statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + + parameters = dict(locals()) + + with tf.name_scope(name) as name: + + float_dtype = dtype_util.common_dtype( + [cutpoints, loc], + dtype_hint=tf.float32) + + self._cutpoints = tensor_util.convert_nonref_to_tensor( + cutpoints, dtype_hint=float_dtype, name='cutpoints') + self._loc = tensor_util.convert_nonref_to_tensor( + loc, dtype_hint=float_dtype, name='loc') + + super(StoppingRatioLogistic, self).__init__( + dtype=dtype, + reparameterization_type=reparameterization.NOT_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + name=name) + + @classmethod + def _params_event_ndims(cls): + return dict(cutpoints=1, loc=0) + + @staticmethod + def _param_shapes(sample_shape): + return dict( + zip(('loc', 'scale'), + ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 2))) + + @property + def cutpoints(self): + """Cutpoints param separating the latent distribution into categories.""" + return self._cutpoints + + @property + def loc(self): + """Mean parameter of the latent logistic distribution.""" + return self._loc + + def categorical_log_probs(self): + """Log probabilities for the `K + 1` sequential categories.""" + + cutpoints = tf.convert_to_tensor(self.cutpoints) + loc = tf.convert_to_tensor(self.loc) + num_cat = self._num_categories() + + # For the StoppingRatioLogistic, we have: + # P(X = c; X >= c, cutpoints, loc) = sigmoid(cutpoints[c] - loc) + # Given these conditional probabilities, we would like to retrieve + # P(X = c; cutpoints, loc). + # Let F(c) = P(X = c; X >= c, cutpoints, loc) and + # G(c) = P(X = c; cutpoints, loc) + + # Conditional probabilities. These are log(F(k)) and log(1 - F(k)) + conditional_log_probs = tf.math.log_sigmoid( + cutpoints - loc[..., tf.newaxis]) + conditional_log_probs_complement = tfp_math.log1mexp(conditional_log_probs) + + # Note that F(0) = G(0). + # G(1) = P(X = 1; cutpoints, loc) = + # P(X = 1; X >= 1, cutpoints, loc) * P(X >= 1) = F(1) * (1 - G(0)) + # G(2) = P(X = 2; cutpoints, loc) = + # P(X = 2; X >= 2, cutpoints, loc) * P(X >= 2) = F(2) * (1 - G(0) - G(1)) + # In general, G(k) = F(k) * (1 - \sum_{k-1} G(i)) + + # We rewrite this recurrence in terms of F(k) + # G(1) = F(1) * (1 - G(0)) = F(1) * (1 - F(0)) + # G(2) = F(2) * (1 - G(0) - G(1)) = (1 - F(0) - F(1) * (1 - F(0)) + # = F(2) * (1 - F(0)) * (1 - F(1)) + # G(k) = F(k) * \prod_{k-1} (1 - F(i)) + + # log(F(k)) + log(\prod (1 - F(i))) + categorical_log_probs = conditional_log_probs + tf.math.cumsum( + conditional_log_probs_complement[..., :(num_cat - 1)], + axis=-1, exclusive=True) + # Finally we need to handle the last category. + return tf.concat([ + categorical_log_probs, + tf.math.reduce_sum( + conditional_log_probs_complement[ + ..., :num_cat], axis=-1, keepdims=True)], axis=-1) + + def categorical_probs(self): + """Probabilities for the `K + 1` sequential categories.""" + return tf.math.exp(self.categorical_log_probs()) + + def _num_categories(self): + return prefer_static.shape(self.cutpoints, out_type=self.dtype)[-1] + 1 + + def _sample_n(self, n, seed=None): + return categorical.Categorical( + logits=self.categorical_log_probs()).sample(n, seed=seed) + + def _batch_shape_tensor(self, cutpoints=None, loc=None): + cutpoints = self.cutpoints if cutpoints is None else cutpoints + loc = self.loc if loc is None else loc + return prefer_static.broadcast_shape( + prefer_static.shape(cutpoints)[:-1], + prefer_static.shape(loc)) + + def _batch_shape(self): + return tf.broadcast_static_shape( + self.loc.shape, self.cutpoints.shape[:-1]) + + def _event_shape_tensor(self): + return tf.constant([], dtype=tf.int32) + + def _event_shape(self): + return tf.TensorShape([]) + + def _log_prob(self, x): + return categorical.Categorical( + logits=self.categorical_log_probs()).log_prob(x) + + def _cdf(self, x): + return categorical.Categorical( + logits=self.categorical_log_probs()).cdf(x) + + def _mode(self): + log_probs = self.categorical_log_probs() + mode = tf.argmax(log_probs, axis=-1, output_type=self.dtype) + tensorshape_util.set_shape(mode, log_probs.shape[:-1]) + return mode + + def _default_event_space_bijector(self): + return + + def _parameter_control_dependencies(self, is_init): + assertions = [] + + # In init, we can always build shape and dtype checks because + # we assume shape doesn't change for Variable backed args. + if is_init: + + if not dtype_util.is_floating(self.cutpoints.dtype): + raise TypeError('Argument `cutpoints` must having floating type.') + + if not dtype_util.is_floating(self.loc.dtype): + raise TypeError('Argument `loc` must having floating type.') + + cutpoint_dims = tensorshape_util.rank(self.cutpoints.shape) + msg = 'Argument `cutpoints` must have rank at least 1.' + if cutpoint_dims is not None: + if cutpoint_dims < 1: + raise ValueError(msg) + elif self.validate_args: + cutpoints = tf.convert_to_tensor(self.cutpoints) + assertions.append( + assert_util.assert_rank_at_least(cutpoints, 1, message=msg)) + + if not self.validate_args: + return [] + + if is_init != tensor_util.is_ref(self.cutpoints): + cutpoints = tf.convert_to_tensor(self.cutpoints) + assertions.append(distribution_util.assert_nondecreasing( + cutpoints, message='Argument `cutpoints` must be non-decreasing.')) + + return assertions + + def _sample_control_dependencies(self, x): + assertions = [] + if not self.validate_args: + return assertions + assertions.extend(distribution_util.assert_nonnegative_integer_form(x)) + assertions.append( + assert_util.assert_less_equal( + x, tf.cast(self._num_categories(), x.dtype), + message=('StoppingRatioLogistic samples must be `>= 0` and `<= K` ' + 'where `K` is the number of cutpoints.'))) + return assertions + + +@kullback_leibler.RegisterKL(StoppingRatioLogistic, StoppingRatioLogistic) +def _kl_stopping_ratio_logistic_stopping_ratio_logistic(a, b, name=None): + """Calculate the batched KL divergence KL(a || b), both StoppingRatioLogistic. + + This function utilises the `StoppingRatioLogistic` `categorical_log_probs` + member function to implement KL divergence for discrete probability + distributions as described in + e.g. [Wikipedia](https://en.wikipedia.org/wiki/Kullback-Leibler_divergence). + + Args: + a: instance of a StoppingRatioLogistic distribution object. + b: instance of a StoppingRatioLogistic distribution object. + name: Python `str` name to use for created operations. + Default value: `None` + + Returns: + Batchwise KL(a || b) + """ + with tf.name_scope(name or + 'kl_stopping_ratio_logistic_stopping_ratio_logistic'): + a_log_probs = a.categorical_log_probs() + b_log_probs = b.categorical_log_probs() + return tf.reduce_sum( + tf.math.multiply_no_nan( + tf.math.exp(a_log_probs), + a_log_probs - b_log_probs), + axis=-1) + diff --git a/tensorflow_probability/python/distributions/stopping_ratio_logistic_test.py b/tensorflow_probability/python/distributions/stopping_ratio_logistic_test.py new file mode 100644 index 0000000000..419548bbb1 --- /dev/null +++ b/tensorflow_probability/python/distributions/stopping_ratio_logistic_test.py @@ -0,0 +1,142 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +# Dependency imports +from absl.testing import parameterized +import numpy as np +import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp + +from tensorflow_probability.python.internal import test_util + +tfd = tfp.distributions +tfb = tfp.bijectors + + +@test_util.test_all_tf_execution_regimes +class StoppingRatioLogisticTest(test_util.TestCase): + + def _random_cutpoints(self, shape): + return self._ordered.inverse(self._rng.randn(*shape)) + + def _random_location(self, shape): + return self._rng.randn(*shape) + + def _random_rvs(self, shape): + return self._rng.multinomial(1, *shape) + + def setUp(self): + self._ordered = tfb.Ordered() + self._rng = np.random.RandomState(test_util.test_seed()) + super(StoppingRatioLogisticTest, self).setUp() + + @parameterized.parameters( + itertools.product(['cutpoints', 'loc', 'both'], [[], [1], [1, 2, 3]]) + ) + def testBatchShapes(self, test, batch_shape): + if test == 'cutpoints': + cutpoints = self._random_cutpoints(batch_shape + [2]) + loc = tf.constant(0., dtype=tf.float64) + elif test == 'loc': + cutpoints = tf.constant([1., 2.], dtype=tf.float64) + loc = self._random_location(batch_shape) + elif test == 'both': + cutpoints = self._random_cutpoints(batch_shape + [2]) + loc = self._random_location(batch_shape) + + dist = tfd.StoppingRatioLogistic(cutpoints=cutpoints, loc=loc) + + self.assertAllEqual(dist.batch_shape, batch_shape) + self.assertAllEqual( + self.evaluate(dist.batch_shape_tensor()), batch_shape) + + self.assertAllEqual(dist.event_shape, []) + self.assertAllEqual(self.evaluate(dist.event_shape_tensor()), []) + + categorical_probs = dist.categorical_probs() + categorical_probs_shape = tf.shape(categorical_probs) + self.assertAllEqual( + self.evaluate(categorical_probs_shape), batch_shape + [3]) + + samples = dist.sample(seed=test_util.test_seed()) + sample_shape = tf.shape(samples) + self.assertAllEqual(self.evaluate(sample_shape), batch_shape) + + probs = dist.prob(samples) + probs_shape = tf.shape(probs) + self.assertAllEqual(self.evaluate(probs_shape), batch_shape) + + samples = dist.sample([4, 5], seed=test_util.test_seed()) + sample_shape_n = tf.shape(samples) + self.assertAllEqual(self.evaluate(sample_shape_n), [4, 5] + batch_shape) + + probs = dist.prob(samples) + probs_shape = tf.shape(probs) + self.assertAllEqual(self.evaluate(probs_shape), [4, 5] + batch_shape) + + mode = dist.mode() + mode_shape = tf.shape(mode) + self.assertAllEqual(self.evaluate(mode_shape), batch_shape) + + def testProbs(self): + expected_probs = [0.11920291, 0.44039854, 0.38790172, 0.05249681] + dist = tfd.StoppingRatioLogistic(cutpoints=[-2., 0., 2.], loc=0.) + + categorical_probs = self.evaluate(dist.categorical_probs()) + self.assertAllClose(expected_probs, categorical_probs, atol=1e-4) + + probs = self.evaluate(dist.prob([0, 1, 2, 3])) + self.assertAllClose(expected_probs, probs, atol=1e-4) + + def testMode(self): + dist = tfd.StoppingRatioLogistic(cutpoints=[-10., 10.], loc=[-20., 0., 20.]) + mode = self.evaluate(dist.mode()) + self.assertAllEqual([0, 1, 2], mode) + + def testSample(self): + dist = tfd.StoppingRatioLogistic(cutpoints=[-1., 0., 1.], loc=0.) + samples = self.evaluate(dist.sample(int(1e5), seed=test_util.test_seed())) + expected_probs = [0.2689414, 0.3655293, 0.26722333, 0.09830596] + for k, p in enumerate(expected_probs): + self.assertAllClose(np.mean(samples == k), p, atol=0.01) + + def testKLAgainstSampling(self): + a_cutpoints = self._random_cutpoints([4]) + b_cutpoints = self._random_cutpoints([4]) + loc = self._random_location([]) + + a = tfd.StoppingRatioLogistic(cutpoints=a_cutpoints, loc=loc) + b = tfd.StoppingRatioLogistic(cutpoints=b_cutpoints, loc=loc) + + samples = a.sample(int(1e5), seed=test_util.test_seed()) + sampled_kl = self.evaluate( + tf.reduce_mean(a.log_prob(samples) - b.log_prob(samples))) + kl = self.evaluate(tfd.kl_divergence(a, b)) + + self.assertAllClose(sampled_kl, kl, atol=2e-2) + + def testUnorderedCutpointsFails(self): + with self.assertRaisesRegexp( + ValueError, 'Argument `cutpoints` must be non-decreasing.'): + dist = tfd.StoppingRatioLogistic( + cutpoints=[1., 0.9], loc=0.0, validate_args=True) + self.evaluate(dist.mode()) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_probability/python/distributions/transformed_distribution.py b/tensorflow_probability/python/distributions/transformed_distribution.py index 68daa369a7..a0f0d208db 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution.py +++ b/tensorflow_probability/python/distributions/transformed_distribution.py @@ -21,8 +21,10 @@ import tensorflow.compat.v2 as tf +from tensorflow_probability.python.bijectors import ldj_ratio from tensorflow_probability.python.distributions import distribution as distribution_lib from tensorflow_probability.python.distributions import kullback_leibler +from tensorflow_probability.python.distributions import log_prob_ratio from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util @@ -134,9 +136,7 @@ class TransformedDistribution(distribution_lib.Distribution): tfb = tfp.bijectors normal = tfd.TransformedDistribution( distribution=tfd.Normal(loc=0., scale=1.), - bijector=tfb.Affine( - shift=-1., - scale_identity_multiplier=2.) + bijector=tfb.Shift(shift=-1.)(tfb.Scale(scale=2.)), name='NormalTransformedDistribution') ``` @@ -289,12 +289,11 @@ def _batch_shape_tensor(self): # dtype.) if tf.nest.is_nested(base_batch_shape_tensor): if self._is_joint: - return base_batch_shape_tensor - + return tf.nest.pack_sequence_as( + self.dtype, tf.nest.flatten(base_batch_shape_tensor)) base_batch_shape_tensor = functools.reduce( ps.broadcast_shape, tf.nest.flatten(base_batch_shape_tensor)) - return base_batch_shape_tensor def _batch_shape(self): @@ -308,7 +307,10 @@ def _batch_shape(self): # the batch shape components of the base distribution are broadcast to # obtain the batch shape of the transformed distribution. batch_shape = self.distribution.batch_shape - if tf.nest.is_nested(batch_shape) and not self._is_joint: + if tf.nest.is_nested(batch_shape): + if self._is_joint: + return tf.nest.pack_sequence_as( + self.dtype, tf.nest.flatten(batch_shape)) batch_shape = functools.reduce( tf.broadcast_static_shape, tf.nest.flatten(batch_shape)) return batch_shape @@ -569,3 +571,22 @@ def _kl_transformed_transformed(a, b, name=None): 'Unable to calculate KL divergence between {} and {} because ' 'their bijectors are not equal: {} vs. {}'.format( a, b, a.bijector, b.bijector)) + + +@log_prob_ratio.RegisterLogProbRatio(TransformedDistribution) +def _transformed_log_prob_ratio(p, x, q, y): + """Computes p.log_prob(x) - q.log_prob(y) for p and q both TDs.""" + x_ = p.bijector.inverse(x) + y_ = q.bijector.inverse(y) + + base_log_prob_ratio = log_prob_ratio.log_prob_ratio( + p.distribution, x_, q.distribution, y_) + + event_ndims = tf.nest.map_structure( + ps.rank_from_shape, + p.event_shape_tensor, + tf.nest.map_structure(tensorshape_util.merge_with, + p.event_shape, q.event_shape)) + ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( + p.bijector, x, q.bijector, y, event_ndims) + return base_log_prob_ratio + tf.cast(ildj_ratio, base_log_prob_ratio.dtype) diff --git a/tensorflow_probability/python/distributions/transformed_distribution_test.py b/tensorflow_probability/python/distributions/transformed_distribution_test.py index 04e3ea43e6..66bf95f0a4 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution_test.py +++ b/tensorflow_probability/python/distributions/transformed_distribution_test.py @@ -477,6 +477,42 @@ def testTransformedNormalNormalKL(self): self.assertAllClose(kl_val, kl_expected) self.assertAllClose(kl_expected, kl_sample_, atol=0.0, rtol=1e-2) + def testLogProbRatio(self): + nsamp = 5 + nbatch = 3 + dim = 5000 + d0 = tfd.MultivariateNormalDiag(tf.fill([nbatch, dim], 0.), + tf.fill([dim], .1)) + d1 = tfd.MultivariateNormalDiag(tf.fill([nbatch, dim], 1e-5), + d0.scale.diag) + strm = test_util.test_seed_stream() + x0 = self.evaluate( # overdispersed + tfd.Normal(0, 2).sample([nsamp, nbatch, dim], seed=strm())) + x1 = self.evaluate( # overdispersed + perturbed + x0 + tfd.Normal(0, 1e-6).sample(x0.shape, seed=strm())) + d0_64 = tfd.MultivariateNormalDiag( + tf.cast(d0.loc, tf.float64), tf.cast(d0.scale.diag, tf.float64)) + d1_64 = tfd.MultivariateNormalDiag( + tf.cast(d1.loc, tf.float64), tf.cast(d1.scale.diag, tf.float64)) + oracle_64 = (d0_64.log_prob(tf.cast(x0, tf.float64)) - + d1_64.log_prob(tf.cast(x1, tf.float64))) + # For a sense of the order of magnitude log_probs we're dealing with: + self.assertNotAllZero(d0.log_prob(x0) < -1_000_000.) + self.assertAllClose( + oracle_64, + tfp.experimental.distributions.log_prob_ratio(d0, x0, d1, x1), + rtol=0., atol=0.007) + # In contrast, this test fails with max-abs-error around 0.05 to 0.1 + # self.assertAllClose( + # oracle_64, + # d0.copy(experimental_use_kahan_sum=True).log_prob(x0) - + # d1.copy(experimental_use_kahan_sum=True).log_prob(x1), + # rtol=0., atol=0.007) + # In contrast, this test fails with max-abs-error around 0.8 to 1.5 + # self.assertAllClose( + # oracle_64, d0.log_prob(x0) - d1.log_prob(x1), + # rtol=0., atol=0.007) + @test_util.test_all_tf_execution_regimes class ScalarToMultiTest(test_util.TestCase): @@ -1047,6 +1083,38 @@ def test_transform_joint_to_joint(self, split_sizes): self.assertAllEqual(tf.nest.map_structure(lambda y: y.shape, y), tf.nest.map_structure(lambda y: y.shape, y_sampled)) + # Test that a `Restructure` bijector applied to a `JointDistribution` works + # as expected. + num_components = len(split_sizes) + input_keys = (split_sizes.keys() if isinstance(split_sizes, dict) + else range(num_components)) + output_keys = [str(i) for i in range(num_components)] + output_structure = {k: v for k, v in zip(output_keys, input_keys)} + restructure = tfb.Restructure(output_structure) + restructured_dist = tfd.TransformedDistribution( + base_dist, bijector=restructure, validate_args=True) + + # Check that attributes of the restructured distribution have the same + # nested structure as the `output_structure` of the bijector. Pass a no-op + # as the `assert_fn` since the contents of the structures are not + # required to be the same. + noop_assert_fn = lambda *_: None + self.assertAllAssertsNested( + noop_assert_fn, restructured_dist.event_shape, output_structure) + self.assertAllAssertsNested( + noop_assert_fn, restructured_dist.batch_shape, output_structure) + self.assertAllAssertsNested( + noop_assert_fn, + self.evaluate(restructured_dist.event_shape_tensor()), + output_structure) + self.assertAllAssertsNested( + noop_assert_fn, + self.evaluate(restructured_dist.batch_shape_tensor()), + output_structure) + self.assertAllAssertsNested( + noop_assert_fn, + self.evaluate(restructured_dist.sample(seed=test_util.test_seed()))) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/distributions/vector_exponential_diag_test.py b/tensorflow_probability/python/distributions/vector_exponential_diag_test.py index 141fad12fe..c26ce5baa1 100644 --- a/tensorflow_probability/python/distributions/vector_exponential_diag_test.py +++ b/tensorflow_probability/python/distributions/vector_exponential_diag_test.py @@ -83,8 +83,8 @@ def testAssertValidSample(self): def testSingularScaleRaises(self): mu = [-1., 1] diag = [1., 0] - dist = tfd.VectorExponentialDiag(mu, diag, validate_args=True) with self.assertRaisesOpError('Singular'): + dist = tfd.VectorExponentialDiag(mu, diag, validate_args=True) self.evaluate(dist.sample(seed=test_util.test_seed())) def testSampleWithBroadcastScale(self): diff --git a/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py b/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py index 0f0396f7a9..e24f2b2211 100644 --- a/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py +++ b/tensorflow_probability/python/distributions/vector_exponential_linear_operator.py @@ -19,9 +19,8 @@ from __future__ import print_function import tensorflow.compat.v2 as tf -from tensorflow_probability.python.bijectors import affine_linear_operator as affine_linear_operator_bijector from tensorflow_probability.python.bijectors import chain as chain_bijector -from tensorflow_probability.python.bijectors import scale_matvec_linear_operator as scale_matvec_linear_operator_bijector +from tensorflow_probability.python.bijectors import scale_matvec_linear_operator from tensorflow_probability.python.bijectors import shift as shift_bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.distributions import exponential @@ -181,6 +180,8 @@ def __init__(self, TypeError: if not `scale.dtype.is_floating` """ parameters = dict(locals()) + if loc is None: + loc = 0.0 # Implicit value for backwards compatibility. if scale is None: raise ValueError('Missing required `scale` parameter.') if not dtype_util.is_floating(scale.dtype): @@ -193,7 +194,8 @@ def __init__(self, loc, name='loc', dtype=scale.dtype) batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( loc, scale) - + self._loc = loc + self._scale = scale super(VectorExponentialLinearOperator, self).__init__( # TODO(b/137665504): Use batch-adding meta-distribution to set the # batch shape instead of tf.ones. @@ -206,8 +208,9 @@ def __init__(self, rate=tf.ones(batch_shape, dtype=scale.dtype), allow_nan_stats=allow_nan_stats), event_shape), - bijector=affine_linear_operator_bijector.AffineLinearOperator( - shift=loc, scale=scale, validate_args=validate_args), + bijector=shift_bijector.Shift(shift=loc)( + scale_matvec_linear_operator.ScaleMatvecLinearOperator( + scale=scale, validate_args=validate_args)), validate_args=validate_args, name=name) self._parameters = parameters @@ -215,12 +218,12 @@ def __init__(self, @property def loc(self): """The `loc` `Tensor` in `Y = scale @ X + loc`.""" - return self.bijector.shift + return self._loc @property def scale(self): """The `scale` `LinearOperator` in `Y = scale @ X + loc`.""" - return self.bijector.scale + return self._scale @distribution_util.AppendDocstring(_mvn_sample_note) def _log_prob(self, x): @@ -236,7 +239,7 @@ def _mean(self): # Then this distribution is # X = loc + LW, # and then E[X] = loc + L1, where 1 is the vector of ones. - scale_x_ones = self.bijector.scale.matvec( + scale_x_ones = self.scale.matvec( tf.ones(self._mode_mean_shape(), self.dtype)) if self.loc is None: @@ -279,7 +282,7 @@ def _stddev(self): self.scale.matmul(self.scale.to_dense(), adjoint_arg=True))) def _mode(self): - scale_x_zeros = self.bijector.scale.matvec( + scale_x_zeros = self.scale.matvec( tf.zeros(self._mode_mean_shape(), self.dtype)) if self.loc is None: @@ -311,7 +314,7 @@ def _sample_control_dependencies(self, x): def _default_event_space_bijector(self): return chain_bijector.Chain([ shift_bijector.Shift(shift=self.loc, validate_args=self.validate_args), - scale_matvec_linear_operator_bijector.ScaleMatvecLinearOperator( + scale_matvec_linear_operator.ScaleMatvecLinearOperator( scale=self.scale, validate_args=self.validate_args), softplus_bijector.Softplus(validate_args=self.validate_args) ], validate_args=self.validate_args) diff --git a/tensorflow_probability/python/distributions/von_mises.py b/tensorflow_probability/python/distributions/von_mises.py index b712f6a896..0a64dab963 100644 --- a/tensorflow_probability/python/distributions/von_mises.py +++ b/tensorflow_probability/python/distributions/von_mises.py @@ -371,7 +371,7 @@ def von_mises_cdf(x, concentration): using automatic differentiation. We use forward mode for the series case (which allows to save memory) and backward mode for the Normal approximation. - Arguments: + Args: x: The point at which to evaluate the CDF. concentration: The concentration parameter of the von Mises distribution. @@ -498,7 +498,7 @@ def cdf_func(concentration): def _von_mises_sample_no_gradient(shape, concentration, seed): """Performs rejection sampling for standardized von Mises. - Arguments: + Args: shape: The output sample shape. concentration: The concentration parameter of the distribution. seed: The random seed. @@ -641,7 +641,7 @@ def _von_mises_sample_jvp(shape, primals, tangents): def _von_mises_sample_with_gradient(shape, concentration, seed): """Performs rejection sampling for standardized von Mises. - Arguments: + Args: shape: The output sample shape. concentration: The concentration parameter of the distribution. seed: (optional) The random seed. @@ -662,7 +662,7 @@ def random_von_mises(shape, concentration, dtype=tf.float32, seed=None): The sampling algorithm is rejection sampling with wrapped Cauchy proposal [1]. The samples are pathwise differentiable using the approach of [2]. - Arguments: + Args: shape: The output sample shape. concentration: The concentration parameter of the von Mises distribution. dtype: The data type of concentration and the outputs. diff --git a/tensorflow_probability/python/distributions/zipf.py b/tensorflow_probability/python/distributions/zipf.py index ec18c6c6e3..92f7d49dae 100644 --- a/tensorflow_probability/python/distributions/zipf.py +++ b/tensorflow_probability/python/distributions/zipf.py @@ -353,7 +353,7 @@ def _hat_integral(self, x, power): pmf. This function implements `hat` integral: H(x) = int_x^inf h(t) dt; which is needed for sampling purposes. - Arguments: + Args: x: A Tensor of points x at which to evaluate H(x). power: Power that parameterized hat function. diff --git a/tensorflow_probability/python/experimental/BUILD b/tensorflow_probability/python/experimental/BUILD index 42f2e5e98a..414ce58fdc 100644 --- a/tensorflow_probability/python/experimental/BUILD +++ b/tensorflow_probability/python/experimental/BUILD @@ -33,11 +33,13 @@ exports_files(["LICENSE"]) multi_substrate_py_library( name = "experimental", srcs = ["__init__.py"], + numpy_omit_deps = [ + "//tensorflow_probability/python/experimental/distribute", + ], srcs_version = "PY3", substrates_omit_deps = [ ":composite_tensor", "//tensorflow_probability/python/experimental/auto_batching", - "//tensorflow_probability/python/experimental/distribute", "//tensorflow_probability/python/experimental/lazybones", "//tensorflow_probability/python/experimental/linalg", "//tensorflow_probability/python/experimental/marginalize", diff --git a/tensorflow_probability/python/experimental/bijectors/BUILD b/tensorflow_probability/python/experimental/bijectors/BUILD index 96ed131f8a..fd4076247d 100644 --- a/tensorflow_probability/python/experimental/bijectors/BUILD +++ b/tensorflow_probability/python/experimental/bijectors/BUILD @@ -36,6 +36,7 @@ multi_substrate_py_library( srcs_version = "PY3", deps = [ ":scalar_function_with_inferred_inverse", + "//tensorflow_probability/python/bijectors:ldj_ratio", ], ) diff --git a/tensorflow_probability/python/experimental/bijectors/__init__.py b/tensorflow_probability/python/experimental/bijectors/__init__.py index 81650620f0..baf5f4e9d7 100644 --- a/tensorflow_probability/python/experimental/bijectors/__init__.py +++ b/tensorflow_probability/python/experimental/bijectors/__init__.py @@ -14,8 +14,10 @@ # ============================================================================ """TensorFlow Probability experimental bijectors package.""" +from tensorflow_probability.python.bijectors.ldj_ratio import inverse_log_det_jacobian_ratio from tensorflow_probability.python.experimental.bijectors.scalar_function_with_inferred_inverse import ScalarFunctionWithInferredInverse __all__ = [ - 'ScalarFunctionWithInferredInverse' + 'inverse_log_det_jacobian_ratio', + 'ScalarFunctionWithInferredInverse', ] diff --git a/tensorflow_probability/python/experimental/distribute/BUILD b/tensorflow_probability/python/experimental/distribute/BUILD index dd24845a69..610f32e50f 100644 --- a/tensorflow_probability/python/experimental/distribute/BUILD +++ b/tensorflow_probability/python/experimental/distribute/BUILD @@ -14,6 +14,11 @@ # ============================================================================ # Description: # Contains utilities for writing distributed TFP code. +load( + "//tensorflow_probability/python:build_defs.bzl", + "multi_substrate_py_library", + "multi_substrate_py_test", +) licenses(["notice"]) @@ -23,10 +28,10 @@ package( ], ) -py_library( +multi_substrate_py_library( name = "distribute", srcs = ["__init__.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":distribute_lib", ":joint_distribution", @@ -34,7 +39,7 @@ py_library( ], ) -py_library( +multi_substrate_py_library( name = "distribute_lib", srcs = ["distribute_lib.py"], srcs_version = "PY3", @@ -43,7 +48,7 @@ py_library( ], ) -py_library( +multi_substrate_py_library( name = "sharded", srcs = ["sharded.py"], deps = [ @@ -55,7 +60,7 @@ py_library( ], ) -py_library( +multi_substrate_py_library( name = "joint_distribution", srcs = ["joint_distribution.py"], deps = [ @@ -66,39 +71,59 @@ py_library( ], ) -py_test( +multi_substrate_py_library( + name = "distribute_test_lib", + testonly = 1, + srcs = ["distribute_test_lib.py"], + srcs_version = "PY3", + deps = [ + # tensorflow dep, + "//tensorflow_probability/python/internal:test_util", + ], +) + +multi_substrate_py_test( name = "sharded_test", srcs = ["sharded_test.py"], + disabled_substrates = ["numpy"], python_version = "PY3", deps = [ + ":distribute_lib", + ":distribute_test_lib", ":sharded", # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/internal:test_util", ], ) -py_test( +multi_substrate_py_test( name = "joint_distribution_test", srcs = ["joint_distribution_test.py"], + disabled_substrates = ["numpy"], python_version = "PY3", deps = [ + ":distribute_test_lib", ":joint_distribution", ":sharded", # absl/testing:parameterized dep, # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/internal:test_util", ], ) -py_test( +multi_substrate_py_test( name = "distribute_lib_test", srcs = ["distribute_lib_test.py"], + disabled_substrates = ["numpy"], python_version = "PY3", srcs_version = "PY3", deps = [ ":distribute_lib", + ":distribute_test_lib", # tensorflow dep, "//tensorflow_probability", "//tensorflow_probability/python/internal:test_util", diff --git a/tensorflow_probability/python/experimental/distribute/distribute_lib.py b/tensorflow_probability/python/experimental/distribute/distribute_lib.py index bab1b6b37e..574c5e489f 100644 --- a/tensorflow_probability/python/experimental/distribute/distribute_lib.py +++ b/tensorflow_probability/python/experimental/distribute/distribute_lib.py @@ -19,25 +19,66 @@ from __future__ import print_function import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient +JAX_MODE = False -def psum(x): +if JAX_MODE: + import jax # pylint: disable=g-import-not-at-top + from jax import lax # pylint: disable=g-import-not-at-top + + +def psum(x, axis_name=None): + if JAX_MODE: + return lax.psum(x, axis_name) ctx = tf.distribute.get_replica_context() return ctx.all_reduce('sum', x) -def pmean(x): +def pmean(x, axis_name=None): + if JAX_MODE: + return lax.pmean(x, axis_name) ctx = tf.distribute.get_replica_context() return ctx.all_reduce('mean', x) +def get_replica_id(axis_name=None): + if JAX_MODE: + return lax.axis_index(axis_name) + ctx = tf.distribute.get_replica_context() + return ctx.replica_id_in_sync_group + + +def get_num_replicas(axis_name=None): + if JAX_MODE: + return lax.psum(1, axis_name) + ctx = tf.distribute.get_replica_context() + return ctx.num_replicas_in_sync + + class _DummyGrads(object): + """Wraps gradients to preserve structure when computing a custom gradient.""" def __init__(self, grads): self.grads = grads + def tree_flatten(self): + return (self.grads,), () + + @classmethod + def tree_unflatten(cls, _, xs): + return cls(*xs) + + def __repr__(self): + return f'_DummyGrads({self.grads})' + + +if JAX_MODE: + from jax import tree_util # pylint: disable=g-import-not-at-top + tree_util.register_pytree_node_class(_DummyGrads) -def make_sharded_log_prob_parts(log_prob_parts_fn, is_sharded): + +def make_sharded_log_prob_parts(log_prob_parts_fn, is_sharded, axis_name=None): """Constructs a log prob parts function that all-reduces over terms. Given a log_prob_parts function, this function will return a new one that @@ -55,81 +96,116 @@ def make_sharded_log_prob_parts(log_prob_parts_fn, is_sharded): add an all-reduce sum for its term in the log prob calculation. If it is `False`, the returned function will have an all-reduce sum over the gradient of sharded terms w.r.t. to the unsharded value. + axis_name: a `str` used for the axis name in the JAX backend. Unused in the + TensorFlow backend. Returns: A new log prob parts function that can be run inside of strategy. """ - @tf.custom_gradient - def sharded_log_prob_parts(value): + def _sharded_log_prob_parts_fwd(value): tf.nest.assert_same_structure(value, is_sharded) - with tf.GradientTape(persistent=True) as tape: - tape.watch(value) + if JAX_MODE: + def flat_log_prob_parts_fn(flat_args): + args = tf.nest.pack_sequence_as(is_sharded, flat_args) + log_prob_parts = log_prob_parts_fn(args) + return tf.nest.flatten(log_prob_parts) + + def wrapped_log_prob(value): + flat_sharded = tf.nest.flatten(is_sharded) + return tf.nest.pack_sequence_as( + is_sharded, + [ + _DummyGrads(tf.nest.pack_sequence_as(is_sharded, [ # pylint: disable=g-complex-comprehension + jax.grad(lambda v: flat_log_prob_parts_fn(v)[i]) # pylint: disable=cell-var-from-loop + (tf.nest.flatten(value))[j] + for i in range(len(flat_sharded)) + ])) + for j in range(len(flat_sharded)) + ]) + log_prob_parts = log_prob_parts_fn(value) - tf.nest.assert_same_structure(log_prob_parts, is_sharded) + local_grads = wrapped_log_prob(value) + else: + with tf.GradientTape(persistent=True) as tape: + tape.watch(value) + log_prob_parts = log_prob_parts_fn(value) + tf.nest.assert_same_structure(log_prob_parts, is_sharded) + + def local_grad(v): + return _DummyGrads( + tf.nest.map_structure( + lambda log_prob_part: tape.gradient(log_prob_part, v), + log_prob_parts)) + local_grads = tf.nest.map_structure(local_grad, value) total_log_prob_parts = tf.nest.map_structure( lambda log_prob_part, sharded: ( # pylint: disable=g-long-lambda - psum(log_prob_part) if sharded else log_prob_part), + psum(log_prob_part, axis_name=axis_name) + if sharded else log_prob_part), log_prob_parts, is_sharded) - def vjp(*gs): - gs = tf.nest.pack_sequence_as(log_prob_parts, gs) - - def local_grad(v, g): - return _DummyGrads( - tf.nest.map_structure( - lambda lp: tape.gradient(lp, v, output_gradients=g), - log_prob_parts)) - - local_grads = tf.nest.map_structure(local_grad, value, gs) - - def value_grad(v, value_sharded, term_grads): - """Computes reductions of output gradients. - - A `log_prob_parts` function takes in a list of values and outputs - a log density for each input to the function. The vector-Jacobian - product (VJP) of a `log_prob_parts` function thus needs to compute the - gradient of each output term w.r.t. each input value. This function - overrides the default VJP of an output term `j` w.r.t to an input - value `i` to include an all-reduce-sum when: - 1) The gradient of `j` w.r.t. `i` is connected. - 2) `j` is a sharded term and `i` is an unsharded value. - - If these conditions do not hold, the gradient remains the same and - either corresponds to: - 1) The gradient of a sharded term w.r.t to a sharded value - 2) The gradient of an unsharded term w.r.t. to an unsharded value. - 3) The gradient of an unsharded term w.r.t. to an sharded value. - In any of these cases, no all-reduce-sum is necessary. - Args: - v: The output term of a `log_prob_part` function. - value_sharded: A boolean indicating whether or not the output term is - is sharded or not. - term_grads: The gradient of the output term w.r.t. to each of the - input values to the `log_prob_part` function. - Returns: - The vector Jacobian product of `v` w.r.t. the input parts of the - `log_prob_parts` function. - """ - term_grads = term_grads.grads - - def psum_grads(term_grad, term_sharded): - if term_grad is not None: - if not value_sharded and term_sharded: - term_grad = psum(term_grad) - return term_grad - - total_grad = tf.nest.map_structure(psum_grads, term_grads, - is_sharded) - if all([grad is None for grad in tf.nest.flatten(total_grad)]): - return None - return tf.add_n( - [v for v in tf.nest.flatten(total_grad) if v is not None]) - - return tf.nest.map_structure(value_grad, value, is_sharded, local_grads) - - return total_log_prob_parts, vjp + return total_log_prob_parts, (value, local_grads) + + def _sharded_log_prob_parts_bwd(res, gs): + value, local_grads = res + + def grad_mul(vs, g): + return tf.nest.map_structure(lambda v: v * g if v is not None else v, vs) + + local_grads = tf.nest.map_structure( + lambda v, g: _DummyGrads(grad_mul(v.grads, g)), local_grads, gs) + + def value_grad(v, value_sharded, term_grads): + """Computes reductions of output gradients. + + A `log_prob_parts` function takes in a list of values and outputs + a log density for each input to the function. The vector-Jacobian + product (VJP) of a `log_prob_parts` function thus needs to compute the + gradient of each output term w.r.t. each input value. This function + overrides the default VJP of an output term `j` w.r.t to an input + value `i` to include an all-reduce-sum when: + 1) The gradient of `j` w.r.t. `i` is connected. + 2) `j` is a sharded term and `i` is an unsharded value. + + If these conditions do not hold, the gradient remains the same and + either corresponds to: + 1) The gradient of a sharded term w.r.t to a sharded value + 2) The gradient of an unsharded term w.r.t. to an unsharded value. + 3) The gradient of an unsharded term w.r.t. to an sharded value. + In any of these cases, no all-reduce-sum is necessary. + Args: + v: The output term of a `log_prob_part` function. + value_sharded: A boolean indicating whether or not the output term is + sharded or not. + term_grads: The gradient of the output term w.r.t. to each of the input + values to the `log_prob_part` function. + + Returns: + The vector Jacobian product of `v` w.r.t. the input parts of the + `log_prob_parts` function. + """ + term_grads = term_grads.grads + def psum_grads(term_grad, term_sharded): + if term_grad is not None: + if not value_sharded and term_sharded: + term_grad = psum(term_grad, axis_name=axis_name) + return term_grad + + total_grad = tf.nest.map_structure(psum_grads, term_grads, + is_sharded) + if all([grad is None for grad in tf.nest.flatten(total_grad)]): + return None + return tf.add_n( + [v for v in tf.nest.flatten(total_grad) if v is not None]) + + out = tf.nest.map_structure(value_grad, value, is_sharded, local_grads) + return (out,) + + @tfp_custom_gradient.custom_gradient( + vjp_fwd=_sharded_log_prob_parts_fwd, vjp_bwd=_sharded_log_prob_parts_bwd) + def sharded_log_prob_parts(value): + return _sharded_log_prob_parts_fwd(value)[0] return sharded_log_prob_parts diff --git a/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py b/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py index 3f65709a54..9bd095962f 100644 --- a/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py +++ b/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py @@ -20,33 +20,17 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp from tensorflow_probability.python.experimental.distribute import distribute_lib +from tensorflow_probability.python.experimental.distribute import distribute_test_lib as test_lib from tensorflow_probability.python.internal import test_util tfd = tfp.distributions -NUM_DEVICES = 4 - - -def per_replica_to_tensor(value): - return tf.nest.map_structure( - lambda per_replica: tf.stack(per_replica.values, axis=0), value) - @test_util.test_all_tf_execution_regimes -class LogProbPartsTest(test_util.TestCase): - - def setUp(self): - super(LogProbPartsTest, self).setUp() - self.strategy = tf.distribute.MirroredStrategy( - devices=tf.config.list_logical_devices()) - - def shard_values(self, values): - - def value_fn(ctx): - return values[ctx.replica_id_in_sync_group] - - return self.strategy.experimental_distribute_values_from_function(value_fn) +class LogProbPartsTest(test_lib.DistributedTest): + @test_util.disable_test_for_backend( + disable_jax=True, reason='Behavior supported natively') def test_can_shard_values_across_logical_devices(self): @tf.function(autograph=False) @@ -59,9 +43,12 @@ def add_one(x): values = self.strategy.experimental_distribute_values_from_function( value_fn) out_values = self.evaluate( - per_replica_to_tensor(self.strategy.run(add_one, (values,)))) + self.per_replica_to_tensor(self.strategy_run(add_one, (values,)))) self.assertAllEqual(out_values, [1., 2., 3., 4.]) + @test_util.disable_test_for_backend( + disable_jax=True, + reason='Cannot use sharded distributions outside of pmap.') def test_correct_log_prob_for_global_variable_no_strategy(self): data = tf.ones(4) @@ -73,7 +60,7 @@ def log_prob_parts(value): ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [False, True]) + log_prob_parts, [False, True], axis_name=None) self.assertAllEqualNested( self.evaluate(sharded_log_prob_parts([tf.constant(0.), data])), self.evaluate([ @@ -81,6 +68,9 @@ def log_prob_parts(value): tf.reduce_sum(tfd.Normal(0., 1.).log_prob(data)) ])) + @test_util.disable_test_for_backend( + disable_jax=True, + reason='Cannot use sharded distributions outside of pmap.') def test_correct_log_prob_for_local_variable_no_strategy(self): data = tf.ones(4) @@ -93,7 +83,7 @@ def log_prob_parts(value): ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [True, True]) + log_prob_parts, [True, True], axis_name=None) self.assertAllEqualNested( self.evaluate(sharded_log_prob_parts([tf.ones(4), data])), self.evaluate([ @@ -103,7 +93,6 @@ def log_prob_parts(value): def test_correct_log_prob_for_global_variable(self): - @tf.function(autograph=False) def run(x, data): def log_prob_parts(value): @@ -114,14 +103,15 @@ def log_prob_parts(value): ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [False, True]) + log_prob_parts, [False, True], axis_name=self.axis_name) return sharded_log_prob_parts([x, data]) x = tf.constant(0.) data = tf.ones(4) sharded_data = self.shard_values(data) - out_parts = per_replica_to_tensor(self.strategy.run(run, (x, sharded_data))) + out_parts = self.per_replica_to_tensor( + self.strategy_run(run, (x, sharded_data), in_axes=(None, 0))) self.assertAllEqualNested( self.evaluate(out_parts), @@ -132,7 +122,6 @@ def log_prob_parts(value): def test_correct_log_prob_for_local_variable(self): - @tf.function(autograph=False) def run(x, data): def log_prob_parts(value): @@ -143,7 +132,7 @@ def log_prob_parts(value): ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [True, True]) + log_prob_parts, [True, True], axis_name=self.axis_name) return sharded_log_prob_parts([x, data]) @@ -151,8 +140,8 @@ def log_prob_parts(value): sharded_x = self.shard_values(x) data = tf.ones(4) sharded_data = self.shard_values(data) - out_parts = per_replica_to_tensor( - self.strategy.run(run, (sharded_x, sharded_data))) + out_parts = self.per_replica_to_tensor( + self.strategy_run(run, (sharded_x, sharded_data))) self.assertAllEqualNested( self.evaluate(out_parts), @@ -163,7 +152,6 @@ def log_prob_parts(value): def test_correct_log_prob_for_global_and_local_variable(self): - @tf.function(autograph=False) def run(w, x, data): def log_prob_parts(values): @@ -175,7 +163,7 @@ def log_prob_parts(values): ] sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [False, True, True]) + log_prob_parts, [False, True, True], axis_name=self.axis_name) return sharded_log_prob_parts([w, x, data]) @@ -184,8 +172,9 @@ def log_prob_parts(values): sharded_x = self.shard_values(x) data = 3 * tf.ones(4) sharded_data = self.shard_values(data) - out_parts = per_replica_to_tensor( - self.strategy.run(run, (w, sharded_x, sharded_data))) + out_parts = self.per_replica_to_tensor( + self.strategy_run( + run, (w, sharded_x, sharded_data), in_axes=(None, 0, 0))) self.assertAllEqualNested( self.evaluate(out_parts), @@ -197,7 +186,6 @@ def log_prob_parts(values): def test_correct_gradient_for_global_variable(self): - @tf.function(autograph=False) def run(x, data): def log_prob_parts(value): @@ -209,7 +197,7 @@ def log_prob_parts(value): def log_prob(x): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [False, True]) + log_prob_parts, [False, True], axis_name=self.axis_name) parts = sharded_log_prob_parts([x, data]) return tf.add_n(parts) @@ -218,7 +206,8 @@ def log_prob(x): x = tf.constant(1.) data = 2 * tf.ones(4) sharded_data = self.shard_values(data) - out_grads = per_replica_to_tensor(self.strategy.run(run, (x, sharded_data))) + out_grads = self.per_replica_to_tensor( + self.strategy_run(run, (x, sharded_data), in_axes=(None, 0))) def true_log_prob(x): return (tfd.Normal(0., 1.).log_prob(x) + @@ -242,7 +231,7 @@ def log_prob_parts(value): def log_prob(x): sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [True, True]) + log_prob_parts, [True, True], axis_name=self.axis_name) parts = sharded_log_prob_parts([x, data]) return tf.add_n(parts) @@ -252,8 +241,8 @@ def log_prob(x): sharded_x = self.shard_values(x) data = 2 * tf.ones(4) sharded_data = self.shard_values(data) - out_grads = per_replica_to_tensor(self.strategy.run(run, (sharded_x, - sharded_data))) + out_grads = self.per_replica_to_tensor( + self.strategy_run(run, (sharded_x, sharded_data))) def true_log_prob(x): return (tf.reduce_sum(tfd.Normal(0., 1.).log_prob(x)) + @@ -265,7 +254,6 @@ def true_log_prob(x): def test_correct_gradient_for_global_and_local_variable(self): - @tf.function(autograph=False) def run(w, x, data): def log_prob_parts(value): @@ -279,7 +267,7 @@ def log_prob_parts(value): def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, [False, True, True]) + log_prob_parts, [False, True, True], axis_name=self.axis_name) parts = sharded_log_prob_parts([w, x, data]) return tf.add_n(parts) @@ -290,8 +278,9 @@ def log_prob(*value): sharded_x = self.shard_values(x) data = 2 * tf.ones(4) sharded_data = self.shard_values(data) - out_grads = per_replica_to_tensor(self.strategy.run(run, (w, sharded_x, - sharded_data))) + out_grads = self.per_replica_to_tensor( + self.strategy_run( + run, (w, sharded_x, sharded_data), in_axes=(None, 0, 0))) def true_log_prob(*value): w, x = value @@ -302,8 +291,8 @@ def true_log_prob(*value): true_grad = tfp.math.value_and_gradient(true_log_prob, [w, x])[1] true_grad[0] = tf.ones(4) * true_grad[0] - self.assertAllEqualNested(self.evaluate(out_grads), - self.evaluate(true_grad)) + self.assertAllEqualNested( + self.evaluate(out_grads), self.evaluate(true_grad)) def test_correct_gradient_for_global_and_local_variable_dict(self): @@ -313,14 +302,15 @@ def run(w, x, data): def log_prob_parts(value): return { 'w': tfd.Normal(0., 1.).log_prob(value['w']), - 'x': tfd.Normal(w, 1.).log_prob(value['x']), - 'data': tfd.Normal(x, 1.).log_prob(value['data']), + 'x': tfd.Normal(value['w'], 1.).log_prob(value['x']), + 'data': tfd.Normal(value['x'], 1.).log_prob(value['data']), } def log_prob(*value): w, x = value sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts( - log_prob_parts, {'w': False, 'x': True, 'data': True}) + log_prob_parts, {'w': False, 'x': True, 'data': True}, + axis_name=self.axis_name) parts = sharded_log_prob_parts({'w': w, 'x': x, 'data': data}) return tf.add_n(tf.nest.flatten(parts)) @@ -331,8 +321,9 @@ def log_prob(*value): sharded_x = self.shard_values(x) data = 2 * tf.ones(4) sharded_data = self.shard_values(data) - out_grads = per_replica_to_tensor(self.strategy.run(run, (w, sharded_x, - sharded_data))) + out_grads = self.per_replica_to_tensor( + self.strategy_run(run, (w, sharded_x, sharded_data), + in_axes=(None, 0, 0))) def true_log_prob(*value): w, x = value @@ -347,11 +338,4 @@ def true_log_prob(*value): self.evaluate(true_grad)) if __name__ == '__main__': - tf.enable_v2_behavior() - physical_devices = tf.config.experimental.list_physical_devices() - - num_logical_devices = 4 - tf.config.experimental.set_virtual_device_configuration( - physical_devices[0], - [tf.config.experimental.VirtualDeviceConfiguration()] * NUM_DEVICES) tf.test.main() diff --git a/tensorflow_probability/python/experimental/distribute/distribute_test_lib.py b/tensorflow_probability/python/experimental/distribute/distribute_test_lib.py new file mode 100644 index 0000000000..9957a174fc --- /dev/null +++ b/tensorflow_probability/python/experimental/distribute/distribute_test_lib.py @@ -0,0 +1,74 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utilities for distributed testing.""" +import os + +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python.internal import test_util + +tf.enable_v2_behavior() +JAX_MODE = False +NUM_DEVICES = 4 + +if JAX_MODE: + import jax # pylint: disable=g-import-not-at-top + + +class DistributedTest(test_util.TestCase): + """Sets up distributed devices and sharding.""" + + def setUp(self): + super(DistributedTest, self).setUp() + if JAX_MODE: + os.environ['XLA_FLAGS'] = ( + '--xla_force_host_platform_device_count={}'.format(NUM_DEVICES)) + assert jax.device_count() == NUM_DEVICES + self.key = jax.random.PRNGKey(0) + else: + physical_devices = tf.config.experimental.list_physical_devices() + + tf.config.experimental.set_virtual_device_configuration( + physical_devices[0], + [tf.config.experimental.VirtualDeviceConfiguration()] * NUM_DEVICES) + self.strategy = tf.distribute.MirroredStrategy( + devices=tf.config.list_logical_devices()) + self.key = [0, 0] + self.axis_name = 'i' + + def per_replica_to_tensor(self, value): + if JAX_MODE: + return value + return tf.nest.map_structure( + lambda per_replica: tf.stack(per_replica.values, axis=0), value) + + def strategy_run(self, f, args, in_axes=0): + if JAX_MODE: + if in_axes is None: + return jax.pmap( + lambda _, args: f(*args), + in_axes=(0, None), + axis_name=self.axis_name)(tf.ones(NUM_DEVICES), args) + return jax.pmap(f, axis_name=self.axis_name, in_axes=in_axes)(*args) + return self.strategy.run(tf.function(f, autograph=False), args) + + def shard_values(self, values): + if JAX_MODE: + return jax.pmap(lambda x: x)(values) + + def value_fn(ctx): + return values[ctx.replica_id_in_sync_group] + + return self.strategy.experimental_distribute_values_from_function(value_fn) diff --git a/tensorflow_probability/python/experimental/distribute/joint_distribution.py b/tensorflow_probability/python/experimental/distribute/joint_distribution.py index 6f9b0034cd..d2e05f4e44 100644 --- a/tensorflow_probability/python/experimental/distribute/joint_distribution.py +++ b/tensorflow_probability/python/experimental/distribute/joint_distribution.py @@ -18,8 +18,11 @@ from __future__ import division from __future__ import print_function +import functools + import tensorflow.compat.v2 as tf from tensorflow_probability.python import distributions as distribution_lib +from tensorflow_probability.python.distributions import joint_distribution as jd_lib from tensorflow_probability.python.experimental.distribute import distribute_lib from tensorflow_probability.python.experimental.distribute import sharded @@ -31,9 +34,13 @@ class JointDistributionDistributedMixin(object): def get_sharded_distributions(self): """Indicates for each part distribution whether or not it is sharded.""" ds = self._get_single_sample_distributions() - return self._model_unflatten(( - isinstance(d, (sharded.ShardedIndependent, sharded.ShardedSample)) - for d in ds)) + return self._model_unflatten( + (isinstance(d, (sharded.ShardedIndependent, sharded.ShardedSample)) + for d in ds)) + + @property + def shard_axis_name(self): + return self._parameters['shard_axis_name'] def _map_measure_over_dists(self, attr, value): """Overrides the default implementation to shard its log_prob calculation.""" @@ -44,34 +51,122 @@ def _map_measure_over_dists(self, attr, value): def inner_log_prob_parts(flat_value): unflat_value = self._model_unflatten(flat_value) ds, xs = self._call_flat_sample_distributions( - value=unflat_value, seed=42) + value=unflat_value, seed=jd_lib.dummy_seed()) + # For sharded distributions, we need to make sure not to do an + # all-reduce. + flat_sharded = self._model_flatten(self.get_sharded_distributions()) + log_prob_fns = [ + functools.partial(d.log_prob, reduce_over_shards=False) + if s else d.log_prob for d, s in zip(ds, flat_sharded)] # We need to flatten and unflatten here to ensure the output structure # matches `flat_sharded_distributions`. vals = self._model_unflatten( - [getattr(d, attr)(x) for d, x in zip(ds, xs)]) + [log_prob_fn(x) for log_prob_fn, x in zip(log_prob_fns, xs)]) return self._model_flatten(vals) flat_value = self._model_flatten(value) flat_sharded_distributions = self._model_flatten( self.get_sharded_distributions()) flat_xs = distribute_lib.make_sharded_log_prob_parts( - inner_log_prob_parts, flat_sharded_distributions)( + inner_log_prob_parts, + flat_sharded_distributions, + axis_name=self.shard_axis_name)( flat_value) return iter(flat_xs) - ds, xs = self._call_flat_sample_distributions(value=value, seed=42) + ds, xs = self._call_flat_sample_distributions( + value=value, seed=jd_lib.dummy_seed()) return (getattr(d, attr)(x) for d, x in zip(ds, xs)) class JointDistributionSequential(JointDistributionDistributedMixin, distribution_lib.JointDistributionSequential): - pass + """A sharding-aware JointDistributionSequential.""" + + def __init__(self, + model, + validate_args=False, + shard_axis_name=None, + name=None): + """Construct the `JointDistributionSequential` distribution. + + Args: + model: Python list of either tfd.Distribution instances and/or lambda + functions which take the `k` previous distributions and returns a new + tfd.Distribution instance. + validate_args: Python `bool`. Whether to validate input with asserts. If + `validate_args` is `False`, and the inputs are invalid, correct behavior + is not guaranteed. + Default value: `False`. + shard_axis_name: `str` for axis name for use in JAX backend. + name: The name for ops managed by the distribution. + Default value: `None` (i.e., `"JointDistributionSequential"`). + """ + super(JointDistributionSequential, self).__init__( + model, validate_args=validate_args, name=name) + self._parameters['shard_axis_name'] = shard_axis_name class JointDistributionNamed(JointDistributionDistributedMixin, distribution_lib.JointDistributionNamed): - pass + """A sharding-aware JointDistributionNamed.""" + + def __init__(self, + model, + validate_args=False, + shard_axis_name=None, + name=None): + """Construct the `JointDistributionNamed` distribution. + + Args: + model: Python `dict`, `collections.OrderedDict`, or `namedtuple` of + distribution-making functions each with required args corresponding only + to other keys. + validate_args: Python `bool`. Whether to validate input with asserts. If + `validate_args` is `False`, and the inputs are invalid, correct behavior + is not guaranteed. + Default value: `False`. + shard_axis_name: `str` for axis name for use in JAX backend. + name: The name for ops managed by the distribution. + Default value: `None` (i.e., `"JointDistributionNamed"`). + """ + super(JointDistributionNamed, + self).__init__(model, validate_args, name or 'JointDistributionNamed') + self._parameters['shard_axis_name'] = shard_axis_name class JointDistributionCoroutine(JointDistributionDistributedMixin, distribution_lib.JointDistributionCoroutine): - pass + """A sharding-aware JointDistributionCoroutine.""" + + def __init__( + self, + model, + sample_dtype=None, + validate_args=False, + shard_axis_name=None, + name=None, + ): + """Construct the `JointDistributionCoroutine` distribution. + + Args: + model: A generator that yields a sequence of `tfd.Distribution`-like + instances. + sample_dtype: Samples from this distribution will be structured like + `tf.nest.pack_sequence_as(sample_dtype, list_)`. `sample_dtype` is only + used for `tf.nest.pack_sequence_as` structuring of outputs, never + casting (which is the responsibility of the component distributions). + Default value: `None` (i.e. `namedtuple`). + validate_args: Python `bool`. Whether to validate input with asserts. If + `validate_args` is `False`, and the inputs are invalid, correct behavior + is not guaranteed. + Default value: `False`. + shard_axis_name: `str` for axis name for use in JAX backend. + name: The name for ops managed by the distribution. + Default value: `None` (i.e., `JointDistributionCoroutine`). + """ + super(JointDistributionCoroutine, self).__init__( + model, + sample_dtype=sample_dtype, + validate_args=validate_args, + name=name) + self._parameters['shard_axis_name'] = shard_axis_name diff --git a/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py b/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py index 14867fda3b..37a0b3799e 100644 --- a/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py +++ b/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py @@ -21,77 +21,88 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.experimental.distribute import distribute_test_lib as test_lib from tensorflow_probability.python.experimental.distribute import joint_distribution as jd from tensorflow_probability.python.experimental.distribute import sharded from tensorflow_probability.python.internal import test_util tfd = tfp.distributions -NUM_DEVICES = 4 +def make_jd_sequential(axis_name): + return jd.JointDistributionSequential([ + tfd.Normal(0., 1.), + lambda w: sharded.ShardedSample( # pylint: disable=g-long-lambda + tfd.Normal(w, 1.), test_lib.NUM_DEVICES, shard_axis_name=axis_name), + lambda x: sharded.ShardedIndependent( # pylint: disable=g-long-lambda + tfd.Normal(x, 1.), 1, shard_axis_name=axis_name), + ], shard_axis_name=axis_name) -def per_replica_to_tensor(value): - return tf.nest.map_structure( - lambda per_replica: tf.stack(per_replica.values, axis=0), value) +def make_jd_named(axis_name): + return jd.JointDistributionNamed( # pylint: disable=g-long-lambda + dict( + w=tfd.Normal(0., 1.), + x=lambda w: sharded.ShardedSample( # pylint: disable=g-long-lambda + tfd.Normal(w, 1.), + test_lib.NUM_DEVICES, + shard_axis_name=axis_name), + data=lambda x: sharded.ShardedIndependent( # pylint: disable=g-long-lambda + tfd.Normal(x, 1.), + 1, + shard_axis_name=axis_name), + ), shard_axis_name=axis_name) -def model_coroutine(): - w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.)) - x = yield sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES) - yield sharded.ShardedIndependent(tfd.Normal(x, 1.), 1) +def make_jd_coroutine(axis_name): -distributions = ( - ('coroutine', lambda: jd.JointDistributionCoroutine(model_coroutine)), - ('sequential', lambda: jd.JointDistributionSequential([ # pylint: disable=g-long-lambda - tfd.Normal(0., 1.), - lambda w: sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES), - lambda x: sharded.ShardedIndependent(tfd.Normal(x, 1.), 1), - ])), - ('named', lambda: jd.JointDistributionNamed( # pylint: disable=g-long-lambda - dict( - w=tfd.Normal(0., 1.), - x=lambda w: sharded.ShardedSample(tfd.Normal(w, 1.), NUM_DEVICES), - data=lambda x: sharded.ShardedIndependent(tfd.Normal(x, 1.), 1), - ))), -) + def model_coroutine(): + w = yield tfd.JointDistributionCoroutine.Root(tfd.Normal(0., 1.)) + x = yield sharded.ShardedSample( + tfd.Normal(w, 1.), test_lib.NUM_DEVICES, shard_axis_name=axis_name) + yield sharded.ShardedIndependent( + tfd.Normal(x, 1.), 1, shard_axis_name=axis_name) + return jd.JointDistributionCoroutine( + model_coroutine, shard_axis_name=axis_name) -@test_util.test_all_tf_execution_regimes -class JointDistributionTest(test_util.TestCase): - - def setUp(self): - super(JointDistributionTest, self).setUp() - self.strategy = tf.distribute.MirroredStrategy( - devices=tf.config.list_logical_devices()) - def shard_values(self, values): +distributions = ( + ('coroutine', make_jd_coroutine), + ('sequential', make_jd_sequential), + ('named', make_jd_named), +) - def value_fn(ctx): - return values[ctx.replica_id_in_sync_group] - return self.strategy.experimental_distribute_values_from_function(value_fn) +@test_util.test_all_tf_execution_regimes +class JointDistributionTest(test_lib.DistributedTest): + @test_util.disable_test_for_backend( + disable_jax=True, + reason='Cannot call `get_sharded_distributions` outside of pmap.') def test_get_sharded_distribution_coroutine(self): - dist = distributions[0][1]() - self.assertTupleEqual(dist.get_sharded_distributions(), - (False, True, True)) + dist = distributions[0][1](self.axis_name) + self.assertTupleEqual(dist.get_sharded_distributions(), (False, True, True)) + @test_util.disable_test_for_backend( + disable_jax=True, + reason='Cannot call `get_sharded_distributions` outside of pmap.') def test_get_sharded_distribution_sequential(self): - dist = distributions[1][1]() - self.assertListEqual(dist.get_sharded_distributions(), - [False, True, True]) + dist = distributions[1][1](self.axis_name) + self.assertListEqual(dist.get_sharded_distributions(), [False, True, True]) + @test_util.disable_test_for_backend( + disable_jax=True, + reason='Cannot call `get_sharded_distributions` outside of pmap.') def test_get_sharded_distribution_named(self): - dist = distributions[2][1]() + dist = distributions[2][1](self.axis_name) self.assertDictEqual(dist.get_sharded_distributions(), dict(w=False, x=True, data=True)) @parameterized.named_parameters(*distributions) def test_jd(self, dist_fn): - dist = dist_fn() + dist = dist_fn(self.axis_name) - @tf.function(autograph=False) def run(key): sample = dist.sample(seed=key) # The identity is to prevent reparameterization gradients from kicking in. @@ -99,9 +110,9 @@ def run(key): dist.log_prob, (tf.nest.map_structure(tf.identity, sample),)) return sample, log_prob, log_prob_grads - sample, log_prob, log_prob_grads = self.strategy.run( - run, (tf.ones(2, tf.int32),)) - sample, log_prob, log_prob_grads = per_replica_to_tensor( + sample, log_prob, log_prob_grads = self.strategy_run( + run, (self.key,), in_axes=None) + sample, log_prob, log_prob_grads = self.per_replica_to_tensor( (sample, log_prob, log_prob_grads)) def true_log_prob_fn(w, x, data): @@ -130,11 +141,4 @@ def true_log_prob_fn(w, x, data): if __name__ == '__main__': - tf.enable_v2_behavior() - physical_devices = tf.config.experimental.list_physical_devices() - - num_logical_devices = 4 - tf.config.experimental.set_virtual_device_configuration( - physical_devices[0], - [tf.config.experimental.VirtualDeviceConfiguration()] * NUM_DEVICES) tf.test.main() diff --git a/tensorflow_probability/python/experimental/distribute/sharded.py b/tensorflow_probability/python/experimental/distribute/sharded.py index 31976f5304..14ed7c2ab2 100644 --- a/tensorflow_probability/python/experimental/distribute/sharded.py +++ b/tensorflow_probability/python/experimental/distribute/sharded.py @@ -24,10 +24,14 @@ from tensorflow_probability.python.distributions import independent as independent_lib from tensorflow_probability.python.distributions import sample as sample_lib +from tensorflow_probability.python.experimental.distribute import distribute_lib from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers +JAX_MODE = False + + class ShardedSample(sample_lib.Sample): """A version of `tfd.Sample` that shards its output across devices.""" @@ -35,6 +39,7 @@ def __init__(self, distribution, sample_shape=(), shard_axis=0, + shard_axis_name=None, validate_args=False, experimental_use_kahan_sum=False, name=None): @@ -47,6 +52,7 @@ def __init__(self, single sample. shard_axis: `int` representing which axis of `sample_shape` will be sharded across devices. + shard_axis_name: `str` for axis name for use in JAX backend. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. @@ -56,13 +62,13 @@ def __init__(self, noticeable in particular for large dimensions in float32. See CPU caveat on `tfp.math.reduce_kahan_sum`. name: The name for ops managed by the distribution. - Default value: `None` (i.e., `'Sample' + distribution.name`). + Default value: `None` (i.e., `'ShardedSample' + distribution.name`). """ parameters = dict(locals()) with tf.name_scope(name or 'ShardedSample' + distribution.name) as name: self._shard_axis = shard_axis - + self._shard_axis_name = shard_axis_name super(ShardedSample, self).__init__( distribution, validate_args=validate_args, @@ -82,39 +88,116 @@ def sample_shape(self): sample_shape = ps.concat([ sample_shape[:self.shard_axis], [shard_size], sample_shape[self.shard_axis + 1:] - ], - axis=0) + ], axis=0) return sample_shape + @property + def shard_axis_name(self): + return self._shard_axis_name + @property def shard_axis(self): return self._shard_axis @property def replica_id(self): - ctx = tf.distribute.get_replica_context() - return ctx.replica_id_in_sync_group + return distribute_lib.get_replica_id(axis_name=self.shard_axis_name) @property def num_devices(self): - ctx = tf.distribute.get_replica_context() - return ctx.num_replicas_in_sync + return distribute_lib.get_num_replicas(axis_name=self.shard_axis_name) def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_sample_sample') - return super(ShardedSample, self)._sample_n(n, seed + self.replica_id, - **kwargs) + seed = samplers.fold_in(seed, tf.cast(self.replica_id, tf.int32)) + return super(ShardedSample, self)._sample_n(n, seed, **kwargs) + + def _log_prob(self, value, reduce_over_shards=True, **kwargs): + out_log_prob = super(ShardedSample, self)._log_prob(value, **kwargs) + if reduce_over_shards: + return distribute_lib.psum(out_log_prob, axis_name=self.shard_axis_name) + return out_log_prob + + def _parameter_control_dependencies(self, is_init=False): + if not self.validate_args: + return [] + return super(ShardedSample, self)._parameter_control_dependencies( + is_init=is_init) class ShardedIndependent(independent_lib.Independent): """A version of `tfd.Independent` that folds device id into its randomness.""" + def __init__(self, + distribution, + reinterpreted_batch_ndims=None, + validate_args=False, + shard_axis_name=None, + experimental_use_kahan_sum=False, + name=None): + """Construct a `ShardedIndependent` distribution. + + Args: + distribution: The base distribution instance to transform. Typically an + instance of `Distribution`. + reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims + which will be regarded as event dims. When `None` all but the first + batch axis (batch axis 0) will be transferred to event dimensions + (analogous to `tf.layers.flatten`). + validate_args: Python `bool`. Whether to validate input with asserts. If + `validate_args` is `False`, and the inputs are invalid, correct behavior + is not guaranteed. + shard_axis_name: `str` for axis name for use in JAX backend. + experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan + summation to aggregate independent underlying log_prob values, which + improves against the precision of a naive float32 sum. This can be + noticeable in particular for large dimensions in float32. See CPU caveat + on `tfp.math.reduce_kahan_sum`. + name: The name for ops managed by the distribution. + Default value: `'ShardedIndependent' + distribution.name`. + + Raises: + ValueError: if `reinterpreted_batch_ndims` exceeds + `distribution.batch_ndims` + """ + parameters = dict(locals()) + + with tf.name_scope(name or + 'ShardedIndependent' + distribution.name) as name: + self._shard_axis_name = shard_axis_name + super(ShardedIndependent, self).__init__( + distribution, + reinterpreted_batch_ndims=reinterpreted_batch_ndims, + validate_args=validate_args, + experimental_use_kahan_sum=experimental_use_kahan_sum, + name=name) + self._parameters = parameters + + @property + def shard_axis_name(self): + return self._shard_axis_name + + def _log_prob(self, value, reduce_over_shards=True, **kwargs): + out_log_prob = super(ShardedIndependent, self)._log_prob(value, **kwargs) + if reduce_over_shards: + return distribute_lib.psum(out_log_prob, axis_name=self.shard_axis_name) + return out_log_prob + @property def replica_id(self): - ctx = tf.distribute.get_replica_context() - return ctx.replica_id_in_sync_group + return distribute_lib.get_replica_id(axis_name=self.shard_axis_name) + + @property + def num_devices(self): + return distribute_lib.get_num_replicas(axis_name=self.shard_axis_name) def _sample_n(self, n, seed, **kwargs): seed = samplers.sanitize_seed(seed, salt='sharded_independent_sample') - return super(ShardedIndependent, self)._sample_n(n, seed + self.replica_id, - **kwargs) + seed = samplers.fold_in(seed, tf.cast(self.replica_id, tf.int32)) + return super(ShardedIndependent, self)._sample_n(n, seed, **kwargs) + + def _parameter_control_dependencies(self, is_init): + if JAX_MODE: + return [] + return super(ShardedIndependent, self)._parameter_control_dependencies( + is_init=is_init) diff --git a/tensorflow_probability/python/experimental/distribute/sharded_test.py b/tensorflow_probability/python/experimental/distribute/sharded_test.py index fdb7387130..03c17f07c6 100644 --- a/tensorflow_probability/python/experimental/distribute/sharded_test.py +++ b/tensorflow_probability/python/experimental/distribute/sharded_test.py @@ -19,35 +19,28 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.experimental.distribute import distribute_test_lib as test_lib from tensorflow_probability.python.experimental.distribute import sharded from tensorflow_probability.python.internal import test_util tfd = tfp.distributions -NUM_DEVICES = 4 - -def per_replica_to_tensor(value): - return tf.nest.map_structure( - lambda per_replica: tf.stack(per_replica.values, axis=0), value) - - -class ShardedDistributionTest(test_util.TestCase): - - def setUp(self): - super(ShardedDistributionTest, self).setUp() - self.strategy = tf.distribute.MirroredStrategy( - devices=tf.config.list_logical_devices()) +@test_util.test_all_tf_execution_regimes +class ShardedDistributionTest(test_lib.DistributedTest): def test_sharded_sample_samples_differently_across_shards(self): @tf.function(autograph=False) def run(key): - return sharded.ShardedSample(tfd.Normal(0., 1.), - NUM_DEVICES).sample(seed=key) + return sharded.ShardedSample( + tfd.Normal(0., 1.), + test_lib.NUM_DEVICES, + shard_axis_name=self.axis_name).sample(seed=key) sample = self.evaluate( - per_replica_to_tensor(self.strategy.run(run, (tf.zeros(2, tf.int32),)))) + self.per_replica_to_tensor( + self.strategy_run(run, (self.key,), in_axes=None))) for i in range(4): for j in range(4): if i == j: @@ -59,10 +52,13 @@ def test_sharded_independent_samples_differently_across_shards(self): @tf.function(autograph=False) def run(key): return sharded.ShardedIndependent( - tfd.Normal(tf.zeros(1), tf.ones(1)), 1).sample(seed=key) + tfd.Normal(tf.zeros(1), tf.ones(1)), + 1, + shard_axis_name=self.axis_name).sample(seed=key) sample = self.evaluate( - per_replica_to_tensor(self.strategy.run(run, (tf.zeros(2, tf.int32),)))) + self.per_replica_to_tensor( + self.strategy_run(run, (self.key,), in_axes=None))) for i in range(4): for j in range(4): if i == j: @@ -71,11 +67,4 @@ def run(key): if __name__ == "__main__": - tf.enable_v2_behavior() - physical_devices = tf.config.experimental.list_physical_devices() - - num_logical_devices = 4 - tf.config.experimental.set_virtual_device_configuration( - physical_devices[0], - [tf.config.experimental.VirtualDeviceConfiguration()] * NUM_DEVICES) tf.test.main() diff --git a/tensorflow_probability/python/experimental/distributions/BUILD b/tensorflow_probability/python/experimental/distributions/BUILD index edc7375c7b..cd74aa6788 100644 --- a/tensorflow_probability/python/experimental/distributions/BUILD +++ b/tensorflow_probability/python/experimental/distributions/BUILD @@ -37,6 +37,7 @@ multi_substrate_py_library( deps = [ ":joint_distribution_pinned", ":mvn_precision_factor_linop", + "//tensorflow_probability/python/distributions:log_prob_ratio", ], ) diff --git a/tensorflow_probability/python/experimental/distributions/__init__.py b/tensorflow_probability/python/experimental/distributions/__init__.py index 0bab103142..2334ac9767 100644 --- a/tensorflow_probability/python/experimental/distributions/__init__.py +++ b/tensorflow_probability/python/experimental/distributions/__init__.py @@ -18,11 +18,13 @@ from __future__ import division from __future__ import print_function +from tensorflow_probability.python.distributions.log_prob_ratio import log_prob_ratio from tensorflow_probability.python.experimental.distributions.joint_distribution_pinned import JointDistributionPinned from tensorflow_probability.python.experimental.distributions.mvn_precision_factor_linop import MultivariateNormalPrecisionFactorLinearOperator __all__ = [ + 'log_prob_ratio', 'JointDistributionPinned', 'MultivariateNormalPrecisionFactorLinearOperator', ] diff --git a/tensorflow_probability/python/experimental/lazybones/utils/weak_container.py b/tensorflow_probability/python/experimental/lazybones/utils/weak_container.py index 4f385d54fd..055269503c 100644 --- a/tensorflow_probability/python/experimental/lazybones/utils/weak_container.py +++ b/tensorflow_probability/python/experimental/lazybones/utils/weak_container.py @@ -114,7 +114,7 @@ class HashableWeakRef(weakref.ref): def __init__(self, referrent, callback=None): """weakref.ref which makes any object hashable. - Arguments: + Args: referrent: Object that is being referred to. callback: Optional callback to invoke when object is GCed. """ diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index 8c87d50e75..3ca9add236 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -55,6 +55,7 @@ multi_substrate_py_library( ":sample", ":sample_discarding_kernel", ":sample_fold", + ":thinning_kernel", ":tracing_reducer", ":with_reductions", ], @@ -79,6 +80,7 @@ multi_substrate_py_library( ":sample_fold", ":sample_sequential_monte_carlo", ":sequential_monte_carlo_kernel", + ":thinning_kernel", ":tracing_reducer", ":weighted_resampling", ":with_reductions", @@ -134,7 +136,7 @@ py_test( py_library( name = "kernel_outputs", srcs = ["kernel_outputs.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":tracing_reducer", "//tensorflow_probability/python/internal:unnest", @@ -191,7 +193,7 @@ py_test( py_library( name = "preconditioned_hmc", srcs = ["preconditioned_hmc.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ # tensorflow dep, "//tensorflow_probability/python/distributions:independent", @@ -225,7 +227,7 @@ py_test( py_library( name = "progress_bar_reducer", srcs = ["progress_bar_reducer.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":reducer", # tensorflow dep, @@ -661,8 +663,10 @@ py_library( srcs = ["sample_fold.py"], srcs_version = "PY3", deps = [ + ":run", ":sample", ":sample_discarding_kernel", + ":thinning_kernel", ":tracing_reducer", ":with_reductions", # numpy dep, @@ -714,6 +718,34 @@ py_test( ], ) +py_library( + name = "thinning_kernel", + srcs = ["thinning_kernel.py"], + srcs_version = "PY3", + deps = [ + ":sample", + # tensorflow dep, + "//tensorflow_probability/python/mcmc:kernel", + "//tensorflow_probability/python/mcmc/internal", + ], +) + +py_test( + name = "thinning_kernel_test", + size = "small", + srcs = ["thinning_kernel_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + ":thinning_kernel", + # numpy dep, + # tensorflow dep, + "//tensorflow_probability", + "//tensorflow_probability/python/experimental/mcmc/internal:test_fixtures", + "//tensorflow_probability/python/internal:test_util", + ], +) + py_library( name = "tracing_reducer", srcs = ["tracing_reducer.py"], diff --git a/tensorflow_probability/python/experimental/mcmc/__init__.py b/tensorflow_probability/python/experimental/mcmc/__init__.py index a13b9bc59d..3dd3ea1e93 100644 --- a/tensorflow_probability/python/experimental/mcmc/__init__.py +++ b/tensorflow_probability/python/experimental/mcmc/__init__.py @@ -44,7 +44,7 @@ from tensorflow_probability.python.experimental.mcmc.run import run_kernel from tensorflow_probability.python.experimental.mcmc.sample import step_kernel from tensorflow_probability.python.experimental.mcmc.sample_discarding_kernel import SampleDiscardingKernel -from tensorflow_probability.python.experimental.mcmc.sample_fold import sample_chain +from tensorflow_probability.python.experimental.mcmc.sample_fold import sample_chain_with_burnin from tensorflow_probability.python.experimental.mcmc.sample_fold import sample_fold from tensorflow_probability.python.experimental.mcmc.sample_sequential_monte_carlo import default_make_hmc_kernel_fn from tensorflow_probability.python.experimental.mcmc.sample_sequential_monte_carlo import gen_make_hmc_kernel_fn @@ -56,6 +56,7 @@ from tensorflow_probability.python.experimental.mcmc.sequential_monte_carlo_kernel import SequentialMonteCarlo from tensorflow_probability.python.experimental.mcmc.sequential_monte_carlo_kernel import SequentialMonteCarloResults from tensorflow_probability.python.experimental.mcmc.sequential_monte_carlo_kernel import WeightedParticles +from tensorflow_probability.python.experimental.mcmc.thinning_kernel import ThinningKernel from tensorflow_probability.python.experimental.mcmc.tracing_reducer import TracingReducer from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample_deterministic_minimum_error from tensorflow_probability.python.experimental.mcmc.weighted_resampling import resample_independent @@ -107,7 +108,7 @@ 'resample_stratified', 'resample_systematic', 'run_kernel', - 'sample_chain', + 'sample_chain_with_burnin', 'sample_fold', 'sample_sequential_monte_carlo', 'SampleDiscardingKernel', @@ -116,6 +117,7 @@ 'simple_heuristic_tuning', 'StateWithHistory', 'step_kernel', + 'ThinningKernel', 'TracingReducer', 'VarianceReducer', 'WeightedParticles', diff --git a/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py b/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py index 0c728cff9f..0e8c966fc9 100644 --- a/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py +++ b/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py @@ -240,7 +240,7 @@ def bootstrap_results(self, init_state): sample_stats.RunningVariance): variance_parts = [self.initial_running_variance] else: - variance_parts = self.initial_running_variance + variance_parts = list(self.initial_running_variance) diags = [variance_part.variance() for variance_part in variance_parts] diff --git a/tensorflow_probability/python/experimental/mcmc/nuts_autobatching_test.py b/tensorflow_probability/python/experimental/mcmc/nuts_autobatching_test.py index 64aa30a707..ae5125b445 100644 --- a/tensorflow_probability/python/experimental/mcmc/nuts_autobatching_test.py +++ b/tensorflow_probability/python/experimental/mcmc/nuts_autobatching_test.py @@ -303,9 +303,9 @@ def __init__(self, loc, chol_precision_tril, name=None): scale=tf.ones_like(loc)), reinterpreted_batch_ndims=1), bijector=tfb.Chain([ - tfb.Affine(shift=loc), - tfb.Invert(tfb.Affine(scale_tril=chol_precision_tril, - adjoint=True)), + tfb.Shift(shift=loc), + tfb.Invert(tfb.ScaleMatvecTriL(scale_tril=chol_precision_tril, + adjoint=True)), ]), name=name) diff --git a/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py b/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py index ceb385f4ac..892fc765cf 100644 --- a/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py +++ b/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py @@ -464,7 +464,7 @@ def _prepare_args(target_log_prob_fn, def _batched_isotropic_normal_like(state_part): event_ndims = ps.rank(state_part) - batch_rank return independent.Independent( - normal.Normal(ps.zeros_like(state_part, tf.float32), 1.), + normal.Normal(ps.zeros_like(state_part), 1.), reinterpreted_batch_ndims=event_ndims) momentum_distribution = jds.JointDistributionSequential( diff --git a/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc_test.py b/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc_test.py index 79ff9e3c19..dd625ca481 100644 --- a/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc_test.py +++ b/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc_test.py @@ -386,6 +386,36 @@ def test_correctness_with_200d_mvn_tril(self, precondition_scheme): dict(testcase_name='_explicit', use_default=False)) class PreconditionedHMCTest(test_util.TestCase): + def test_f64(self, use_default): + if use_default: + momentum_distribution = None + else: + momentum_distribution = tfp.experimental.as_composite( + tfd.Normal(0., tf.constant(.5, dtype=tf.float64))) + kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo( + lambda x: -x**2, step_size=.5, num_leapfrog_steps=2, + momentum_distribution=momentum_distribution) + kernel = tfp.mcmc.SimpleStepSizeAdaptation(kernel, num_adaptation_steps=3) + self.evaluate(tfp.mcmc.sample_chain( + 1, kernel=kernel, current_state=tf.ones([], tf.float64), + num_burnin_steps=5, trace_fn=None)) + + # TODO(b/175787154): Enable this test + def DISABLED_test_f64_multichain(self, use_default): + if use_default: + momentum_distribution = None + else: + momentum_distribution = tfp.experimental.as_composite( + tfd.Normal(0., tf.constant(.5, dtype=tf.float64))) + kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo( + lambda x: -x**2, step_size=.5, num_leapfrog_steps=2, + momentum_distribution=momentum_distribution) + kernel = tfp.mcmc.SimpleStepSizeAdaptation(kernel, num_adaptation_steps=3) + nchains = 7 + self.evaluate(tfp.mcmc.sample_chain( + 1, kernel=kernel, current_state=tf.ones([nchains], tf.float64), + num_burnin_steps=5, trace_fn=None)) + def test_diag(self, use_default): """Test that a diagonal multivariate normal can be effectively sampled from. diff --git a/tensorflow_probability/python/experimental/mcmc/run.py b/tensorflow_probability/python/experimental/mcmc/run.py index 32db8c0077..d5b8f78b74 100644 --- a/tensorflow_probability/python/experimental/mcmc/run.py +++ b/tensorflow_probability/python/experimental/mcmc/run.py @@ -161,7 +161,6 @@ def run_kernel( Default value: `None` (i.e., 'mcmc_run_kernel'). Returns: - result: A `RunKernelResults` instance containing information about the sampling run. Main fields are `trace`, the history of outputs of `trace_fn`, and `reduction_results`, the final outputs of all supplied diff --git a/tensorflow_probability/python/experimental/mcmc/sample_fold.py b/tensorflow_probability/python/experimental/mcmc/sample_fold.py index ea82ac8275..9f8af741b4 100644 --- a/tensorflow_probability/python/experimental/mcmc/sample_fold.py +++ b/tensorflow_probability/python/experimental/mcmc/sample_fold.py @@ -22,16 +22,17 @@ # Dependency imports import tensorflow.compat.v2 as tf +from tensorflow_probability.python import random +from tensorflow_probability.python.experimental.mcmc import run from tensorflow_probability.python.experimental.mcmc import sample as exp_sample_lib from tensorflow_probability.python.experimental.mcmc import sample_discarding_kernel -from tensorflow_probability.python.experimental.mcmc import tracing_reducer +from tensorflow_probability.python.experimental.mcmc import thinning_kernel from tensorflow_probability.python.experimental.mcmc import with_reductions -from tensorflow_probability.python.mcmc import sample from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import __all__ = [ - 'sample_chain', + 'sample_chain_with_burnin', 'sample_fold', ] @@ -126,19 +127,19 @@ def sample_fold( if reducer is None: reducer = [] reducer_was_none = True - thinning_kernel = sample_discarding_kernel.SampleDiscardingKernel( + thinning_k = sample_discarding_kernel.SampleDiscardingKernel( inner_kernel=kernel, num_burnin_steps=num_burnin_steps, num_steps_between_results=num_steps_between_results) reduction_kernel = with_reductions.WithReductions( - inner_kernel=thinning_kernel, + inner_kernel=thinning_k, reducer=reducer, # Strip thinning kernel results layer adjust_kr_fn=lambda kr: kr.inner_results, ) if previous_kernel_results is None: previous_kernel_results = kernel.bootstrap_results(current_state) - thinning_pkr = thinning_kernel.bootstrap_results( + thinning_pkr = thinning_k.bootstrap_results( current_state, previous_kernel_results) reduction_pkr = reduction_kernel.bootstrap_results( current_state, thinning_pkr, previous_reducer_state) @@ -176,20 +177,19 @@ def sample_fold( final_kernel_results.inner_results.inner_results) -def _trace_kernel_results(current_state, kernel_results): - del current_state - return kernel_results +def _trace_current_state(current_state, kernel_results): + del kernel_results + return current_state -def sample_chain( +def sample_chain_with_burnin( num_results, current_state, previous_kernel_results=None, kernel=None, num_burnin_steps=0, num_steps_between_results=0, - trace_fn=_trace_kernel_results, - return_final_kernel_results=False, + trace_fn=_trace_current_state, parallel_iterations=10, seed=None, name=None, @@ -216,9 +216,8 @@ def sample_chain( In addition to returning the chain state, this function supports tracing of auxiliary variables used by the kernel. The traced values are selected by - specifying `trace_fn`. By default, all kernel results are traced but in the - future the default will be changed to no results being traced, so plan - accordingly. See below for some examples of this feature. + specifying `trace_fn`. By default, all chain states but no kernel results are + traced. Args: num_results: Integer number of Markov chain draws. @@ -239,27 +238,17 @@ def sample_chain( trace_fn: A callable that takes in the current chain state and the previous kernel results and return a `Tensor` or a nested collection of `Tensor`s that is then traced along with the chain state. - return_final_kernel_results: If `True`, then the final kernel results are - returned alongside the chain state and the trace specified by the - `trace_fn`. parallel_iterations: The number of iterations allowed to run in parallel. It must be a positive integer. See `tf.while_loop` for more details. seed: Optional, a seed for reproducible sampling. name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., 'experimental_mcmc_sample_chain'). + Default value: `None` (i.e., + 'experimental_mcmc_sample_chain_with_burnin'). Returns: - checkpointable_states_and_trace: if `return_final_kernel_results` is - `True`. The return value is an instance of - `CheckpointableStatesAndTrace`. - all_states: if `return_final_kernel_results` is `False` and `trace_fn` is - `None`. The return value is a `Tensor` or Python list of `Tensor`s - representing the state(s) of the Markov chain(s) at each result step. Has - same shape as input `current_state` but with a prepended - `num_results`-size dimension. - states_and_trace: if `return_final_kernel_results` is `False` and - `trace_fn` is not `None`. The return value is an instance of - `StatesAndTrace`. + result: A `RunKernelResults` instance containing information about the + sampling run. Main field is `trace`, the history of outputs of + `trace_fn`. See `RunKernelResults` for contents of other fields. #### References @@ -267,51 +256,42 @@ def sample_chain( _Technical Report_, 2017. http://statweb.stanford.edu/~owen/reports/bestthinning.pdf """ - with tf.name_scope(name or 'experimental_mcmc_sample_chain'): + with tf.name_scope(name or 'experimental_mcmc_sample_chain_with_burnin'): if not kernel.is_calibrated: warnings.warn('supplied `TransitionKernel` is not calibrated. Markov ' 'chain may not converge to intended target distribution.') if trace_fn is None: trace_fn = lambda *args: () - no_trace = True - else: - no_trace = False - - if trace_fn is sample_chain.__defaults__[4]: - warnings.warn('Tracing all kernel results by default is deprecated. Set ' - 'the `trace_fn` argument to None (the future default ' - 'value) or an explicit callback that traces the values ' - 'you are interested in.') - - def real_trace_fn(curr_state, kr): - return curr_state, trace_fn(curr_state, kr) - trace_reducer = tracing_reducer.TracingReducer( - trace_fn=real_trace_fn, - size=num_results - ) - # pylint: disable=unbalanced-tuple-unpacking - trace_results, _, final_kernel_results = sample_fold( - num_steps=num_results, + + burnin_seed, sampling_seed = random.split_seed(seed, n=2) + + # Burn-in run + chain_state, kr = exp_sample_lib.step_kernel( + num_steps=num_burnin_steps, current_state=current_state, previous_kernel_results=previous_kernel_results, kernel=kernel, - reducer=trace_reducer, - num_burnin_steps=num_burnin_steps, - num_steps_between_results=num_steps_between_results, + return_final_kernel_results=True, parallel_iterations=parallel_iterations, - seed=seed, - name=name, - ) + seed=burnin_seed, + name='burnin') - all_states, trace = trace_results - if return_final_kernel_results: - return sample.CheckpointableStatesAndTrace( - all_states=all_states, - trace=trace, - final_kernel_results=final_kernel_results) - else: - if no_trace: - return all_states - else: - return sample.StatesAndTrace(all_states=all_states, trace=trace) + thinning_k = thinning_kernel.ThinningKernel( + kernel, num_steps_to_skip=num_steps_between_results) + + # ThinningKernel doesn't wrap the kernel_results structure, so we don't need + # any of the usual munging. + results = run.run_kernel( + num_results=num_results, + current_state=chain_state, + previous_kernel_results=kr, + kernel=thinning_k, + trace_fn=trace_fn, + parallel_iterations=parallel_iterations, + seed=sampling_seed, + name='sampling') + + del results.resume_kwargs['reducer'] + del results.resume_kwargs['previous_reducer_state'] + return results diff --git a/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py b/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py index bbbbee0952..0779b0c112 100644 --- a/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py @@ -44,11 +44,9 @@ def test_simple_operation(self): num_steps=5, current_state=0., kernel=fake_kernel, - reducer=fake_reducer, - ) + reducer=fake_reducer) reduction_rslt, last_sample, kernel_results = self.evaluate([ - reduction_rslt, last_sample, kr - ]) + reduction_rslt, last_sample, kr]) self.assertEqual(3, reduction_rslt) self.assertEqual(5, last_sample) self.assertEqual(5, kernel_results.counter_1) @@ -60,11 +58,9 @@ def test_simple_operation(self): current_state=last_sample, kernel=fake_kernel, reducer=fake_reducer, - previous_kernel_results=kernel_results, - ) + previous_kernel_results=kernel_results) reduction_rslt_2, last_sample_2, kernel_results_2 = self.evaluate([ - reduction_rslt_2, last_sample_2, kr_2 - ]) + reduction_rslt_2, last_sample_2, kr_2]) self.assertEqual(8, reduction_rslt_2) self.assertEqual(10, last_sample_2) self.assertEqual(10, kernel_results_2.counter_1) @@ -78,11 +74,9 @@ def test_reducer_warm_restart(self): current_state=0., kernel=fake_kernel, reducer=fake_reducer, - return_final_reducer_states=True, - ) + return_final_reducer_states=True) red_res, last_sample, kernel_results, red_states = self.evaluate([ - red_res, last_sample, kr, red_states - ]) + red_res, last_sample, kr, red_states]) self.assertEqual(3, red_res) self.assertEqual(5, last_sample) self.assertEqual(5, kernel_results.counter_1) @@ -95,11 +89,9 @@ def test_reducer_warm_restart(self): previous_kernel_results=kernel_results, kernel=fake_kernel, reducer=fake_reducer, - previous_reducer_state=red_states - ) + previous_reducer_state=red_states) reduction_rslt_2, last_sample_2, kernel_results_2 = self.evaluate([ - reduction_rslt_2, last_sample_2, kr_2 - ]) + reduction_rslt_2, last_sample_2, kr_2]) self.assertEqual(5.5, reduction_rslt_2) self.assertEqual(10, last_sample_2) self.assertEqual(10, kernel_results_2.counter_1) @@ -113,11 +105,9 @@ def test_current_state(self, curr_state): num_steps=5, current_state=curr_state, kernel=fake_kernel, - reducer=fake_reducer, - ) + reducer=fake_reducer) reduction_rslt, last_sample, kernel_results = self.evaluate([ - reduction_rslt, last_sample, kr - ]) + reduction_rslt, last_sample, kr]) self.assertEqual( np.mean(np.arange(curr_state + 1, curr_state + 6)), reduction_rslt) self.assertEqual(curr_state + 5, last_sample) @@ -136,11 +126,9 @@ def reduction_target(current_state, kernel_results): num_steps=5, current_state=0., kernel=kernel, - reducer=reduction, - ) + reducer=reduction) reduction_rslt, last_sample, kernel_results = self.evaluate([ - reduction_rslt, last_sample, kr - ]) + reduction_rslt, last_sample, kr]) self.assertEqual(np.mean(np.arange(2, 12, 2)), reduction_rslt) self.assertEqual(5, last_sample) self.assertEqual(5, kernel_results.counter_1) @@ -151,17 +139,14 @@ def test_nested_reducers(self): fake_reducers = [ [test_fixtures.NaiveMeanReducer(), tfp.experimental.mcmc.CovarianceReducer()], - [test_fixtures.NaiveMeanReducer()] - ] + [test_fixtures.NaiveMeanReducer()]] reduction_rslt, last_sample, kr = tfp.experimental.mcmc.sample_fold( num_steps=3, current_state=0., kernel=fake_kernel, - reducer=fake_reducers, - ) + reducer=fake_reducers) reduction_rslt, last_sample, kernel_results = self.evaluate([ - reduction_rslt, last_sample, kr - ]) + reduction_rslt, last_sample, kr]) self.assertEqual(2, len(reduction_rslt)) self.assertEqual(2, len(reduction_rslt[0])) self.assertEqual(1, len(reduction_rslt[1])) @@ -196,8 +181,7 @@ def test_batched_streaming_covariance(self): current_state=tf.convert_to_tensor( [[0., 0., 0.], [0., 0., 0.]]), kernel=fake_kernel, - reducer=cov_reducer, - ) + reducer=cov_reducer) reduction_rslt = self.evaluate(reduction_rslt) self.assertEqual((2, 3, 3), reduction_rslt.shape) self.assertAllEqual(np.ones(reduction_rslt.shape) * 2, reduction_rslt) @@ -212,18 +196,15 @@ def test_seed_reproducibility(self): current_state=0., kernel=fake_kernel, reducer=fake_reducer, - seed=seed - ) + seed=seed) second_reduction_rslt, _, _ = tfp.experimental.mcmc.sample_fold( num_steps=3, current_state=0., kernel=fake_kernel, reducer=fake_reducer, - seed=seed - ) + seed=seed) first_reduction_rslt, second_reduction_rslt = self.evaluate([ - first_reduction_rslt, second_reduction_rslt - ]) + first_reduction_rslt, second_reduction_rslt]) self.assertEqual(first_reduction_rslt, second_reduction_rslt) def test_thinning_and_burnin(self): @@ -235,13 +216,11 @@ def test_thinning_and_burnin(self): kernel=fake_kernel, reducer=fake_reducer, num_burnin_steps=10, - num_steps_between_results=1, - ) + num_steps_between_results=1) reduction_rslt, last_sample, kernel_results = self.evaluate([ reduction_rslt, last_sample, - kr - ]) + kr]) self.assertEqual(16, reduction_rslt) self.assertEqual(20, last_sample) self.assertEqual( @@ -258,13 +237,11 @@ def test_tensor_thinning_and_burnin(self): kernel=fake_kernel, reducer=fake_reducer, num_burnin_steps=tf.convert_to_tensor(10), - num_steps_between_results=tf.convert_to_tensor(1), - ) + num_steps_between_results=tf.convert_to_tensor(1)) reduction_rslt, last_sample, kernel_results = self.evaluate([ reduction_rslt, last_sample, - kr - ]) + kr]) self.assertEqual(16, reduction_rslt) self.assertEqual(20, last_sample) self.assertEqual( @@ -280,11 +257,9 @@ def test_none_reducer(self): kernel=fake_kernel, reducer=None, num_burnin_steps=10, - num_steps_between_results=1, - ) + num_steps_between_results=1) last_sample, kernel_results = self.evaluate([ - last_sample, kr - ]) + last_sample, kr]) self.assertIsNone(reduction_rslt) self.assertEqual(20, last_sample) self.assertEqual(20, kernel_results.counter_1) @@ -298,11 +273,9 @@ def test_empty_reducer(self): kernel=fake_kernel, reducer=[], num_burnin_steps=10, - num_steps_between_results=1, - ) + num_steps_between_results=1) last_sample, kernel_results = self.evaluate([ - last_sample, kr - ]) + last_sample, kr]) self.assertEqual([], reduction_rslt) self.assertEqual(20, last_sample) self.assertEqual(20, kernel_results.counter_1) @@ -319,200 +292,93 @@ def setUp(self): def test_basic_operation(self): kernel = test_fixtures.TestTransitionKernel() - samples, kernel_results, final_results = tfp.experimental.mcmc.sample_chain( + result = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, current_state=0., kernel=kernel, - return_final_kernel_results=True, seed=test_util.test_seed()) + samples = result.trace + kernel_results = result.final_kernel_results self.assertAllClose( [2], tensorshape_util.as_list(samples.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_2.shape)) samples, kernel_results = self.evaluate([samples, kernel_results]) self.assertAllClose([1, 2], samples) - self.assertAllClose([1, 2], kernel_results.counter_1) - self.assertAllClose([2, 4], kernel_results.counter_2) + self.assertAllClose(2, kernel_results.counter_1) + self.assertAllClose(4, kernel_results.counter_2) # Warm-restart the underlying kernel. The Trace does not support warm # restart. - samples_2, kr_2 = tfp.experimental.mcmc.sample_chain( + result_2 = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, - current_state=samples[-1], - previous_kernel_results=final_results, - kernel=kernel, - ) - samples_2, kernel_results_2 = self.evaluate([samples_2, kr_2]) + **result.resume_kwargs) + samples_2, kernel_results_2 = self.evaluate( + [result_2.trace, result_2.final_kernel_results]) self.assertAllClose([3, 4], samples_2) - self.assertAllClose([3, 4], kernel_results_2.counter_1) - self.assertAllClose([6, 8], kernel_results_2.counter_2) - - def test_basic_operation_legacy(self): - kernel = test_fixtures.TestTransitionKernel(accepts_seed=False) - samples, kernel_results = tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel) - - self.assertAllClose( - [2], tensorshape_util.as_list(samples.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_2.shape)) - - samples, kernel_results = self.evaluate([samples, kernel_results]) - self.assertAllClose([1, 2], samples) - self.assertAllClose([1, 2], kernel_results.counter_1) - self.assertAllClose([2, 4], kernel_results.counter_2) + self.assertAllClose(4, kernel_results_2.counter_1) + self.assertAllClose(8, kernel_results_2.counter_2) def test_burn_in(self): kernel = test_fixtures.TestTransitionKernel() - samples, kernel_results = tfp.experimental.mcmc.sample_chain( + result = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, current_state=0., kernel=kernel, num_burnin_steps=1, seed=test_util.test_seed()) + samples = result.trace + kernel_results = result.final_kernel_results self.assertAllClose([2], tensorshape_util.as_list(samples.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_2.shape)) samples, kernel_results = self.evaluate([samples, kernel_results]) self.assertAllClose([2, 3], samples) - self.assertAllClose([2, 3], kernel_results.counter_1) - self.assertAllClose([4, 6], kernel_results.counter_2) + self.assertAllClose(3, kernel_results.counter_1) + self.assertAllClose(6, kernel_results.counter_2) def test_thinning(self): kernel = test_fixtures.TestTransitionKernel() - samples, kernel_results = tfp.experimental.mcmc.sample_chain( + result = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, current_state=0., kernel=kernel, num_steps_between_results=2, seed=test_util.test_seed()) + samples = result.trace + kernel_results = result.final_kernel_results self.assertAllClose([2], tensorshape_util.as_list(samples.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(kernel_results.counter_2.shape)) samples, kernel_results = self.evaluate([samples, kernel_results]) self.assertAllClose([3, 6], samples) - self.assertAllClose([3, 6], kernel_results.counter_1) - self.assertAllClose([6, 12], kernel_results.counter_2) - - def test_default_trace_named_tuple(self): - kernel = test_fixtures.TestTransitionKernel() - res = tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel, - seed=test_util.test_seed()) - - self.assertAllClose([2], tensorshape_util.as_list(res.all_states.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(res.trace.counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(res.trace.counter_2.shape)) - - res = self.evaluate(res) - self.assertAllClose([1, 2], res.all_states) - self.assertAllClose([1, 2], res.trace.counter_1) - self.assertAllClose([2, 4], res.trace.counter_2) - - def test_no_trace_fn(self): - kernel = test_fixtures.TestTransitionKernel() - samples = tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel, - trace_fn=None, - seed=test_util.test_seed()) - self.assertAllClose([2], tensorshape_util.as_list(samples.shape)) - samples = self.evaluate(samples) - self.assertAllClose([1, 2], samples) + self.assertAllClose(6, kernel_results.counter_1) + self.assertAllClose(12, kernel_results.counter_2) def test_custom_trace(self): kernel = test_fixtures.TestTransitionKernel() - res = tfp.experimental.mcmc.sample_chain( + res = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, current_state=0., kernel=kernel, trace_fn=lambda *args: args, seed=test_util.test_seed()) + trace = res.trace - self.assertAllClose([2], tensorshape_util.as_list(res.all_states.shape)) - self.assertAllClose([2], tensorshape_util.as_list(res.trace[0].shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(res.trace[1].counter_1.shape)) - self.assertAllClose( - [2], tensorshape_util.as_list(res.trace[1].counter_2.shape)) - - res = self.evaluate(res) - self.assertAllClose([1, 2], res.all_states) - self.assertAllClose([1, 2], res.trace[0]) - self.assertAllClose([1, 2], res.trace[1].counter_1) - self.assertAllClose([2, 4], res.trace[1].counter_2) - - def test_checkpointing(self): - kernel = test_fixtures.TestTransitionKernel() - res = tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel, - trace_fn=None, - return_final_kernel_results=True, - seed=test_util.test_seed()) - - self.assertAllClose([2], tensorshape_util.as_list(res.all_states.shape)) - self.assertEqual((), res.trace) + self.assertAllClose([2], tensorshape_util.as_list(trace[0].shape)) self.assertAllClose( - [], tensorshape_util.as_list(res.final_kernel_results.counter_1.shape)) + [2], tensorshape_util.as_list(trace[1].counter_1.shape)) self.assertAllClose( - [], tensorshape_util.as_list(res.final_kernel_results.counter_2.shape)) - - res = self.evaluate(res) - self.assertAllClose([1, 2], res.all_states) - self.assertAllClose(2, res.final_kernel_results.counter_1) - self.assertAllClose(4, res.final_kernel_results.counter_2) + [2], tensorshape_util.as_list(trace[1].counter_2.shape)) - def test_warnings_default(self): - with warnings.catch_warnings(record=True) as triggered: - kernel = test_fixtures.TestTransitionKernel() - tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel, - seed=test_util.test_seed()) - self.assertTrue( - any('Tracing all kernel results by default is deprecated' in str( - warning.message) for warning in triggered)) - - def test_no_warnings_explicit(self): - with warnings.catch_warnings(record=True) as triggered: - kernel = test_fixtures.TestTransitionKernel() - tfp.experimental.mcmc.sample_chain( - num_results=2, - current_state=0., - kernel=kernel, - trace_fn=lambda current_state, kernel_results: kernel_results, - seed=test_util.test_seed()) - self.assertFalse( - any('Tracing all kernel results by default is deprecated' in str( - warning.message) for warning in triggered)) + trace = self.evaluate(trace) + self.assertAllClose([1, 2], trace[0]) + self.assertAllClose([1, 2], trace[1].counter_1) + self.assertAllClose([2, 4], trace[1].counter_2) def test_is_calibrated(self): with warnings.catch_warnings(record=True) as triggered: kernel = test_fixtures.TestTransitionKernel(is_calibrated=False) - tfp.experimental.mcmc.sample_chain( + tfp.experimental.mcmc.sample_chain_with_burnin( num_results=2, current_state=0., kernel=kernel, @@ -535,12 +401,12 @@ def log_prob(x): target_log_prob_fn=log_prob, num_leapfrog_steps=3, step_size=1e-3) - return tfp.experimental.mcmc.sample_chain( + results = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=5, num_burnin_steps=4, current_state=initial_state, - kernel=kernel, - trace_fn=None) + kernel=kernel) + return results.trace # Checking that shape inference doesn't fail. sample(2) @@ -548,24 +414,21 @@ def log_prob(x): def test_seed_reproducibility(self): first_fake_kernel = test_fixtures.RandomTransitionKernel() second_fake_kernel = test_fixtures.RandomTransitionKernel() - seed = samplers.sanitize_seed(test_util.test_seed()) - first_final_state = tfp.experimental.mcmc.sample_chain( + seed = test_util.test_seed(sampler_type='stateless') + first_trace = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=5, current_state=0., kernel=first_fake_kernel, - seed=seed, - ) - second_final_state = tfp.experimental.mcmc.sample_chain( + seed=seed).trace + second_trace = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=5, current_state=1., # difference should be irrelevant kernel=second_fake_kernel, - seed=seed, - ) - first_final_state, second_final_state = self.evaluate([ - first_final_state, second_final_state - ]) + seed=seed).trace + first_trace, second_trace = self.evaluate([ + first_trace, second_trace]) self.assertAllCloseNested( - first_final_state, second_final_state, rtol=1e-6) + first_trace, second_trace, rtol=1e-6) @test_util.test_graph_mode_only @@ -589,7 +452,7 @@ def target_log_prob(x, y): z = tf.linalg.triangular_solve(true_cov_chol, z[..., tf.newaxis])[..., 0] return -0.5 * tf.reduce_sum(z**2., axis=-1) - states = tfp.experimental.mcmc.sample_chain( + states = tfp.experimental.mcmc.sample_chain_with_burnin( num_results=num_results, current_state=[dtype(-2), dtype(2)], kernel=tfp.mcmc.HamiltonianMonteCarlo( @@ -598,8 +461,7 @@ def target_log_prob(x, y): num_leapfrog_steps=2), num_burnin_steps=20, num_steps_between_results=1, - trace_fn=None, - seed=test_util.test_seed()) + seed=test_util.test_seed()).trace self.assertAllEqual(dict(target_calls=1), counter) states = tf.stack(states, axis=-1) diff --git a/tensorflow_probability/python/experimental/mcmc/thinning_kernel.py b/tensorflow_probability/python/experimental/mcmc/thinning_kernel.py new file mode 100644 index 0000000000..53429a235a --- /dev/null +++ b/tensorflow_probability/python/experimental/mcmc/thinning_kernel.py @@ -0,0 +1,120 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Kernel for Thinning.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf +from tensorflow_probability.python.experimental.mcmc import sample +from tensorflow_probability.python.mcmc import kernel as kernel_base +from tensorflow_probability.python.mcmc.internal import util as mcmc_util + + +__all__ = [ + 'ThinningKernel', +] + + +class ThinningKernel(kernel_base.TransitionKernel): + """Discards samples to perform thinning. + + `ThinningKernel` is a composable `TransitionKernel` that thins samples + returned by its `inner_kernel`. All Transition Kernels wrapping it will only + see non-discarded samples. + """ + + def __init__( + self, + inner_kernel, + num_steps_to_skip, + name=None): + """Instantiates this object. + + Args: + inner_kernel: `TransitionKernel` whose `one_step` will generate + MCMC results. + num_steps_to_skip: Integer or scalar `Tensor` representing + the number of chain steps skipped before collecting a result. + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "thinning_kernel"). + """ + self._parameters = dict( + inner_kernel=inner_kernel, + num_steps_to_skip=num_steps_to_skip, + name=name or 'thinning_kernel' + ) + + def one_step(self, current_state, previous_kernel_results, seed=None): + """Collects one non-thinned chain state. + + Args: + current_state: `Tensor` or Python `list` of `Tensor`s + representing the current state(s) of the Markov chain(s), + previous_kernel_results: `collections.namedtuple` containing `Tensor`s + representing values from previous calls to this function (or from the + `bootstrap_results` function). + seed: Optional seed for reproducible sampling. + + Returns: + new_chain_state: Newest non-discarded MCMC chain state drawn from + the `inner_kernel`. + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. + """ + with tf.name_scope( + mcmc_util.make_name(self.name, 'thinned_kernel', 'one_step')): + return sample.step_kernel( + num_steps=self.num_steps_to_skip + 1, + current_state=current_state, + previous_kernel_results=previous_kernel_results, + kernel=self.inner_kernel, + return_final_kernel_results=True, + seed=seed, + name=self.name) + + def bootstrap_results(self, init_state): + """Instantiates a new kernel state with no calls. + + Args: + init_state: `Tensor` or Python `list` of `Tensor`s representing the + state(s) of the Markov chain(s). + + Returns: + kernel_results: `collections.namedtuple` of `Tensor`s representing + internal calculations made within this function. + """ + return self.inner_kernel.bootstrap_results(init_state) + + @property + def is_calibrated(self): + return self.inner_kernel.is_calibrated + + @property + def inner_kernel(self): + return self._parameters['inner_kernel'] + + @property + def num_steps_to_skip(self): + return self._parameters['num_steps_to_skip'] + + @property + def name(self): + return self._parameters['name'] + + @property + def parameters(self): + return self._parameters diff --git a/tensorflow_probability/python/experimental/mcmc/thinning_kernel_test.py b/tensorflow_probability/python/experimental/mcmc/thinning_kernel_test.py new file mode 100644 index 0000000000..87bfe21d63 --- /dev/null +++ b/tensorflow_probability/python/experimental/mcmc/thinning_kernel_test.py @@ -0,0 +1,188 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for ThinningKernel TransitionKernel.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports +import numpy as np + +import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp +from tensorflow_probability.python.experimental.mcmc.internal import test_fixtures +from tensorflow_probability.python.internal import test_util + + +@test_util.test_all_tf_execution_regimes +class ThinningTest(test_util.TestCase): + + def test_thinning(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=1,) + first_state, kernel_results = thinner.one_step( + 0., thinner.bootstrap_results(0.)) + second_state, kernel_results = thinner.one_step( + first_state, kernel_results) + first_state, second_state, kernel_results = self.evaluate([ + first_state, second_state, kernel_results]) + self.assertEqual(2, first_state) + self.assertEqual(4, second_state) + self.assertEqual(4, kernel_results.counter_1) + self.assertEqual(8, kernel_results.counter_2) + + def test_no_thinning(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=0,) + first_state, kernel_results = thinner.one_step( + 0., thinner.bootstrap_results(0.)) + second_state, kernel_results = thinner.one_step( + first_state, kernel_results) + first_state, second_state, kernel_results = self.evaluate([ + first_state, second_state, kernel_results]) + self.assertEqual(1, first_state) + self.assertEqual(2, second_state) + self.assertEqual(2, kernel_results.counter_1) + self.assertEqual(4, kernel_results.counter_2) + + def test_cold_start(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=1,) + first_state, _ = thinner.one_step( + 0., thinner.bootstrap_results(0.)) + second_state, kernel_results = thinner.one_step( + first_state, thinner.bootstrap_results(first_state)) + first_state, second_state, kernel_results = self.evaluate([ + first_state, second_state, kernel_results]) + self.assertEqual(2, first_state) + self.assertEqual(4, second_state) + self.assertEqual(2, kernel_results.counter_1) + self.assertEqual(4, kernel_results.counter_2) + + def test_is_calibrated(self): + calibrated_kernel = test_fixtures.TestTransitionKernel() + uncalibrated_kernel = test_fixtures.TestTransitionKernel( + is_calibrated=False) + calibrated_thinner = tfp.experimental.mcmc.ThinningKernel( + calibrated_kernel, 0) + uncalibrated_thinner = tfp.experimental.mcmc.ThinningKernel( + uncalibrated_kernel, 0) + self.assertTrue(calibrated_thinner.is_calibrated) + self.assertFalse(uncalibrated_thinner.is_calibrated) + + def test_with_composed_kernel(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + cov_reducer = tfp.experimental.mcmc.CovarianceReducer() + reducer_kernel = tfp.experimental.mcmc.WithReductions( + inner_kernel=tfp.experimental.mcmc.ThinningKernel( + inner_kernel=fake_inner_kernel, + num_steps_to_skip=2,), + reducer=cov_reducer + ) + current_state, kernel_results = 0., reducer_kernel.bootstrap_results(0.) + for _ in range(2): + current_state, kernel_results = reducer_kernel.one_step( + current_state, kernel_results) + cov = self.evaluate(cov_reducer.finalize(kernel_results.reduction_results)) + self.assertAllEqual(6, current_state) + self.assertAllEqual(6, kernel_results.inner_results.counter_1) + self.assertAllEqual(12, kernel_results.inner_results.counter_2) + self.assertNear(np.var([3, 6]), cov, err=1e-6) + + def test_tf_while(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=1,) + + def _loop_body(i, curr_state, pkr): + new_state, kernel_results = thinner.one_step( + curr_state, pkr, + ) + return (i + 1, new_state, kernel_results) + + pkr = thinner.bootstrap_results(0.) + _, final_sample, kernel_results = tf.while_loop( + lambda i, *_: i < 2, + _loop_body, + (0., 0., pkr), + ) + final_sample, kernel_results = self.evaluate([ + final_sample, kernel_results]) + self.assertEqual(4, final_sample) + self.assertEqual(4, kernel_results.counter_1) + self.assertEqual(8, kernel_results.counter_2) + + def test_tensor_thinning(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=tf.convert_to_tensor(1),) + + def _loop_body(i, curr_state, pkr): + new_state, kernel_results = thinner.one_step( + curr_state, pkr, + ) + return (i + 1, new_state, kernel_results) + + pkr = thinner.bootstrap_results(0.) + _, final_sample, kernel_results = tf.while_loop( + lambda i, _, __: i < 2, + _loop_body, + (0., 0., pkr), + ) + + final_sample, kernel_results = self.evaluate([ + final_sample, kernel_results]) + self.assertEqual(4, final_sample) + self.assertEqual(4, kernel_results.counter_1) + self.assertEqual(8, kernel_results.counter_2) + + def test_non_static_thinning(self): + fake_inner_kernel = test_fixtures.TestTransitionKernel() + num_steps_to_skip = tf.Variable(1, dtype=tf.int32) + thinner = tfp.experimental.mcmc.ThinningKernel( + fake_inner_kernel, + num_steps_to_skip=num_steps_to_skip) + + def _loop_body(i, curr_state, pkr): + new_state, kernel_results = thinner.one_step( + curr_state, pkr, + ) + return (i + 1, new_state, kernel_results) + + pkr = thinner.bootstrap_results(0.) + _, final_sample, kernel_results = tf.while_loop( + lambda i, _, __: i < 2, + _loop_body, + (0., 0., pkr), + ) + self.evaluate([num_steps_to_skip.initializer]) + final_sample, kernel_results = self.evaluate([ + final_sample, kernel_results]) + self.assertEqual(4, final_sample) + self.assertEqual(4, kernel_results.counter_1) + self.assertEqual(8, kernel_results.counter_2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_probability/python/experimental/nn/BUILD b/tensorflow_probability/python/experimental/nn/BUILD index 959b560aa8..5b831a5b29 100644 --- a/tensorflow_probability/python/experimental/nn/BUILD +++ b/tensorflow_probability/python/experimental/nn/BUILD @@ -105,7 +105,7 @@ py_test( py_library( name = "convolutional_layers_v2", srcs = ["convolutional_layers_v2.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":layers", ":variational_base", diff --git a/tensorflow_probability/python/experimental/nn/util/BUILD b/tensorflow_probability/python/experimental/nn/util/BUILD index 7fefb0db42..3661c714e3 100644 --- a/tensorflow_probability/python/experimental/nn/util/BUILD +++ b/tensorflow_probability/python/experimental/nn/util/BUILD @@ -28,7 +28,7 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY3", deps = [ - ":im2row", + ":convolution_util", ":random_variable", ":utils", "//tensorflow_probability/python/internal:all_util", @@ -36,24 +36,32 @@ py_library( ) py_library( - name = "im2row", - srcs = ["im2row.py"], + name = "convolution_util", + srcs = ["convolution_util.py"], srcs_version = "PY3", deps = [ + ":utils", # tensorflow dep, + "//tensorflow_probability/python/internal:assert_util", + "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", ], ) py_test( - name = "im2row_test", + name = "convolution_util_test", size = "medium", - srcs = ["im2row_test.py"], + srcs = ["convolution_util_test.py"], python_version = "PY3", + shard_count = 4, srcs_version = "PY3", deps = [ + ":convolution_util", + # absl/testing:parameterized dep, + # numpy dep, # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:test_util", ], ) diff --git a/tensorflow_probability/python/experimental/nn/util/__init__.py b/tensorflow_probability/python/experimental/nn/util/__init__.py index 3c0c4c4417..8678e5c9ba 100644 --- a/tensorflow_probability/python/experimental/nn/util/__init__.py +++ b/tensorflow_probability/python/experimental/nn/util/__init__.py @@ -17,7 +17,12 @@ from __future__ import division from __future__ import print_function -from tensorflow_probability.python.experimental.nn.util.im2row import im2row +from tensorflow_probability.python.experimental.nn.util.convolution_util import im2row +from tensorflow_probability.python.experimental.nn.util.convolution_util import im2row_index +from tensorflow_probability.python.experimental.nn.util.convolution_util import make_convolution_fn +from tensorflow_probability.python.experimental.nn.util.convolution_util import make_convolution_transpose_fn_with_dilation +from tensorflow_probability.python.experimental.nn.util.convolution_util import make_convolution_transpose_fn_with_subkernels +from tensorflow_probability.python.experimental.nn.util.convolution_util import make_convolution_transpose_fn_with_subkernels_matrix from tensorflow_probability.python.experimental.nn.util.random_variable import CallOnce from tensorflow_probability.python.experimental.nn.util.random_variable import RandomVariable from tensorflow_probability.python.experimental.nn.util.utils import batchify_op @@ -53,10 +58,15 @@ 'flatten_rightmost', 'halflife_decay', 'im2row', + 'im2row_index', 'make_fit_op', 'make_kernel_bias', 'make_kernel_bias_posterior_mvn_diag', 'make_kernel_bias_prior_spike_and_slab', + 'make_convolution_fn', + 'make_convolution_transpose_fn_with_dilation', + 'make_convolution_transpose_fn_with_subkernels', + 'make_convolution_transpose_fn_with_subkernels_matrix', 'negloglik', 'prepare_conv_args', 'prepare_strides', diff --git a/tensorflow_probability/python/experimental/nn/util/convolution_util.py b/tensorflow_probability/python/experimental/nn/util/convolution_util.py new file mode 100644 index 0000000000..da4c00fdba --- /dev/null +++ b/tensorflow_probability/python/experimental/nn/util/convolution_util.py @@ -0,0 +1,910 @@ +# Lint as: python2, python3 +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Functions for framing `conv` as `matmul`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python.experimental.nn.util import utils +from tensorflow_probability.python.internal import assert_util +from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import prefer_static as ps + +__all__ = [ + 'im2row', + 'im2row_index', + 'make_convolution_fn', + 'make_convolution_transpose_fn_with_dilation', + 'make_convolution_transpose_fn_with_subkernels', + 'make_convolution_transpose_fn_with_subkernels_matrix', +] + + +def im2row(x, + block_shape, + slice_step=(1, 1), + padding='VALID', + name=None): + """Rearrange image blocks into rows. + + This function can be used to implement 2D convolution as a `matmul`, e.g., + + `tf.nn.conv2d(x, k) = tf.matmul( + tf.experimental.nn.util.im2row(x), tf.reshape(k, shape=[-1, out_size]))`. + + Args: + x: Rank 3 (or more) Tensor representing 2D images. + block_shape: Length-2 vector representing the block or "filter" shape. + slice_step: Length-2 vector specifying the convolution stride length. + Default value: `(1, 1)`. + padding: One of `'VALID'` or `'SAME'` (case insensitive). + Default value: `'VALID'`. + name: Python `str` used to describe ops created by this function. + Default value: `None` (i.e., `'im2col'`). + + Returns: + im2row_x: batch of matrices representing subblock copies of `x`. + Same batch shape as `x` but with rightmost shape: + `batch_shape + [oh * ow, block_shape[0] * block_shape[1] * channels]`, + where `oh = (h - block_shape[0] + 1) // slice_step[0]` and + `ow = (w - block_shape[1] + 1) // slice_step[1]` when `padding = 'VALID'` + and `oh = h` and `ow = w` when `padding = 'SAME'`. + shape: shape `Tensor` equivalent to: + `batch_shape + [oh, ow, block_shape[0] * block_shape[1] * channels]` where + `oh, ow` are defined as above. + """ + with tf.name_scope(name or 'im2row'): + padding = _validate_padding(padding) + if padding == 'VALID': + pass # Do nothing. + elif padding == 'SAME': + raise NotImplementedError( + 'Argument padding="SAME" not implemented.') + # TODO(jvdillon): See if the following works: + # fh, fw = block_shape + # o = 1 if data_format == 'NHWC' else 0 + # n = ps.maximum(0, ps.rank(x) - 3) + # paddings = ps.pad( + # [[0, fh - 1], [0, fw - 1]], + # paddings=[[n + 1 - o, o], [0, 0]], + # constant_values=0) + # x = tf.pad(x, paddings=paddings, constant_values=0) + # padding = 'VALID' + else: + assert False # Can't be here. + x_shape = ps.shape(x) + idx, s = im2row_index( + x_shape, block_shape=block_shape, slice_step=slice_step) + flat_shape = ps.pad( + x_shape[:-3], paddings=[[0, 1]], constant_values=-1) + x = tf.gather(tf.reshape(x, flat_shape), idx, axis=-1) # == np.take + return tf.reshape(x, s) + + +def im2row_index(input_shape, + block_shape, + rank=2, + slice_step=(1, 1), + dilations=(1, 1), + dtype=tf.int32, + transpose=False, + validate_args=False, + name=None): + """Computes indexes into a flattened image for building `im2row`.""" + with tf.name_scope(name or 'im2row_index'): + if tf.get_static_value(rank) != 2: + raise NotImplementedError('Argument `rank` currently only supports `2`; ' + 'saw "{}".'.format(rank)) + fh, fw = prepare_tuple_argument( + block_shape, n=rank, arg_name='block_shape', + validate_args=validate_args) + sh, sw = prepare_tuple_argument( + slice_step, n=rank, arg_name='slice_step', validate_args=validate_args) + dh, dw = prepare_tuple_argument( + dilations, n=rank, arg_name='dilations', validate_args=validate_args) + + # 1) Process input arguments. + batch_shape, h, w, c = ps.split( + ps.reshape(ps.cast(input_shape, dtype=dtype), shape=[-1]), + num_or_size_splits=[-1, 1, 1, 1]) + h, w, c = h[0], w[0], c[0] + + tot_fh = dh * (fh - 1) + 1 + tot_fw = dw * (fw - 1) + 1 + + # 2) Assemble all block start positions as indexes into the flattened image. + # start_idx.shape = [fh, fw, c] + if transpose: + last_element = lambda size, step: size - (size - 1) % step - 1 + w_step = c * dw + h_step = c * w * dh + last_w = last_element(c * tot_fw, w_step) + last_h = last_element(c * w * tot_fh, h_step) + start_idx = cartesian_add([ + ps.range(last_h, -1, delta=-h_step, dtype=dtype), + ps.range(last_w, -1, delta=-w_step, dtype=dtype), + ps.range(c, delta=1, dtype=dtype), + ]) + else: + start_idx = cartesian_add([ + ps.range(c * w * tot_fh, delta=c * w * dh, dtype=dtype), + ps.range(c * tot_fw, delta=c * dw, dtype=dtype), + ps.range(c, delta=1, dtype=dtype), + ]) + + # 3) Assemble all block offsets (into flattened image). + eh = h - tot_fh + 1 + ew = w - tot_fw + 1 + + offset_idx = cartesian_add([ + ps.range(w * eh, delta=w * sh, dtype=dtype), + ps.range(ew, delta=sw, dtype=dtype), + ]) + + offset_idx = offset_idx * c + oh = (eh - 1) // sh + 1 # out height + ow = (ew - 1) // sw + 1 # out width + + # 4) Combine block start/offset pairs. + # shape = [(eh // sh) * (ew // sw), fh * fw * c] + idx = cartesian_add([offset_idx, start_idx]) + new_shape = ps.concat( + [batch_shape, ps.convert_to_shape_tensor([oh, ow, fh * fw * c])], + axis=0) + return idx, new_shape + + +def cartesian_add(xs): + """Adds a list of vectors by cumulatively expanding a dimension.""" + return sum(ps.reshape(x, shape=[-1] + [1] * (len(xs) - 1 - i)) + for i, x in enumerate(xs)) + + +def _validate_padding(padding): + """Verify correctness of `padding` argument.""" + padding_ = str(padding).upper() + if padding_ in {'SAME', 'VALID'}: + return padding_ + raise ValueError( + 'Argument padding="{}" not recognized; must be one of ' + '{{"VALID", "SAME"}} (case insensitive).'.format(padding)) + + +# TODO(emilyaf): Finish docstrings. +def make_convolution_fn( + filter_shape, rank, strides, padding, dilations=None, dtype=tf.int32, + validate_args=False, name=None): + """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" + with tf.name_scope(name or 'conv2d'): + if tf.get_static_value(rank) != 2: + raise NotImplementedError('Argument `rank` currently only supports `2`; ' + 'saw "{}".'.format(rank)) + [ + filter_shape, + rank, + strides, + padding, + dilations, + ] = prepare_conv_args( + filter_shape, rank=rank, strides=strides, padding=padding, + dilations=dilations, validate_args=validate_args) + + def op(x, kernel): + input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) + x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') + kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') + + batch_shape, event_shape = ps.split( + ps.shape(x), num_or_size_splits=[-1, 3]) + xh, xw, c_in = ps.unstack(event_shape, num=3) + fh, fw = filter_shape + + assertions = _maybe_validate_input_shapes( + ps.shape(kernel), channels_in=c_in, filter_height=fh, + filter_width=fw, validate_args=validate_args) + + with tf.control_dependencies(assertions): + if tf.get_static_value(ps.rank(kernel)) == 2: + flat_x = tf.reshape(x, shape=ps.concat([[-1], event_shape], axis=0)) + flat_y = tf.nn.conv2d( + x, + filters=tf.reshape(kernel, shape=[fh, fw, c_in, -1]), + strides=strides, + padding=padding, + data_format='NHWC', + dilations=dilations) + output_shape = ps.shape(flat_y)[-3:] + return tf.reshape( + flat_y, shape=ps.concat([batch_shape, output_shape], axis=0)) + + pad_values = [ + _get_conv_padding( + xdim, filter_dim=k, stride=s, dilation=d, padding=padding) + for (xdim, k, s, d) in zip((xh, xw), filter_shape, strides, dilations) + ] + + idx, shape = im2row_index( + (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in), + block_shape=filter_shape, slice_step=strides, dilations=dilations, + dtype=dtype) + + if padding == 'SAME': + n = ps.maximum(0, ps.rank(x) - 3) + paddings = ps.pad( + pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) + x = tf.pad(x, paddings=paddings, constant_values=0) + + flat_shape = ps.pad( + batch_shape, paddings=[[0, 1]], constant_values=-1) + flat_x = tf.gather(tf.reshape(x, shape=flat_shape), indices=idx, axis=-1) + im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) + return tf.matmul(im_x, kernel[..., tf.newaxis, :, :]) + return op + + +def _get_conv_padding(xdim, filter_dim, stride, dilation, padding): + """Returns the number of zeros to pad at the start and end of an axis.""" + if padding == 'VALID': + return (0, 0) + elif padding == 'SAME': + tot_k = dilation * (filter_dim - 1) + 1 + tot_pad = tf.maximum(tot_k - ((xdim - 1) % stride + 1), 0) + pad_start = tot_pad // 2 + return pad_start, tot_pad - pad_start + + +def make_convolution_transpose_fn_with_dilation( + filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, + validate_args=False, name=None): + """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`. + + This version tends to be fastest on GPU. It implements the transposed + convolution as a regular convolution of an image that is dilated by + interleaving rows and columns of zeros equal to the number of strides. + + Args: + filter_shape: ... + strides: ... + padding: ... + rank: ... + dilations: ... + dtype: ... + validate_args: ... + name: ... + Returns: + convolution_transpose_fn: A callable that takes an input `Tensor` and kernel + and applies the transpose convolution operation. + """ + with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): + + if tf.get_static_value(rank) != 2: + raise NotImplementedError('Argument `rank` currently only supports `2`; ' + 'saw "{}".'.format(rank)) + [ + filter_shape, + rank, + strides, + padding, + dilations, + ] = prepare_conv_args( + filter_shape, rank=rank, strides=strides, padding=padding, + dilations=dilations, validate_args=validate_args) + + sh, sw = strides + fh, fw = filter_shape + + pad_values = [ + _get_transpose_conv_dilated_padding( + k, stride=s, dilation=d, padding=padding) + for (k, s, d) in zip(filter_shape, strides, dilations)] + + def op(x, kernel): + input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) + x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') + kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') + + batch_shape, event_shape = ps.split( + ps.shape(x), num_or_size_splits=[-1, 3]) + xh, xw, c_in = ps.unstack(event_shape, num=3) + kernel_shape = ps.shape(kernel) + assertions = _maybe_validate_input_shapes( + kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, + validate_args=validate_args) + + with tf.control_dependencies(assertions): + # If the kernel does not have batch shape, fall back to + # `conv2d_transpose` (unless dilations > 1, which is not implemented in + # `conv2d_transpose`). + if (tf.get_static_value(ps.rank(kernel)) == 2 + and all(d == 1 for d in dilations)): + return _call_conv2d_transpose( + x, kernel, filter_shape, strides, padding, dilations, + kernel_shape[-1], batch_shape, event_shape) + + idx, shape = im2row_index( + (xh * sh + sum(pad_values[0]), xw * sw + sum(pad_values[1]), c_in), + block_shape=filter_shape, slice_step=(1, 1), dilations=dilations, + dtype=dtype, transpose=True) + + n = ps.maximum(0, ps.rank(x) - 3) + paddings = ps.pad( + pad_values, paddings=[[n, 1], [0, 0]], constant_values=0) + + # Interleave the rows and columns of the input with rows and columns of + # zeros equal to the number of strides. + x_half_dilated = tf.concat( + [tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)], axis=0), + dtype=input_dtype), + tf.reshape( + x, shape=ps.concat([batch_shape, (xh * xw, 1, c_in)], axis=0)) + ], axis=-2) + y = tf.reshape( + x_half_dilated, + shape=ps.concat([batch_shape, (xh, 1, xw * sw, c_in)], axis=0)) + + x = tf.reshape( + tf.concat( + [tf.zeros( + ps.concat( + [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0), + dtype=input_dtype), y], axis=-3), + shape=ps.concat([batch_shape, (xh * sh, xw * sw, c_in)], axis=0)) + x_pad = tf.pad(x, paddings=paddings, constant_values=0) + flat_shape = ps.pad(batch_shape, paddings=[[0, 1]], constant_values=-1) + flat_x = tf.gather( + tf.reshape(x_pad, shape=flat_shape), indices=idx, axis=-1) + im_x = tf.reshape(flat_x, shape=ps.concat([batch_shape, shape], axis=0)) + return tf.matmul(im_x, kernel[..., tf.newaxis, :, :]) + return op + + +def make_convolution_transpose_fn_with_subkernels_matrix( + filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, + validate_args=False, name=None): + """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" + with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): + + if tf.get_static_value(rank) != 2: + raise NotImplementedError('Argument `rank` currently only supports `2`; ' + 'saw "{}".'.format(rank)) + + strides = tf.get_static_value(strides) + if not isinstance(strides, int): + raise ValueError('Argument `strides` must be a statically known integer.' + 'Saw: {}'.format(strides)) + + [ + filter_shape, + rank, + _, + padding, + dilations, + ] = prepare_conv_args( + filter_shape, rank=rank, strides=strides, padding=padding, + dilations=dilations, validate_args=validate_args) + + fh, fw = filter_shape + dh, dw = dilations + + # Determine maximum filter height and filter width of sub-kernels. + sub_fh = (fh - 1) // strides + 1 + sub_fw = (fw - 1) // strides + 1 + + def loop_body(i_, event_ind): + i = i_ // strides + j = i_ % strides + + i_ind = ps.range(i * fw, fw * fh, delta=strides * fw, dtype=dtype) + j_ind = ps.range(j, fw, delta=strides, dtype=dtype) + + nc = cartesian_add([i_ind, j_ind]) + ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0]) + + k = ps.reshape( + cartesian_add( + [ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype), + ps.range(ps.shape(nc)[1], dtype=dtype)]), + shape=[-1]) + last_j = strides - (fw - j - 1) % strides - 1 + last_i = strides - (fh - i - 1) % strides - 1 + kernel_ind = ps.stack( + [k, ps.ones_like(k) * last_i * strides + last_j], axis=1) + event_ind = ps.tensor_scatter_nd_update( + event_ind, ind[..., tf.newaxis], kernel_ind) + + return i_ + 1, event_ind + + event_ind = ps.zeros((fh * fw, 2), dtype=dtype) + _, event_ind = tf.while_loop( + lambda i, _: i < strides ** 2, + loop_body, + [tf.zeros([], dtype=dtype), event_ind]) + + tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding( + fh, stride=strides, dilation=dh, padding=padding) + tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding( + fw, stride=strides, dilation=dw, padding=padding) + + pad_bottom = (tot_pad_bottom - 1) // strides + 1 + pad_top = (tot_pad_top - 1) // strides + 1 + pad_right = (tot_pad_right - 1) // strides + 1 + pad_left = (tot_pad_left - 1) // strides + 1 + padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right)) + + truncate_top = pad_top * strides - tot_pad_top + truncate_left = pad_left * strides - tot_pad_left + + def op(x, kernel): + input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) + x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') + kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') + + batch_shape, event_shape = ps.split( + ps.shape(x), num_or_size_splits=[-1, 3]) + xh, xw, c_in = ps.unstack(event_shape, num=3) + + kernel_shape = ps.shape(kernel) + c_out = kernel_shape[-1] + kernel_batch = kernel_shape[:-2] + assertions = _maybe_validate_input_shapes( + kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, + validate_args=validate_args) + + with tf.control_dependencies(assertions): + + # If the kernel does not have batch shape, fall back to + # `conv2d_transpose` (unless dilations > 1, which is not implemented in + # `conv2d_transpose`). + if (tf.get_static_value(ps.rank(kernel)) == 2 + and all(d == 1 for d in dilations)): + return _call_conv2d_transpose( + x, kernel=kernel, filter_shape=filter_shape, + strides=(strides,) * rank, padding=padding, dilations=dilations, + c_out=c_out, batch_shape=batch_shape, event_shape=event_shape) + + n = ps.maximum(0, ps.rank(x) - 3) + paddings = ps.pad( + padding_vals, + paddings=[[n, 1], [0, 0]], + constant_values=0) + + x_pad = tf.pad(x, paddings=paddings, constant_values=0) + x_pad_shape = ps.shape(x_pad)[:-3] + flat_shape = ps.pad(x_pad_shape, paddings=[[0, 1]], constant_values=-1) + flat_x = tf.reshape(x_pad, shape=flat_shape) + + idx, s = im2row_index( + (xh + tf.reduce_sum(padding_vals[0]), + xw + tf.reduce_sum(padding_vals[1]), c_in), + block_shape=(sub_fh, sub_fw), slice_step=(1, 1), dilations=dilations + ) + + x_ = tf.gather(flat_x, indices=idx, axis=-1) + im_x = tf.reshape(x_, shape=ps.concat([x_pad_shape, s], axis=0)) + + # Add channels to subkernel indices + idx_event = event_ind * [[c_in, 1]] + idx_event_channels = ( + idx_event[tf.newaxis] + + tf.stack([ps.range(c_in), tf.zeros((c_in,), dtype=dtype)], + axis=-1)[:, tf.newaxis, :]) + idx_event = tf.squeeze( + tf.batch_to_space( + idx_event_channels, block_shape=[c_in], crops=[[0, 0]]), axis=0) + idx_event_broadcast = tf.broadcast_to( + idx_event, + shape=ps.concat([kernel_batch, ps.shape(idx_event)], axis=0)) + + # Add cartesian product of batch indices, since scatter_nd can only be + # applied to leading dimensions. + idx_batch = tf.stack( + tf.meshgrid( + *[ps.range(b_, delta=1, dtype=dtype) + for b_ in tf.unstack(kernel_batch)], indexing='ij'), + axis=ps.size(kernel_batch)) + + idx_batch = tf.cast(idx_batch, dtype=dtype) # empty tensor is float + + idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros( + (ps.shape(idx_event)[0], 1), dtype=dtype) + idx_kernel = tf.concat( + [idx_batch_broadcast, idx_event_broadcast], axis=-1) + + kernel_mat = tf.scatter_nd( + idx_kernel, + updates=kernel, + shape=ps.cast( + ps.concat([kernel_batch, + [sub_fh * sub_fw * c_in, strides ** 2, c_out]], + axis=0), + dtype=dtype)) + + kernel_mat = tf.reshape( + kernel_mat, + shape=ps.concat( + [ps.shape(kernel_mat)[:-2], [strides ** 2 * c_out]], axis=0)) + + kernel_mat = kernel_mat[..., tf.newaxis, :, :] + out = tf.matmul(im_x, kernel_mat) + broadcast_batch_shape = ps.broadcast_shape(batch_shape, kernel_batch) + + if strides > 1: + tot_size = tf.reduce_prod(broadcast_batch_shape) + flat_out = tf.reshape( + out, + shape=ps.concat([[tot_size], ps.shape(out)[-3:]], axis=0)) + out = tf.nn.depth_to_space(flat_out, block_size=strides) + + if padding == 'VALID': + out_height = fh + strides * (xh - 1) + out_width = fw + strides * (xw - 1) + elif padding == 'SAME': + out_height = xh * strides + out_width = xw * strides + + out = out[..., truncate_top:truncate_top + out_height, + truncate_left:truncate_left + out_width, :] + out = tf.reshape( + out, shape=ps.concat( + [broadcast_batch_shape, [out_height, out_width, c_out]], + axis=0)) + return out + return op + + +def make_convolution_transpose_fn_with_subkernels( + filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32, + validate_args=False, name=None): + """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`.""" + with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'): + + if tf.get_static_value(rank) != 2: + raise NotImplementedError('Argument `rank` currently only supports `2`; ' + 'saw "{}".'.format(rank)) + [ + filter_shape, + rank, + strides, + padding, + dilations, + ] = prepare_conv_args( + filter_shape, rank=rank, strides=strides, padding=padding, + dilations=dilations, validate_args=validate_args) + + sh, sw = strides + fh, fw = filter_shape + dh, dw = dilations + + # Determine maximum filter height and filter width of sub-kernels. + sub_fh = (fh - 1) // sh + 1 + sub_fw = (fw - 1) // sw + 1 + + def loop_body(i_, kernels_ind): + i = i_ // sw + j = i_ % sw + i_ind = ps.range((sh - i - 1)*fw, fw * fh, delta=sh*fw, dtype=dtype) + j_ind = ps.range((sw - j - 1), fw, delta=sw, dtype=dtype) + + last_j = sw - (fw - j - 1) % sw - 1 + last_i = sh - (fh - i - 1) % sh - 1 + pos = last_i * sw + last_j + + nc = cartesian_add([i_ind, j_ind]) + kernels_ind = kernels_ind.write( + sh * sw - pos - 1, ps.reverse(ps.reverse(nc, [0]), [1])) + + return i_ + 1, kernels_ind + + kernels_ind = tf.TensorArray(dtype=dtype, infer_shape=False, size=1, + dynamic_size=True) + + _, kernels_ind = tf.while_loop( + lambda i, _: i < sh * sw, + loop_body, + [0, kernels_ind]) + + tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding( + fh, stride=sh, dilation=dh, padding=padding) + tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding( + fw, stride=sw, dilation=dw, padding=padding) + + pad_bottom = (tot_pad_bottom - 1) // sh + 1 + pad_top = (tot_pad_top - 1) // sh + 1 + pad_right = (tot_pad_right - 1) // sw + 1 + pad_left = (tot_pad_left - 1) // sw + 1 + padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right)) + + truncate_top = pad_top * sh - tot_pad_top + truncate_left = pad_left * sw - tot_pad_left + + def op(x, kernel): + input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) + x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') + kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') + + batch_shape, event_shape = ps.split( + ps.shape(x), num_or_size_splits=[-1, 3]) + xh, xw, c_in = ps.unstack(event_shape, num=3) + + kernel_shape = ps.shape(kernel) + c_out = kernel_shape[-1] + kernel_batch = kernel_shape[:-2] + assertions = _maybe_validate_input_shapes( + kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, + validate_args=validate_args) + + with tf.control_dependencies(assertions): + # If the kernel does not have batch shape, fall back to + # `conv2d_transpose` (unless dilations > 1, which is not implemented in + # `conv2d_transpose`). + if (tf.get_static_value(ps.rank(kernel)) == 2 + and all(d == 1 for d in dilations)): + return _call_conv2d_transpose( + x, kernel, filter_shape, strides, padding, dilations, c_out, + batch_shape, event_shape) + + n = ps.maximum(0, ps.rank(x) - 3) + paddings = ps.pad( + padding_vals, + paddings=[[n, 1], [0, 0]], + constant_values=0) + x_pad = tf.pad(x, paddings=paddings, constant_values=0) + + ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1 + ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1 + + def loop_body(i, outputs): + subkernel_ind = kernels_ind.read(i) + fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2) + eh = ex_h + fh_ - 1 + ew = ex_w + fw_ - 1 + + subkernel_ind = ps.reshape( + ps.reshape(subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] + + ps.range(c_in), shape=[-1]) + + k = tf.gather(kernel, subkernel_ind, axis=-2) + ind, shape = im2row_index( + [eh, ew, c_in], + block_shape=(fh_, fw_), + slice_step=(1, 1), + dilations=dilations) + x_i = x_pad[..., :eh, :ew, :] + x_i_shape = ps.shape(x_i) + flat_shape = ps.pad( + x_i_shape[:-3], paddings=[[0, 1]], constant_values=-1) + flat_x = tf.reshape(x_i, flat_shape) + x_ = tf.gather(flat_x, ind, axis=-1) + im_x = tf.reshape(x_, ps.concat([x_i_shape[:-3], shape], axis=0)) + outputs = outputs.write( + i, + tf.matmul( + im_x, + tf.reshape( + k, ps.concat( + [kernel_batch, [1, fh_ * fw_* c_in, c_out]], axis=0))) + ) + return i + 1, outputs + + outputs = tf.TensorArray(dtype=input_dtype, infer_shape=False, size=1, + dynamic_size=True) + + _, outputs = tf.while_loop( + lambda i, _: i < sh * sw, + loop_body, + [0, outputs]) + + y = outputs.concat() + + m = tf.reduce_prod(ps.shape(y)[:-3]) + y_ = tf.reshape(y, shape=ps.concat([[m], ps.shape(y)[-3:]], axis=0)) + y2 = tf.batch_to_space( + y_, strides, crops=tf.zeros([2, 2], dtype=tf.int64)) + broadcast_batch_shape = ps.broadcast_shape(batch_shape, kernel_batch) + y2 = tf.reshape(y2, ps.concat( + [broadcast_batch_shape, ps.shape(y2)[-3:]], axis=0)) + + if padding == 'VALID': + out_height = fh + sh * (xh - 1) + out_width = fw + sw * (xw - 1) + elif padding == 'SAME': + out_height = xh * sh + out_width = xw * sw + + return y2[..., truncate_top:truncate_top+out_height, + truncate_left:truncate_left+out_width, :] + return op + + +def _maybe_validate_input_shapes( + kernel_shape, channels_in, filter_height, filter_width, validate_args): + """Validate shapes of inputs to convolution op.""" + k_dim = kernel_shape[-2] + k_dim_ = tf.get_static_value(k_dim) + expected_k_dim = filter_height * filter_width * channels_in + expected_k_dim_ = tf.get_static_value(expected_k_dim) + assertions = [] + if expected_k_dim_ is not None and k_dim_ is not None: + if expected_k_dim_ != k_dim_: + raise ValueError( + 'The size of the second-to-rightmost dimension of `kernel` ( ={}) ' + ' must equal `filter_height * filter_width * channels_in` ( ={}), ' + 'where `channels_in` is the size of the rightmost dimension of the ' + 'input.'.format(k_dim_, expected_k_dim_)) + elif validate_args: + assertions.append( + assert_util.assert_equal( + k_dim, expected_k_dim, + message=('The size of the second-to-rightmost dimension of `kernel`' + ' must equal `filter_height * filter_width * channels_in`,' + ' where `channels_in` is the size of the rightmost ' + 'dimension of the input.'))) + return assertions + + +def _get_transpose_conv_dilated_padding(filter_dim, stride, dilation, padding): + """Zero-padding for inputs dilated by strides.""" + tot_filter_dim = filter_dim + (filter_dim - 1) * (dilation - 1) + if padding == 'VALID': + tot_pad = 2 * (tot_filter_dim - 1) + elif padding == 'SAME': + tot_pad = tot_filter_dim + stride - 2 + + # TODO(emilyaf): Don't need to consider case where stride > kernel_dim, right? + # if filter_dim > 1: + pad_end = tot_pad // 2 + pad_start = tot_pad - pad_end - (stride - 1) # implicit pad + # else: + # pad_end = pad_start = 0 + return pad_start, pad_end + + +def _get_output_shape(rank, strides, padding, dilations, input_shape, + output_size, filter_shape, output_padding=None): + """Compute the `output_shape` and `strides` arg used by `conv_transpose`.""" + if output_padding is None: + output_padding = (None,) * rank + else: + output_padding = utils.prepare_tuple_argument( + output_padding, n=rank, arg_name='output_padding') + for stride, out_pad in zip(strides, output_padding): + if out_pad >= stride: + raise ValueError('Stride {} must be greater than output ' + 'padding {}.'.format(strides, output_padding)) + event_shape = [] + for i in range(-rank, 0): + event_shape.append(_deconv_output_length( + input_shape[i - 1], + filter_size=filter_shape[i], + padding=padding, + output_padding=output_padding[i], + stride=strides[i], + dilation=dilations[i])) + event_shape.append(output_size) + batch_shape = input_shape[:-rank-1] + output_shape = ps.concat([batch_shape, event_shape], axis=0) + strides = ps.pad(strides, paddings=[[1, 1]], constant_values=1) + return output_shape, strides + + +def _deconv_output_length(input_size, filter_size, padding, output_padding, + stride, dilation): + """Determines output length of a transposed convolution given input length. + + Args: + input_size: `int`. + filter_size: `int`. + padding: one of `"SAME"`, `"VALID"`, `"FULL"`. + output_padding: `int`, amount of padding along the output dimension. Can + be set to `None` in which case the output length is inferred. + stride: `int`. + dilation: `int`. + + Returns: + output_length: The output length (`int`). + """ + assert padding in {'SAME', 'VALID', 'FULL'} + if input_size is None: + return None + # Get the dilated kernel size + filter_size = filter_size + (filter_size - 1) * (dilation - 1) + # Infer length if output padding is None, else compute the exact length + if output_padding is None: + if padding == 'VALID': + return input_size * stride + max(filter_size - stride, 0) + elif padding == 'FULL': + return input_size * stride - (stride + filter_size - 2) + elif padding == 'SAME': + return input_size * stride + if padding == 'SAME': + pad = filter_size // 2 + elif padding == 'VALID': + pad = 0 + elif padding == 'FULL': + pad = filter_size - 1 + return (input_size - 1) * stride + filter_size - 2 * pad + output_padding + + +def prepare_conv_args( + filter_shape, rank, strides, padding, dilations, validate_args=False): + """Sanitizes use provided input.""" + padding = _validate_padding(padding) # pylint: disable=protected-access + try: + rank = int(tf.get_static_value(rank)) + except TypeError: + raise TypeError('Argument `rank` must be statically known `int`.') + valid_rank = {1, 2, 3} + if rank not in valid_rank: + raise ValueError('Argument `rank` must be in {}.'.format(valid_rank)) + filter_shape = prepare_tuple_argument( + filter_shape, n=rank, arg_name='filter_shape', + validate_args=validate_args) + strides = prepare_tuple_argument( + strides, n=rank, arg_name='strides', validate_args=validate_args) + padding = utils._prepare_padding_argument(padding) # pylint: disable=protected-access + dilations = prepare_tuple_argument( + dilations, n=rank, arg_name='dilations', validate_args=validate_args) + return filter_shape, rank, strides, padding, dilations + + +# TODO(emilyaf): Replace the version in `utils` with this. +def prepare_tuple_argument(arg, n, arg_name, validate_args): + """Helper which processes `Tensor`s to tuples in standard form.""" + arg_size = ps.size(arg) + arg_size_ = tf.get_static_value(arg_size) + assertions = [] + if arg_size_ is not None: + if arg_size_ not in (1, n): + raise ValueError('The size of `{}` must be equal to `1` or to the rank ' + 'of the convolution (={}). Saw size = {}'.format( + arg_name, n, arg_size_)) + elif validate_args: + assertions.append(assert_util.assert_equal( + ps.logical_or(arg_size == 1, arg_size == n), + True, + message=('The size of `{}` must be equal to `1` or to the rank of the ' + 'convolution (={})'.format(arg_name, n)))) + + with tf.control_dependencies(assertions): + arg = ps.broadcast_to(arg, shape=[n]) + arg = ps.unstack(arg, num=n) + return arg + + +def _call_conv2d_transpose(x, kernel, filter_shape, strides, padding, dilations, + c_out, batch_shape, event_shape): + """Call `tf.nn.conv2d_transpose` (for kernels with no batch dimensions).""" + fh, fw = filter_shape + flat_x = tf.reshape(x, shape=ps.concat([[-1], event_shape], axis=0)) + output_shape, strides_ = _get_output_shape( + rank=2, strides=strides, padding=padding, dilations=dilations, + input_shape=ps.shape(flat_x), output_size=c_out, + filter_shape=filter_shape) + flat_y = tf.nn.conv2d_transpose( + flat_x, + filters=tf.transpose( + tf.reshape( + kernel, shape=[fh, fw, event_shape[-1], -1]), + perm=[0, 1, 3, 2]), + output_shape=output_shape, + strides=strides_, + padding=padding, + data_format='NHWC', + dilations=dilations) + return tf.reshape( + flat_y, shape=ps.concat([batch_shape, output_shape[-3:]], axis=0)) diff --git a/tensorflow_probability/python/experimental/nn/util/convolution_util_test.py b/tensorflow_probability/python/experimental/nn/util/convolution_util_test.py new file mode 100644 index 0000000000..6dbf9d473b --- /dev/null +++ b/tensorflow_probability/python/experimental/nn/util/convolution_util_test.py @@ -0,0 +1,561 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for batched convolutions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from absl.testing import parameterized + +import numpy as np +import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp + +from tensorflow_probability.python.experimental.nn.util import convolution_util +from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.internal import test_util + +tfn = tfp.experimental.nn + + +# TODO(emilyaf): Test that gradients work. +# pylint: disable=bad-whitespace +_CONV_TEST_CASES = ( + # input dim filter c_out strides padding dilations + ((1, 32, 32, 3), (3, 4), 2, (1, 1), 'VALID', (1, 1)), + ((5, 2, 32, 32, 3), (2, 2), 4, (1, 2), 'SAME', (1, 1)), + ((5, 2, 7, 7, 3), (2, 2), 4, (1, 2), 'SAME', (2, 1)), + ((5, 2, 13, 13, 3), (2, 2), 4, (1, 2), 'SAME', (1, 1)), + ((4, 28, 28, 2), (2, 3), 2, (2, 2), 'VALID', (1, 2)) + ) + +_CONV_TRANSPOSE_TEST_CASES = ( + # input dim filter c_out strides padding dilations + ((2, 16, 16, 3), (3, 3), 4, (2, 2), 'SAME', (1, 1)), + ((2, 16, 16, 3), (4, 4), 3, (2, 2), 'SAME', (1, 1)), + ((2, 8, 8, 2), (3, 3), 3, (1, 2), 'SAME', (1, 1)), + ((4, 9, 9, 3), (3, 3), 2, (1, 1), 'SAME', (2, 2)), + ((4, 12, 9, 3), (3, 3), 1, (2, 2), 'VALID', (1, 1)), + ((2, 12, 12, 2), (2, 3), 1, (2, 2), 'VALID', (1, 1)), + ) +# pylint: enable=bad-whitespace + + +def _make_input_and_kernel( + make_input, input_batch_shape, input_shape, kernel_batch_shape, + filter_shape, channels_out, dtype): + total_input_shape = ps.concat([input_batch_shape, input_shape], axis=0) + total_kernel_shape = ps.concat( + [kernel_batch_shape, [filter_shape[0] * filter_shape[1] * input_shape[-1], + channels_out]], axis=0) + # Use integers for numerical stability. + sample_fn = lambda s: make_input(tf.cast( # pylint: disable=g-long-lambda + tf.random.uniform( + ps.cast(s, tf.int32), minval=-10, maxval=10, dtype=tf.int32), + dtype=dtype)) + return sample_fn(total_input_shape), sample_fn(total_kernel_shape) + + +def _get_conv_transpose_fn(method): + if method == 'subkernels': + return tfn.util.make_convolution_transpose_fn_with_subkernels + elif method == 'subkernels_matrix': + return tfn.util.make_convolution_transpose_fn_with_subkernels_matrix + elif method == 'dilation': + return tfn.util.make_convolution_transpose_fn_with_dilation + else: + raise ValueError('Unsupported method for `_get_conv_transpose_fn`: {}.' + ''.format(method)) + + +class _Common(object): + """Common methods for Conv/ConvTranspose tests.""" + + def assertRaisesMaybeStaticError(self, msg): + if tf.executing_eagerly() or self.use_static_shape: + return self.assertRaisesRegex(ValueError, msg) + return self.assertRaisesOpError(msg) + + def make_integer_input(self, number): + if self.use_static_shape: + return number + output = tf.Variable(number, dtype=tf.int32) + self.evaluate(output.initializer) + return output + + +@test_util.test_all_tf_execution_regimes +class Im2RowTest(test_util.TestCase): + + def test_works_like_conv2d(self): + x = tf.constant([[ + [[2], [1], [2], [0], [1]], + [[1], [3], [2], [2], [3]], + [[1], [1], [3], [3], [0]], + [[2], [2], [0], [1], [1]], + [[0], [0], [3], [1], [2]], + ]], tf.float32) # shape=[1, 5, 5, 1] + x = tf.concat([x, x], axis=-1) + k = tf.constant([ + [[[2, 0.1]], [[3, 0.2]]], + [[[0, 0.3]], [[1, 0.4]]], + ], tf.float32) # shape=[2, 2, 1, 2] + k = tf.concat([k, k], axis=-2) + strides = [1, 2] + im2row_x = tfn.util.im2row( + x, + block_shape=ps.shape(k)[:2], + slice_step=strides, + padding='VALID') + y_expected = tf.nn.conv2d(x, k, strides=strides, padding='VALID') + y_actual = tf.matmul(im2row_x, tf.reshape(k, shape=[-1, k.shape[-1]])) + [y_expected_, y_actual_] = self.evaluate([y_expected, y_actual]) + self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) + + @parameterized.parameters((tf.int32, np.int32), (tf.int64, np.int64)) + def test_dtype(self, tf_dtype, np_dtype): + ind, _ = tfn.util.im2row_index( + input_shape=(1, 12, 16, 3), + block_shape=(2, 3), + dtype=tf_dtype) + self.assertDTypeEqual(ind, np_dtype) + + +@test_util.test_all_tf_execution_regimes +class ConvolutionUtilsTest(test_util.TestCase, _Common): + + use_static_shape = False + + def test_prepare_tuple_argument(self): + + rank = 3 + + # Test that scalars are processed to tuples. + arg = convolution_util.prepare_tuple_argument( + self.make_integer_input(2), n=rank, arg_name='arg', validate_args=True) + self.assertIsInstance(arg, list) + self.assertLen(arg, rank) + + # Test that `Tensor` args are processed correctly. + arg = convolution_util.prepare_tuple_argument( + self.make_integer_input( + [2, 3, 4]), n=rank, arg_name='arg_2', validate_args=True) + self.assertIsInstance(arg, list) + self.assertLen(arg, rank) + + with self.assertRaisesRegex( + ValueError, 'must be equal to `1` or to the rank'): + convolution_util.prepare_tuple_argument( + self.make_integer_input([1, 2]), n=rank, arg_name='invalid_arg', + validate_args=True) + + def test_prepare_conv_args(self): + [filter_shape, + rank, + strides, + padding, + dilations] = convolution_util.prepare_conv_args( + (3, 3), + rank=2, + strides=2, + padding='same', + dilations=(1, 1)) + + for arg in [filter_shape, strides, dilations]: + self.assertLen(arg, rank) + + self.assertEqual(padding, 'SAME') + + +@test_util.test_all_tf_execution_regimes +class _BatchedConvTest(test_util.TestCase, _Common): + + @parameterized.parameters(*_CONV_TEST_CASES) + def test_works_like_conv2d( + self, input_shape, filter_shape, channels_out, + strides, padding, dilations): + x, k = _make_input_and_kernel( + self.make_input, + input_batch_shape=[], + input_shape=input_shape, + # Use singleton kernel_batch_shape to bypass the short circuit to tf.nn. + kernel_batch_shape=[1], + filter_shape=filter_shape, + channels_out=channels_out, + dtype=self.dtype) + + conv_fn = tfn.util.make_convolution_fn( + self.make_integer_input(filter_shape), + rank=2, + strides=self.make_integer_input(strides), + padding=padding, + dilations=self.make_integer_input(dilations), + validate_args=True) + y_actual = conv_fn(x, k) + + tf_kernel = tf.reshape( + k, shape=(filter_shape) + (input_shape[-1], channels_out)) + y_expected = tf.nn.conv2d( + x, tf_kernel, strides=strides, padding=padding, dilations=dilations) + [y_expected_, y_actual_] = self.evaluate([y_expected, y_actual]) + self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) + + @parameterized.parameters( + ((1,), ()), # scalar input batch, scalar kernel batch + ((1,), (2, 3)), # non-scalar kernel batch + ((3, 4), ()), # non-scalar input batch + ((3, 1), (2,)), # broadcasting kernel and input batch shapes + ((2, 3), (2, 3),)) # same kernel and input batch shapes + def test_batching(self, input_batch_shape, kernel_batch_shape): + input_shape = (12, 12, 2) + filter_shape = (2, 2) + channels_out = 3 + strides = (1, 1) + dilations = (1, 1) + padding = 'SAME' + + x, k = _make_input_and_kernel( + self.make_input, + input_batch_shape=input_batch_shape, + input_shape=input_shape, + kernel_batch_shape=kernel_batch_shape, + filter_shape=filter_shape, + channels_out=channels_out, + dtype=self.dtype) + + conv_fn = tfn.util.make_convolution_fn( + filter_shape, rank=2, strides=strides, padding=padding, + dilations=dilations, validate_args=True) + y_batched = conv_fn(x, k) + + broadcast_batch_shape = ps.broadcast_shape( + input_batch_shape, kernel_batch_shape) + broadcasted_input = tf.broadcast_to( + x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0)) + broadcasted_kernel = tf.broadcast_to( + k, shape=ps.concat([broadcast_batch_shape, ps.shape(k)[-2:]], axis=0)) + + flat_y = tf.reshape( + y_batched, + shape=ps.pad( + ps.shape(y_batched)[-3:], paddings=[[1, 0]], constant_values=-1)) + flat_x = tf.reshape( + broadcasted_input, + shape=ps.pad(input_shape, paddings=[[1, 0]], constant_values=-1)) + flat_tf_kernel = tf.reshape( + broadcasted_kernel, + shape=ps.concat([(-1,), filter_shape, (input_shape[-1], channels_out)], + axis=0)) + + y_expected = tf.vectorized_map( + lambda args: tf.nn.conv2d( # pylint: disable=g-long-lambda + args[0][tf.newaxis], + args[1], + strides=strides, + padding=padding), + elems=(flat_x, flat_tf_kernel)) + + [y_actual_, y_expected_] = self.evaluate( + [flat_y, tf.squeeze(y_expected, axis=1)]) + self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) + + def test_incompatible_shapes_raises(self): + filter_shape = (3, 3) + + # Inconsistent channels in for kernel and image. + c_in_kernel = 6 + c_in_image = 8 + c_out = 12 + + k_dim = np.prod(filter_shape) * c_in_kernel + kernel = self.make_input(tf.ones((2, k_dim, c_out), dtype=tf.float32)) + x = self.make_input(tf.ones((3, 2, 16, 16, c_in_image), dtype=tf.float32)) + conv_fn = tfn.util.make_convolution_fn( + self.make_integer_input(filter_shape), + rank=2, + strides=self.make_integer_input((1, 1)), + padding='SAME', + dilations=self.make_integer_input((1, 1)), + validate_args=True) + with self.assertRaisesMaybeStaticError('size of the rightmost dimension'): + self.evaluate(conv_fn(x, kernel)) + + def test_dtype(self): + # Test int64 indices. + conv_fn = tfn.util.make_convolution_fn( + (2, 2), rank=2, strides=(1, 1), padding='SAME', dilations=(1, 1), + dtype=tf.int64, validate_args=True) + x = tf.ones((2, 8, 8, 2), dtype=tf.float32) + kernel = tf.ones((2, 8, 2), dtype=tf.float32) + _ = self.evaluate(conv_fn(x, kernel)) + + # Test f64 input. + conv_fn = tfn.util.make_convolution_fn( + self.make_integer_input((2, 2)), + rank=2, + strides=self.make_integer_input((1, 1)), + padding='SAME', + dilations=self.make_integer_input((1, 1)), + validate_args=True) + x = tf.ones((2, 8, 8, 2), dtype=tf.float64) + kernel = tf.ones((2, 8, 2), dtype=tf.float64) + y = self.evaluate(conv_fn(x, kernel)) + self.assertDTypeEqual(y, np.float64) + + +@test_util.test_all_tf_execution_regimes +class _BatchedConvTransposeTest(test_util.TestCase, _Common): + + dynamic_strides_ok = True + unequal_strides_ok = True + + def make_conv_fn(self, filter_shape, strides, padding, dilations): + return _get_conv_transpose_fn(self.method)( + self.make_integer_input(filter_shape), + strides=(self.make_integer_input(strides) + if self.dynamic_strides_ok else strides), + padding=padding, + dilations=self.make_integer_input(dilations), + validate_args=True) + + @parameterized.parameters(*_CONV_TRANSPOSE_TEST_CASES) + def test_works_like_conv2d_transpose( + self, input_shape, filter_shape, channels_out, strides, padding, + dilations): + + strides_tuple = strides + if not self.unequal_strides_ok: + if strides[0] != strides[1]: + # Skip this test case if the method does not support unequal strides. + return + else: + strides = strides[0] + + x, k = _make_input_and_kernel( + self.make_input, + input_batch_shape=[], + input_shape=input_shape, + # Use singleton kernel_batch_shape to avoid the short circuit to + # `conv2d_transpose`. + kernel_batch_shape=[1], + filter_shape=filter_shape, + channels_out=channels_out, + dtype=self.dtype) + conv_fn = self.make_conv_fn(filter_shape, strides, padding, dilations) + y_actual = conv_fn(x, k) + output_shape, strides_ = convolution_util._get_output_shape( + rank=2, strides=strides_tuple, padding=padding, dilations=dilations, + input_shape=input_shape, output_size=channels_out, + filter_shape=filter_shape) + + tf_kernel = tf.transpose( + tf.reshape(k, ps.concat( + [filter_shape, [input_shape[-1], channels_out]], axis=0)), + perm=[0, 1, 3, 2]) + # conv2d_transpose does not support dilations > 1; use Keras instead. + if any(d > 1 for d in dilations): + keras_convt = tf.keras.layers.Conv2DTranspose( + filters=channels_out, + kernel_size=filter_shape, + strides=strides, + padding=padding, + dilation_rate=dilations, + use_bias=False) + _ = keras_convt(x) # build kernel + keras_convt.kernel = tf_kernel + y_expected = keras_convt(x) + else: + y_expected = tf.nn.conv2d_transpose( + x, tf_kernel, output_shape=output_shape, + strides=strides_, padding=padding, dilations=dilations) + + [y_expected_, y_actual_] = self.evaluate([y_expected, y_actual]) + self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) + + @parameterized.parameters( + ((1,), ()), # scalar input batch, scalar kernel batch + ((1,), (2, 3)), # non-scalar kernel batch + ((3, 4), ()), # non-scalar input batch + ((3, 1), (2,)), # broadcasting kernel and input batch shapes + ((2, 3), (2, 3),)) # same kernel and input batch shapes + def test_batching(self, input_batch_shape, kernel_batch_shape): + input_shape = (12, 12, 2) + filter_shape = (3, 3) + channels_out = 4 + strides = 2 + dilations = (1, 1) + padding = 'SAME' + + x, k = _make_input_and_kernel( + self.make_input, + input_batch_shape=input_batch_shape, + input_shape=input_shape, + kernel_batch_shape=kernel_batch_shape, + filter_shape=filter_shape, + channels_out=channels_out, + dtype=self.dtype) + + conv_fn = self.make_conv_fn(filter_shape, strides, padding, dilations) + y_batched = conv_fn(x, k) + + broadcast_batch_shape = ps.broadcast_shape( + input_batch_shape, kernel_batch_shape) + broadcasted_input = tf.broadcast_to( + x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0)) + broadcasted_kernel = tf.broadcast_to( + k, shape=ps.concat([broadcast_batch_shape, ps.shape(k)[-2:]], axis=0)) + + flat_y = tf.reshape( + y_batched, + shape=ps.pad( + ps.shape(y_batched)[-3:], paddings=[[1, 0]], constant_values=-1)) + flat_x = tf.reshape( + broadcasted_input, + shape=ps.pad(input_shape, paddings=[[1, 0]], constant_values=-1)) + flat_tf_kernel = tf.einsum( + '...ij->...ji', + tf.reshape( + broadcasted_kernel, + shape=ps.concat( + [(-1,), filter_shape, (input_shape[-1], channels_out)], + axis=0))) + + rank = 2 + output_shape, strides_ = convolution_util._get_output_shape( + rank=rank, strides=(strides,) * rank, padding=padding, + dilations=dilations, input_shape=input_shape, output_size=channels_out, + filter_shape=filter_shape) + + y_expected = tf.vectorized_map( + lambda args: tf.nn.conv2d_transpose( # pylint: disable=g-long-lambda + args[0][tf.newaxis], + args[1], + output_shape=ps.concat([[1], output_shape], axis=0), + strides=strides_, + padding=padding), + elems=(flat_x, flat_tf_kernel)) + + [y_actual_, y_expected_] = self.evaluate( + [flat_y, tf.squeeze(y_expected, axis=1)]) + self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) + + def test_incompatible_shapes_raises(self): + filter_shape = (3, 3) + + # Inconsistent channels in for kernel and image. + c_in_kernel = 6 + c_in_image = 8 + c_out = 12 + + k_dim = np.prod(filter_shape) * c_in_kernel + kernel = self.make_input(tf.ones((2, k_dim, c_out), dtype=self.dtype)) + x = self.make_input(tf.ones((3, 2, 16, 16, c_in_image), dtype=self.dtype)) + conv_fn = self.make_conv_fn( + filter_shape, strides=1, padding='SAME', dilations=1) + + with self.assertRaisesMaybeStaticError('size of the rightmost dimension'): + self.evaluate(conv_fn(x, kernel)) + + def test_dtype(self): + # Test int64 indices. + conv_fn = self.make_conv_fn((2, 2), strides=1, padding='SAME', dilations=1) + x = tf.ones((2, 8, 8, 2), dtype=tf.float32) + kernel = tf.ones((2, 8, 2), dtype=tf.float32) + _ = self.evaluate(conv_fn(x, kernel)) + + # Test f64 input. + conv_fn = self.make_conv_fn((2, 2), strides=1, padding='SAME', dilations=1) + x = tf.ones((2, 8, 8, 2), dtype=tf.float64) + kernel = tf.ones((2, 8, 2), dtype=tf.float64) + y = self.evaluate(conv_fn(x, kernel)) + self.assertDTypeEqual(y, np.float64) + + +@test_util.test_all_tf_execution_regimes +class BatchedConvStaticTest(_BatchedConvTest): + + dtype = tf.float32 + use_static_shape = True + + +@test_util.test_all_tf_execution_regimes +class BatchedConvDynamicTest(_BatchedConvTest): + + dtype = tf.float32 + use_static_shape = False + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithDilationsStaticTest(_BatchedConvTransposeTest): + + method = 'dilation' + dtype = tf.float32 + use_static_shape = True + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithSubkernelsMatrixStaticTest( + _BatchedConvTransposeTest): + + method = 'subkernels_matrix' + dtype = tf.float32 + use_static_shape = True + unequal_strides_ok = False + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithSubkernelsStaticTest(_BatchedConvTransposeTest): + + method = 'subkernels' + dtype = tf.float32 + use_static_shape = True + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithDilationsDynamicTest(_BatchedConvTransposeTest): + + method = 'dilation' + dtype = tf.float32 + use_static_shape = False + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithSubkernelsMatrixDynamicTest( + _BatchedConvTransposeTest): + + method = 'subkernels_matrix' + dtype = tf.float32 + use_static_shape = False + dynamic_strides_ok = False + unequal_strides_ok = False + + +@test_util.test_all_tf_execution_regimes +class BatchedConvTransposeWithSubkernelsDynamicTest(_BatchedConvTransposeTest): + + method = 'subkernels' + dtype = tf.float32 + use_static_shape = False + + +del _BatchedConvTest +del _BatchedConvTransposeTest + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_probability/python/experimental/nn/util/im2row.py b/tensorflow_probability/python/experimental/nn/util/im2row.py deleted file mode 100644 index a67f46f309..0000000000 --- a/tensorflow_probability/python/experimental/nn/util/im2row.py +++ /dev/null @@ -1,187 +0,0 @@ -# Lint as: python2, python3 -# Copyright 2020 The TensorFlow Probability Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Functions for framing `conv` as `matmul`.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow.compat.v2 as tf - -from tensorflow_probability.python.internal import prefer_static - - -__all__ = [ - 'im2row', -] - - -def im2row(x, - block_shape, - slice_step=(1, 1), - data_format='NHWC', - padding='VALID', - name=None): - """Rearrange image blocks into rows. - - This function can be used to implement 2D convolution as a `matml`, e.g., - - `tf.nn.conv2d(x, k) = tf.matmul(im2row(x), tf.reshape(k, [-1, out_size]))`. - - Args: - x: Rank 3 (or more) Tensor representing 2D images. - block_shape: Length-2 vector representing the block or "filter" shape. - slice_step: Length-2 vector specifying the convolution stride length. - Default value: `(1, 1)`. - data_format: One of `'NHWC'` or `'NCHW'` (case insensitive). - Default value: `'NHWC'`. - padding: One of `'VALID'` or `'SAME'` (case insensitive). - Default value: `'VALID'`. - name: Python `str` used to describe ops created by this function. - Default value: `None` (i.e., `'im2col'`). - - Returns: - im2row_x: batch of matrices representing subblock copies of `x`. - Same batch shape as `x` but with rightmost shape: - `batch_shape + [oh * ow, block_shape[0] * block_shape[1] * channels]`, - where `oh = (h - block_shape[0] + 1) // slice_step[0]` and - `ow = (w - block_shape[1] + 1) // slice_step[1]` when `padding = 'VALID'` - and `oh = h` and `ow = w` when `padding = 'SAME'`. - shape: shape `Tensor` equivalent to: - `batch_shape + [oh, ow, block_shape[0] * block_shape[1] * channels]` where - `oh, ow` are defined as above. - """ - with tf.name_scope(name or 'im2row'): - data_format = _validate_data_format(data_format) - padding = _validate_padding(padding) - if padding == 'VALID': - pass # Do nothing. - elif padding == 'SAME': - raise NotImplementedError( - 'Argument padding="SAME" not implemented.') - # TODO(jvdillon): See if the following works: - # fh, fw = block_shape - # o = 1 if data_format == 'NHWC' else 0 - # n = prefer_static.maximum(0, prefer_static.rank(x) - 3) - # paddings = prefer_static.pad( - # [[0, fh - 1], [0, fw - 1]], - # paddings=[[n + 1 - o, o], [0, 0]], - # constant_values=0) - # x = tf.pad(x, paddings=paddings, constant_values=0) - # padding = 'VALID' - else: - assert False # Can't be here. - x_shape = prefer_static.shape(x) - idx, s = _im2row_index( - x_shape, block_shape, slice_step, data_format, padding) - flat_shape = prefer_static.pad( - x_shape[:-3], paddings=[[0, 1]], constant_values=-1) - x = tf.gather(tf.reshape(x, flat_shape), idx, axis=-1) # == np.take - return tf.reshape(x, s) - - -def _im2row_index(input_shape, - block_shape, - slice_step=(1, 1), - data_format='NHWC', - padding='VALID', - dtype=tf.int64, - name=None): - """Computes indexes into a flattened image for building `im2col`.""" - with tf.name_scope(name or 'im2row_index'): - # 1) Process input arguments. - batch_shape, s3, s2, s1 = prefer_static.split( - prefer_static.cast(input_shape, tf.int32), - num_or_size_splits=[-1, 1, 1, 1]) - fh, fw = _split_pair(block_shape) - sh, sw = _split_pair(slice_step) - data_format = _validate_data_format(data_format) - padding = _validate_padding(padding) - - # 2) Assemble all block start positions as indexes into the flattened image. - if data_format == 'NHWC': - h, w, c = s3[0], s2[0], s1[0] - # start_idx.shape = [fh, fw, c] - start_idx = _cartesian_add([ - prefer_static.range(c * w * fh, delta=c * w, dtype=dtype), - prefer_static.range(c * fw, delta=c, dtype=dtype), - prefer_static.range(c, delta=1, dtype=dtype), - ]) - elif data_format == 'NCHW': - c, h, w = s3[0], s2[0], s1[0] - # start_idx.shape = [c, fh, fw] - start_idx = _cartesian_add([ - prefer_static.range(w * h * c, delta=w * h, dtype=dtype), - prefer_static.range(w * fh, delta=w, dtype=dtype), - prefer_static.range(fw, delta=1, dtype=dtype), - ]) - else: - assert False # Can't be here. - - # 3) Assemble all block offsets (into flattened image). - if padding == 'VALID': - eh = h - fh + 1 # extent height - ew = w - fw + 1 # extent width - # offset_idx.shape = [eh // sh, ew // sw] - offset_idx = _cartesian_add([ - prefer_static.range(w * eh, delta=w * sh, dtype=dtype), - prefer_static.range(ew, delta=sw, dtype=dtype), - ]) - if data_format == 'NHWC': - offset_idx *= c - oh = eh // sh # out height - ow = ew // sw # out width - else: - assert False # Can't be here. - - # 4) Combine block start/offset pairs. - # shape = [(eh // sh) * (ew // sw), fh * fw * c] - idx = _cartesian_add([offset_idx, start_idx]) - new_shape = [oh, ow, fh * fw * c] - new_shape = prefer_static.concat([batch_shape, new_shape], axis=0) - return idx, new_shape - - -def _split_pair(x): - """Splits a length two vector into two scalars.""" - x = prefer_static.cast(x, dtype=tf.int32) - a, b = prefer_static.split(x, num_or_size_splits=[1, 1]) - return a[0], b[0] - - -def _cartesian_add(xs): - """Adds a list of vectors by cumulatively expanding a dimension.""" - return sum(prefer_static.reshape(x, shape=[-1] + [1]*(len(xs) - 1 - i)) - for i, x in enumerate(xs)) - - -def _validate_data_format(data_format): - """Verify correctness of `data_format` argument.""" - data_format_ = str(data_format).upper() - if data_format_ in {'NHWC', 'NCHW'}: - return data_format_ - raise ValueError( - 'Argument data_format="{}" not recognized; must be one of ' - '{{"NHWC", "NCHW"}} (case insensitive).'.format(data_format)) - - -def _validate_padding(padding): - """Verify correctness of `padding` argument.""" - padding_ = str(padding).upper() - if padding_ in {'SAME', 'VALID'}: - return padding_ - raise ValueError( - 'Argument padding="{}" not recognized; must be one of ' - '{{"VALID", "SAME"}} (case insensitive).'.format(padding)) diff --git a/tensorflow_probability/python/experimental/nn/util/im2row_test.py b/tensorflow_probability/python/experimental/nn/util/im2row_test.py deleted file mode 100644 index c5591db944..0000000000 --- a/tensorflow_probability/python/experimental/nn/util/im2row_test.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2020 The TensorFlow Probability Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Tests for im2col.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports -import tensorflow.compat.v2 as tf -import tensorflow_probability as tfp - -from tensorflow_probability.python.internal import test_util - -tfn = tfp.experimental.nn - - -@test_util.test_all_tf_execution_regimes -class Im2ColTest(test_util.TestCase): - - def test_works_like_conv2d(self): - x = tf.constant([[ - [[2], [1], [2], [0], [1]], - [[1], [3], [2], [2], [3]], - [[1], [1], [3], [3], [0]], - [[2], [2], [0], [1], [1]], - [[0], [0], [3], [1], [2]], - ]], tf.float32) # shape=[1, 5, 5, 1] - x = tf.concat([x, x], axis=-1) - k = tf.constant([ - [[[2, 0.1]], [[3, 0.2]]], - [[[0, 0.3]], [[1, 0.4]]], - ], tf.float32) # shape=[2, 2, 1, 2] - k = tf.concat([k, k], axis=-2) - strides = [1, 2] - im2row_x = tfn.util.im2row( - x, - block_shape=k.shape[:2], - slice_step=strides, - padding='VALID') - y_expected = tf.nn.conv2d(x, k, strides=strides, padding='VALID') - y_actual = tf.matmul(im2row_x, tf.reshape(k, [-1, k.shape[-1]])) - [y_expected_, y_actual_] = self.evaluate([y_expected, y_actual]) - self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow_probability/python/experimental/util/BUILD b/tensorflow_probability/python/experimental/util/BUILD index f605c210be..849114ba9e 100644 --- a/tensorflow_probability/python/experimental/util/BUILD +++ b/tensorflow_probability/python/experimental/util/BUILD @@ -33,7 +33,7 @@ exports_files(["LICENSE"]) multi_substrate_py_library( name = "util", srcs = ["__init__.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", substrates_omit_deps = [ ":deferred_module", ], @@ -45,7 +45,7 @@ multi_substrate_py_library( py_library( name = "deferred_module", srcs = ["deferred_module.py"], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ # numpy dep, # tensorflow dep, diff --git a/tensorflow_probability/python/experimental/vi/BUILD b/tensorflow_probability/python/experimental/vi/BUILD index f443c07f4d..9072388b64 100644 --- a/tensorflow_probability/python/experimental/vi/BUILD +++ b/tensorflow_probability/python/experimental/vi/BUILD @@ -40,12 +40,21 @@ py_library( srcs = ["surrogate_posteriors.py"], srcs_version = "PY3", deps = [ - # numpy dep, # tensorflow dep, - "//tensorflow_probability/python/distributions", - "//tensorflow_probability/python/internal:nest_util", + "//tensorflow_probability/python/bijectors", + "//tensorflow_probability/python/bijectors:softplus", + "//tensorflow_probability/python/distributions:beta", + "//tensorflow_probability/python/distributions:independent", + "//tensorflow_probability/python/distributions:joint_distribution", + "//tensorflow_probability/python/distributions:joint_distribution_auto_batched", + "//tensorflow_probability/python/distributions:joint_distribution_coroutine", + "//tensorflow_probability/python/distributions:joint_distribution_util", + "//tensorflow_probability/python/distributions:normal", + "//tensorflow_probability/python/distributions:sample", + "//tensorflow_probability/python/distributions:transformed_distribution", + "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", - "//tensorflow_probability/python/monte_carlo", + "//tensorflow_probability/python/util", ], ) diff --git a/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py b/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py index c8c40563f8..00654bb529 100644 --- a/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py +++ b/tensorflow_probability/python/experimental/vi/surrogate_posteriors.py @@ -23,23 +23,37 @@ import functools import tensorflow.compat.v2 as tf -import tensorflow_probability as tfp from tensorflow_probability.python import bijectors as tfb from tensorflow_probability.python import util as tfp_util +from tensorflow_probability.python.bijectors import identity as identity_bijector from tensorflow_probability.python.bijectors import softplus as softplus_lib +from tensorflow_probability.python.distributions import beta +from tensorflow_probability.python.distributions import half_normal from tensorflow_probability.python.distributions import independent +from tensorflow_probability.python.distributions import joint_distribution from tensorflow_probability.python.distributions import joint_distribution_auto_batched from tensorflow_probability.python.distributions import joint_distribution_coroutine from tensorflow_probability.python.distributions import joint_distribution_util from tensorflow_probability.python.distributions import normal +from tensorflow_probability.python.distributions import sample +from tensorflow_probability.python.distributions import transformed_distribution +from tensorflow_probability.python.distributions import truncated_normal +from tensorflow_probability.python.distributions import uniform from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static -from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.util import deprecation +from tensorflow.python.util import nest +# pylint: enable=g-direct-tensorflow-import Root = joint_distribution_coroutine.JointDistributionCoroutine.Root -_NON_STATISTICAL_PARAMS = ['name', 'validate_args', 'allow_nan_stats'] +_NON_STATISTICAL_PARAMS = [ + 'name', 'validate_args', 'allow_nan_stats', 'experimental_use_kahan_sum', + 'reinterpreted_batch_ndims' +] +_NON_TRAINABLE_PARAMS = ['low', 'high'] ASVIParameters = collections.namedtuple( 'ASVIParameters', ['prior_weight', 'mean_field_parameter']) @@ -111,8 +125,13 @@ def _not_list_of_ints(s): build_trainable_location_scale_distribution, distribution_fn=normal.Normal) +@deprecation.deprecated_args( + '2021-03-15', + '`constraining_bijectors` is deprecated, use `bijector` instead', + 'constraining_bijectors') def build_factored_surrogate_posterior( event_shape=None, + bijector=None, constraining_bijectors=None, initial_unconstrained_loc=_sample_uniform_initial_loc, initial_unconstrained_scale=1e-2, @@ -131,22 +150,22 @@ def build_factored_surrogate_posterior( Args: event_shape: `Tensor` shape, or nested structure of `Tensor` shapes, specifying the event shape(s) of the posterior variables. - constraining_bijectors: Optional `tfb.Bijector` instance, or nested - structure of such instances, defining support(s) of the posterior - variables. The structure must match that of `event_shape` and may - contain `None` values. A posterior variable will - be modeled as `tfd.TransformedDistribution(underlying_dist, - constraining_bijector)` if a corresponding constraining bijector is - specified, otherwise it is modeled as supported on the - unconstrained real line. + bijector: Optional `tfb.Bijector` instance, or nested structure of such + instances, defining support(s) of the posterior variables. The structure + must match that of `event_shape` and may contain `None` values. A + posterior variable will be modeled as + `tfd.TransformedDistribution(underlying_dist, bijector)` if a + corresponding constraining bijector is specified, otherwise it is modeled + as supported on the unconstrained real line. + constraining_bijectors: Deprecated alias for `bijector`. initial_unconstrained_loc: Optional Python `callable` with signature `tensor = initial_unconstrained_loc(shape, seed)` used to sample real-valued initializations for the unconstrained representation of each variable. May alternately be a nested structure of `Tensor`s, giving specific initial locations for each variable; these must have structure matching `event_shape` and shapes determined by the - inverse image of `event_shape` under `constraining_bijectors`, which - may optionally be prefixed with a common batch shape. + inverse image of `event_shape` under `bijector`, which may optionally be + prefixed with a common batch shape. Default value: `functools.partial(tf.random.uniform, minval=-2., maxval=2., dtype=tf.float32)`. initial_unconstrained_scale: Optional scalar float `Tensor` initial @@ -198,8 +217,8 @@ def model_fn(): ```python surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], # Omit the observed `y`. - constraining_bijectors=[tfb.Softplus(), # Rate is positive. - tfb.Softplus()]) # Concentration is positive. + bijector=[tfb.Softplus(), # Rate is positive. + tfb.Softplus()]) # Concentration is positive. ``` This creates a trainable joint distribution, defined by variables in @@ -230,14 +249,13 @@ def model_fn(): ```python initial_loc = {'concentration': 0.4, 'rate': 0.2} - constraining_bijectors={'concentration': tfb.Softplus(), # Rate is positive. - 'rate': tfb.Softplus()} # Concentration is positive. + bijector={'concentration': tfb.Softplus(), # Rate is positive. + 'rate': tfb.Softplus()} # Concentration is positive. initial_unconstrained_loc = tf.nest.map_fn( - lambda b, x: b.inverse(x) if b is not None else x, - constraining_bijectors, initial_loc) + lambda b, x: b.inverse(x) if b is not None else x, bijector, initial_loc) surrogate_posterior = tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=tf.nest.map_fn(tf.shape, initial_loc), - constraining_bijectors=constraining_bijectors, + bijector=bijector, initial_unconstrained_loc=initial_unconstrained_state, initial_unconstrained_scale=1e-4) ``` @@ -245,6 +263,9 @@ def model_fn(): """ with tf.name_scope(name or 'build_factored_surrogate_posterior'): + bijector = deprecation.deprecated_argument_lookup( + 'bijector', bijector, 'constraining_bijectors', constraining_bijectors) + seed = tfp_util.SeedStream(seed, salt='build_factored_surrogate_posterior') # Convert event shapes to Tensors. @@ -252,92 +273,149 @@ def model_fn(): event_shape = nest.map_structure_up_to( shallow_structure, lambda s: tf.convert_to_tensor(s, dtype=tf.int32), event_shape) - flat_event_shapes = tf.nest.flatten(event_shape) - # For simplicity, we'll work with flattened lists of state parts and - # repack the structure at the end. - if constraining_bijectors is not None: - flat_bijectors = tf.nest.flatten(constraining_bijectors) + if nest.is_nested(bijector): + bijector = nest.map_structure( + lambda b: identity_bijector.Identity() if b is None else b, + bijector) + + # Support mismatched nested structures for backwards compatibility (e.g. + # non-nested `event_shape` and a single-element list of `bijector`s). + bijector = nest.pack_sequence_as(event_shape, nest.flatten(bijector)) + + event_space_bijector = tfb.JointMap(bijector, validate_args=validate_args) else: - flat_bijectors = [None for _ in flat_event_shapes] - flat_unconstrained_event_shapes = [ - b.inverse_event_shape_tensor(s) if b is not None else s - for s, b in zip(flat_event_shapes, flat_bijectors)] + event_space_bijector = bijector + + if event_space_bijector is None: + unconstrained_event_shape = event_shape + else: + unconstrained_event_shape = ( + event_space_bijector.inverse_event_shape_tensor(event_shape)) # Construct initial locations for the internal unconstrained dists. if callable(initial_unconstrained_loc): # Sample random initialization. - flat_unconstrained_locs = [initial_unconstrained_loc( - shape=s, seed=seed()) for s in flat_unconstrained_event_shapes] - else: # Use provided initialization. - flat_unconstrained_locs = nest.flatten_up_to( - shallow_structure, initial_unconstrained_loc, check_types=False) - - if nest.is_nested(initial_unconstrained_scale): - flat_unconstrained_scales = nest.flatten_up_to( - shallow_structure, initial_unconstrained_scale, check_types=False) - else: - flat_unconstrained_scales = [ - initial_unconstrained_scale for _ in flat_unconstrained_locs] + initial_unconstrained_loc = nest.map_structure( + lambda s: initial_unconstrained_loc(shape=s, seed=seed()), + unconstrained_event_shape) + + if not nest.is_nested(initial_unconstrained_scale): + initial_unconstrained_scale = nest.map_structure( + lambda _: initial_unconstrained_scale, + unconstrained_event_shape) # Extract the rank of each event, so that we build distributions with the # correct event shapes. - flat_unconstrained_event_ndims = [prefer_static.rank_from_shape(s) - for s in flat_unconstrained_event_shapes] + unconstrained_event_ndims = nest.map_structure( + prefer_static.rank_from_shape, + unconstrained_event_shape) # Build the component surrogate posteriors. - flat_component_dists = [] - for initial_loc, initial_scale, event_ndims, bijector in zip( - flat_unconstrained_locs, - flat_unconstrained_scales, - flat_unconstrained_event_ndims, - flat_bijectors): - unconstrained_dist = trainable_distribution_fn( - initial_loc=initial_loc, initial_scale=initial_scale, - event_ndims=event_ndims, validate_args=validate_args) - flat_component_dists.append( - bijector(unconstrained_dist) if bijector is not None - else unconstrained_dist) - component_distributions = tf.nest.pack_sequence_as( - event_shape, flat_component_dists) - - # Return a `Distribution` object whose events have the specified structure. - return ( + unconstrained_distributions = nest.map_structure_up_to( + unconstrained_event_shape, + lambda loc, scale, ndims: trainable_distribution_fn( # pylint: disable=g-long-lambda + loc, scale, ndims, validate_args=validate_args), + initial_unconstrained_loc, + initial_unconstrained_scale, + unconstrained_event_ndims) + + base_distribution = ( joint_distribution_util.independent_joint_distribution_from_structure( - component_distributions, validate_args=validate_args)) + unconstrained_distributions, validate_args=validate_args)) + if event_space_bijector is None: + return base_distribution + return transformed_distribution.TransformedDistribution( + base_distribution, event_space_bijector) + + +def _as_trainable_family(distribution): + """Substitutes prior distributions with more easily trainable ones.""" + with tf.name_scope('as_trainable_family'): + + if isinstance(distribution, half_normal.HalfNormal): + return truncated_normal.TruncatedNormal( + loc=0., + scale=distribution.scale, + low=0., + high=distribution.scale * 10.) + elif isinstance(distribution, uniform.Uniform): + return tfb.Shift(distribution.low)( + tfb.Scale(distribution.high - distribution.low)(beta.Beta( + concentration0=tf.ones( + distribution.event_shape_tensor(), dtype=distribution.dtype), + concentration1=1.))) + else: + return distribution -def _make_asvi_trainable_variables(prior): +def _make_asvi_trainable_variables(prior, + mean_field=False, + initial_prior_weight=0.5): """Generates parameter dictionaries given a prior distribution and list.""" with tf.name_scope('make_asvi_trainable_variables'): param_dicts = [] prior_dists = prior._get_single_sample_distributions() # pylint: disable=protected-access for dist in prior_dists: - actual_dist = dist.distribution if isinstance(dist, Root) else dist - dist_params = actual_dist.parameters + original_dist = dist.distribution if isinstance(dist, Root) else dist + + substituted_dist = _as_trainable_family(original_dist) + + # Grab the base distribution if it exists + try: + actual_dist = substituted_dist.distribution + except AttributeError: + actual_dist = substituted_dist + new_params_dict = {} # Build trainable ASVI representation for each distribution's parameters. - for param, value in dist_params.items(): - if param in _NON_STATISTICAL_PARAMS or value is None: + parameter_properties = actual_dist.parameter_properties( + dtype=actual_dist.dtype) + sample_shape = tf.concat( + [dist.batch_shape_tensor(), + dist.event_shape_tensor()], axis=0) + for param, value in actual_dist.parameters.items(): + if param in (_NON_STATISTICAL_PARAMS + + _NON_TRAINABLE_PARAMS) or value is None: continue - new_params_dict[param] = ASVIParameters( - prior_weight=tfp.util.TransformedVariable( - 0.5, - bijector=tfb.Sigmoid(), - name='prior_weight/{}/{}'.format(dist.name, param)), - mean_field_parameter=tfp.util.TransformedVariable( - 0.5, - bijector=dist.parameter_properties( - dtype=dist.dtype)[param].default_constraining_bijector_fn(), - name='mean_field_parameter/{}/{}'.format(dist.name, param)) - ) + try: + bijector = parameter_properties[ + param].default_constraining_bijector_fn() + except NotImplementedError: + bijector = tfb.Identity() + unconstrained_ones = tf.ones( + shape=bijector.inverse_event_shape_tensor( + parameter_properties[param].shape_fn( + sample_shape=sample_shape)), + dtype=actual_dist.dtype) + + if mean_field: + new_params_dict[param] = ASVIParameters( + prior_weight=None, + mean_field_parameter=tfp_util.TransformedVariable( + value, + bijector=bijector, + name='mean_field_parameter/{}/{}'.format(dist.name, param))) + else: + new_params_dict[param] = ASVIParameters( + prior_weight=tfp_util.TransformedVariable( + initial_prior_weight * unconstrained_ones, + bijector=tfb.Sigmoid(), + name='prior_weight/{}/{}'.format(dist.name, param)), + mean_field_parameter=tfp_util.TransformedVariable( + value, + bijector=bijector, + name='mean_field_parameter/{}/{}'.format(dist.name, param))) param_dicts.append(new_params_dict) return param_dicts # TODO(kateslin): Add support for models with prior+likelihood written as # a single JointDistribution. -def build_asvi_surrogate_posterior(prior, name=None): +def build_asvi_surrogate_posterior(prior, + mean_field=False, + initial_prior_weight=0.5, + name=None): """Builds a structured surrogate posterior inspired by conjugate updating. ASVI, or Automatic Structured Variational Inference, was proposed by @@ -360,12 +438,22 @@ def build_asvi_surrogate_posterior(prior, name=None): Args: prior: tfd.JointDistribution instance of the prior. - name: Optional string. + mean_field: Optional Python boolean. If `True`, creates a degenerate + surrogate distribution in which all variables are independent, + ignoring the prior dependence structure. Default value: `False`. + initial_prior_weight: Optional float value (either static or tensor value) + on the interval [0, 1]. A larger value creates an initial surrogate + distribution with more dependence on the prior structure. Default value: + `0.5`. + name: Optional string. Default value: `build_asvi_surrogate_posterior`. Returns: surrogate_posterior: A `tfd.JointDistributionCoroutineAutoBatched` instance whose samples have shape and structure matching that of `prior`. + Raises: + TypeError: The `prior` argument cannot be a nested `JointDistribution`. + ### Examples Consider a Brownian motion model expressed as a JointDistribution: @@ -417,9 +505,10 @@ def model_fn(): """ with tf.name_scope(name or 'build_asvi_surrogate_posterior'): - - param_dicts = _make_asvi_trainable_variables(prior) - + param_dicts = _make_asvi_trainable_variables( + prior=prior, + mean_field=mean_field, + initial_prior_weight=initial_prior_weight) def posterior_generator(): prior_gen = prior._model_coroutine() # pylint: disable=protected-access @@ -428,25 +517,61 @@ def posterior_generator(): i = 0 try: while True: - actual_dist = dist.distribution if isinstance(dist, Root) else dist - dist_params = actual_dist.parameters - temp_params_dict = {} + original_dist = dist.distribution if isinstance(dist, Root) else dist + + if isinstance(original_dist, joint_distribution.JointDistribution): + # TODO(kateslin): Build inner JD surrogate in + # _make_asvi_trainable_variables to avoid rebuilding variables. + raise TypeError( + 'Argument `prior` cannot be a nested `JointDistribution`.') - for param, value in dist_params.items(): - if param in _NON_STATISTICAL_PARAMS or value is None: - temp_params_dict[param] = value + else: + + original_dist = _as_trainable_family(original_dist) + + try: + actual_dist = original_dist.distribution + except AttributeError: + actual_dist = original_dist + + dist_params = actual_dist.parameters + temp_params_dict = {} + + for param, value in dist_params.items(): + if param in (_NON_STATISTICAL_PARAMS + + _NON_TRAINABLE_PARAMS) or value is None: + temp_params_dict[param] = value + else: + prior_weight = param_dicts[i][param].prior_weight + mean_field_parameter = param_dicts[i][ + param].mean_field_parameter + if mean_field: + temp_params_dict[param] = mean_field_parameter + else: + temp_params_dict[param] = prior_weight * value + ( + 1. - prior_weight) * mean_field_parameter + + if isinstance(original_dist, sample.Sample): + surrogate_dist = sample.Sample( + type(actual_dist)(**temp_params_dict)) else: - prior_weight = param_dicts[i][param].prior_weight - mean_field_parameter = param_dicts[i][param].mean_field_parameter - temp_params_dict[param] = prior_weight * value + ( - 1. - prior_weight) * mean_field_parameter + surrogate_dist = type(actual_dist)(**temp_params_dict) - surrogate_dist = type(actual_dist)(**temp_params_dict) + if isinstance(original_dist, + transformed_distribution.TransformedDistribution): + surrogate_dist = transformed_distribution.TransformedDistribution( + surrogate_dist, bijector=original_dist.bijector) - if isinstance(dist, Root): - value_out = yield Root(surrogate_dist) - else: - value_out = yield surrogate_dist + if isinstance(original_dist, independent.Independent): + surrogate_dist = independent.Independent( + surrogate_dist, + reinterpreted_batch_ndims=original_dist + .reinterpreted_batch_ndims) + + if isinstance(dist, Root): + value_out = yield Root(surrogate_dist) + else: + value_out = yield surrogate_dist dist = prior_gen.send(value_out) i += 1 @@ -456,5 +581,19 @@ def posterior_generator(): surrogate_posterior = ( joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched( posterior_generator)) + + # Ensure that the surrogate posterior structure matches that of the prior + try: + tf.nest.assert_same_structure(prior.dtype, surrogate_posterior.dtype) + except TypeError: + tokenize = lambda structure: tf.nest.pack_sequence_as( # pylint:disable=g-long-lambda + structure, [i for (i, _) in enumerate(tf.nest.flatten(structure))]) + surrogate_posterior = tfb.Restructure( + output_structure=tokenize(prior.dtype), + input_structure=tokenize(surrogate_posterior.dtype))( + surrogate_posterior) + surrogate_posterior.also_track = param_dicts return surrogate_posterior + + diff --git a/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py b/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py index 7b7f8e8f90..a58fa61e87 100644 --- a/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py +++ b/tensorflow_probability/python/experimental/vi/surrogate_posteriors_test.py @@ -28,6 +28,7 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp from tensorflow_probability.python.experimental.vi import surrogate_posteriors +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import test_util tfb = tfp.bijectors @@ -83,28 +84,27 @@ class FactoredSurrogatePosterior(test_util.TestCase): @parameterized.named_parameters( {'testcase_name': 'TensorEvent', 'event_shape': tf.TensorShape([4]), - 'constraining_bijectors': [tfb.Sigmoid()], + 'bijector': [tfb.Sigmoid()], 'dtype': np.float64, 'use_static_shape': True}, {'testcase_name': 'ListEvent', 'event_shape': [tf.TensorShape([3]), tf.TensorShape([]), tf.TensorShape([2, 2])], - 'constraining_bijectors': [tfb.Softplus(), None, tfb.FillTriangular()], + 'bijector': [tfb.Softplus(), None, tfb.FillTriangular()], 'dtype': np.float32, 'use_static_shape': False}, {'testcase_name': 'DictEvent', 'event_shape': {'x': tf.TensorShape([1]), 'y': tf.TensorShape([])}, - 'constraining_bijectors': None, + 'bijector': None, 'dtype': np.float64, 'use_static_shape': True}, {'testcase_name': 'NestedEvent', 'event_shape': {'x': [tf.TensorShape([1]), tf.TensorShape([1, 2])], 'y': tf.TensorShape([])}, - 'constraining_bijectors': { + 'bijector': { 'x': [tfb.Identity(), tfb.Softplus()], 'y': tfb.Sigmoid()}, 'dtype': np.float32, 'use_static_shape': True}, ) - def test_specifying_event_shape(self, event_shape, - constraining_bijectors, - dtype, use_static_shape): + def test_specifying_event_shape( + self, event_shape, bijector, dtype, use_static_shape): seed = test_util.test_seed_stream() surrogate_posterior = ( tfp.experimental.vi.build_factored_surrogate_posterior( @@ -113,7 +113,7 @@ def test_specifying_event_shape(self, event_shape, dtype=np.int32, use_static_shape=use_static_shape), event_shape), - constraining_bijectors=constraining_bijectors, + bijector=bijector, initial_unconstrained_loc=functools.partial( tf.random.uniform, minval=-2., maxval=2., dtype=dtype), seed=seed(), @@ -149,7 +149,7 @@ def test_specifying_event_shape(self, event_shape, 'event_shape': [4], 'initial_loc': np.array([[[0.9, 0.1, 0.5, 0.7]]]), 'implicit_batch_shape': [1, 1], - 'constraining_bijectors': tfb.Sigmoid(), + 'bijector': tfb.Sigmoid(), 'dtype': np.float32, 'use_static_shape': False}, {'testcase_name': 'ListEvent', 'event_shape': [[3], [], [2, 2]], @@ -157,29 +157,28 @@ def test_specifying_event_shape(self, event_shape, 0.1, np.array([[1., 0], [-4., 2.]])], 'implicit_batch_shape': [], - 'constraining_bijectors': [tfb.Softplus(), None, tfb.FillTriangular()], + 'bijector': [tfb.Softplus(), None, tfb.FillTriangular()], 'dtype': np.float64, 'use_static_shape': True}, {'testcase_name': 'DictEvent', 'event_shape': {'x': [2], 'y': []}, 'initial_loc': {'x': np.array([[0.9, 1.2]]), 'y': np.array([-4.1])}, 'implicit_batch_shape': [1], - 'constraining_bijectors': None, + 'bijector': None, 'dtype': np.float32, 'use_static_shape': False}, ) def test_specifying_initial_loc(self, event_shape, initial_loc, - implicit_batch_shape, - constraining_bijectors, + implicit_batch_shape, bijector, dtype, use_static_shape): initial_loc = tf.nest.map_structure( lambda s: _build_tensor(s, dtype=dtype, # pylint: disable=g-long-lambda use_static_shape=use_static_shape), initial_loc) - if constraining_bijectors is not None: + if bijector is not None: initial_unconstrained_loc = tf.nest.map_structure( lambda x, b: x if b is None else b.inverse(x), - initial_loc, constraining_bijectors) + initial_loc, bijector) else: initial_unconstrained_loc = initial_loc @@ -188,7 +187,7 @@ def test_specifying_initial_loc(self, event_shape, initial_loc, event_shape=event_shape, initial_unconstrained_loc=initial_unconstrained_loc, initial_unconstrained_scale=1e-6, - constraining_bijectors=constraining_bijectors, + bijector=bijector, validate_args=True)) self.evaluate([v.initializer for v in surrogate_posterior.trainable_variables]) @@ -222,7 +221,7 @@ def model_fn(): surrogate_posterior = ( tfp.experimental.vi.build_factored_surrogate_posterior( event_shape=model.event_shape_tensor()[:-1], - constraining_bijectors=[tfb.Softplus(), tfb.Softplus()])) + bijector=[tfb.Softplus(), tfb.Softplus()])) # Fit model. y = [0.2, 0.5, 0.3, 0.7] @@ -244,6 +243,46 @@ def model_fn(): _ = self.evaluate(posterior_mean) _ = self.evaluate(posterior_stddev) + def test_multipart_bijector(self): + dist = tfd.JointDistributionNamed({ + 'a': tfd.Exponential(1.), + 'b': tfd.Normal(0., 1.), + 'c': lambda b, a: tfd.Sample(tfd.Normal(b, a), sample_shape=[5])}) + + seed = test_util.test_seed_stream() + surrogate_posterior = ( + tfp.experimental.vi.build_factored_surrogate_posterior( + event_shape=dist.event_shape, + bijector=( + dist.experimental_default_event_space_bijector()), + initial_unconstrained_loc=functools.partial( + tf.random.uniform, minval=-2., maxval=2.), + seed=seed(), + validate_args=True)) + self.evaluate([v.initializer + for v in surrogate_posterior.trainable_variables]) + + # Test that the posterior has the specified event shape(s). + self.assertAllEqualNested( + self.evaluate(dist.event_shape_tensor()), + self.evaluate(surrogate_posterior.event_shape_tensor())) + + posterior_sample_ = self.evaluate(surrogate_posterior.sample(seed=seed())) + posterior_logprob_ = self.evaluate( + surrogate_posterior.log_prob(posterior_sample_)) + + # Test that all sample Tensors have the expected shapes. + check_shape = lambda s, x: self.assertAllEqual(s, x.shape) + self.assertAllAssertsNested( + check_shape, dist.event_shape, posterior_sample_) + + # Test that samples are finite and not NaN. + self.assertAllAssertsNested(self.assertAllFinite, posterior_sample_) + + # Test that logprob is scalar, finite, and not NaN. + self.assertEmpty(posterior_logprob_.shape) + self.assertAllFinite(posterior_logprob_) + def _build_tensor(ndarray, dtype, use_static_shape): # Enforce parameterized dtype and static/dynamic testing. @@ -265,10 +304,16 @@ def test_dims_and_gradients(self): # Test that the correct number of trainable variables are being tracked prior_dists = prior_dist._get_single_sample_distributions() # pylint: disable=protected-access expected_num_trainable_vars = 0 - for dist in prior_dists: + for original_dist in prior_dists: + try: + original_dist = original_dist.distribution + except AttributeError: + pass + dist = surrogate_posteriors._as_trainable_family(original_dist) dist_params = dist.parameters for param, value in dist_params.items(): - if param not in surrogate_posteriors._NON_STATISTICAL_PARAMS and value is not None: + if (param not in surrogate_posteriors._NON_STATISTICAL_PARAMS + and value is not None and param not in ('low', 'high')): expected_num_trainable_vars += 2 # prior_weight, mean_field_parameter self.assertLen(surrogate_posterior.trainable_variables, @@ -285,9 +330,9 @@ def test_dims_and_gradients(self): # Test that the sample shape is correct three_posterior_samples = surrogate_posterior.sample(3) three_prior_samples = prior_dist.sample(3) - - self.assertAllEqualNested([s.shape for s in three_prior_samples], - [s.shape for s in three_posterior_samples]) + self.assertAllEqualNested( + [s.shape for s in tf.nest.flatten(three_prior_samples)], + [s.shape for s in tf.nest.flatten(three_posterior_samples)]) def test_fitting_surrogate_posterior(self): @@ -308,8 +353,9 @@ def test_fitting_surrogate_posterior(self): # Compute posterior statistics. with tf.control_dependencies([losses]): posterior_samples = surrogate_posterior.sample(100) - posterior_mean = [tf.reduce_mean(x) for x in posterior_samples] - posterior_stddev = [tf.math.reduce_std(x) for x in posterior_samples] + posterior_mean = tf.nest.map_structure(tf.reduce_mean, posterior_samples) + posterior_stddev = tf.nest.map_structure(tf.math.reduce_std, + posterior_samples) self.evaluate(tf1.global_variables_initializer()) _ = self.evaluate(losses) @@ -328,8 +374,16 @@ def test_make_asvi_trainable_variables(self): # Confirm that there exists correct number of trainable variables. for (prior_distribution, trained_vars_dict) in zip(prior_dists, trained_vars): - for param_name, prior_value in prior_distribution.parameters.items(): - if param_name not in surrogate_posteriors._NON_STATISTICAL_PARAMS and prior_value is not None: + substituted_dist = surrogate_posteriors._as_trainable_family( + prior_distribution) + try: + posterior_distribution = substituted_dist.distribution + except AttributeError: + posterior_distribution = substituted_dist + + for param_name, prior_value in posterior_distribution.parameters.items(): + if (param_name not in surrogate_posteriors._NON_STATISTICAL_PARAMS + and prior_value is not None and param_name not in ('low', 'high')): self.assertIsInstance(trained_vars_dict[param_name], surrogate_posteriors.ASVIParameters) @@ -375,9 +429,101 @@ def target_log_prob(*x): return target_log_prob +@test_util.test_all_tf_execution_regimes +class ASVISurrogatePosteriorTestEightSchools(test_util.TestCase, + _TrainableASVISurrogate): + + def make_prior_dist(self): + treatment_effects = tf.constant([28, 8, -3, 7, -1, 1, 18, 12], + dtype=tf.float32) + num_schools = ps.shape(treatment_effects)[-1] + + return tfd.JointDistributionNamed({ + 'avg_effect': + tfd.Normal(loc=0., scale=10., name='avg_effect'), + 'log_stddev': + tfd.Normal(loc=5., scale=1., name='log_stddev'), + 'school_effects': + lambda log_stddev, avg_effect: ( # pylint: disable=g-long-lambda + tfd.Independent( + tfd.Normal( + loc=avg_effect[..., None] * tf.ones(num_schools), + scale=tf.exp(log_stddev[..., None]) * tf.ones( + num_schools), + name='school_effects'), + reinterpreted_batch_ndims=1)) + }) + + def make_likelihood_model(self, x, observation_noise=None): + treatment_stddevs = tf.constant([15, 10, 16, 11, 9, 11, 10, 18], + dtype=tf.float32) + + return tfd.Independent( + tfd.Normal(loc=x['school_effects'], scale=treatment_stddevs), + reinterpreted_batch_ndims=1) + + def get_observations(self, prior_dist): + ground_truth = self.evaluate(prior_dist.sample()) + likelihood = self.make_likelihood_model(x=ground_truth) + return likelihood.sample(1) + + def get_target_log_prob(self, observations, prior_dist): + + def target_log_prob(**x): + likelihood_dist = self.make_likelihood_model(x=x) + return likelihood_dist.log_prob(observations) + prior_dist.log_prob(x) + + return target_log_prob + + +@test_util.test_all_tf_execution_regimes +class ASVISurrogatePosteriorTestHalfNormal(test_util.TestCase, + _TrainableASVISurrogate): + + def make_prior_dist(self): + + def _prior_model_fn(): + innovation_noise = 1. + yield tfd.HalfNormal( + scale=innovation_noise, validate_args=True, allow_nan_stats=False) + + return tfd.JointDistributionCoroutineAutoBatched(_prior_model_fn) + + def make_likelihood_model(self, x, observation_noise): + + def _likelihood_model(): + yield tfd.Normal( + loc=x, + scale=observation_noise, + validate_args=True, + allow_nan_stats=False) + + return tfd.JointDistributionCoroutineAutoBatched(_likelihood_model) + + def get_observations(self, prior_dist): + observation_noise = 1. + ground_truth = prior_dist.sample() + likelihood = self.make_likelihood_model( + x=ground_truth, observation_noise=observation_noise) + return likelihood.sample(1) + + def get_target_log_prob(self, observations, prior_dist): + + obs = observations + def target_log_prob(*x): + observation_noise = 0.15 + likelihood_dist = self.make_likelihood_model( + x=x, observation_noise=observation_noise) + + return likelihood_dist.log_prob(obs) + prior_dist.log_prob(x) + + return target_log_prob + # TODO(kateslin): Add an ASVI surrogate posterior test for gamma distributions. # TODO(kateslin): Add an ASVI surrogate posterior test with for a model with # missing observations. +# TODO(kateslin): Add an ASVI surrogate posterior test for Uniform distribution +# to check that Beta substitution works properly if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/internal/auto_composite_tensor.py b/tensorflow_probability/python/internal/auto_composite_tensor.py index 26ab6b3ac5..82e8cf8f22 100644 --- a/tensorflow_probability/python/internal/auto_composite_tensor.py +++ b/tensorflow_probability/python/internal/auto_composite_tensor.py @@ -36,10 +36,11 @@ _SENTINEL = object() -_AUTO_COMPOSITE_TENSOR_VERSION = 1 +_AUTO_COMPOSITE_TENSOR_VERSION = 2 -def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None): +def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None, + prefer_static_value=()): """Extract constructor kwargs to reconstruct `obj`.""" argspec = inspect.getfullargspec(obj.__init__) if argspec.varargs or argspec.varkw: @@ -61,6 +62,10 @@ def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None): raise ValueError( 'Object did not have an attr corresponding to constructor argument ' '{k}. (Tried both `obj.{k}` and obj._{k}`).'.format(k=k)) + if k in prefer_static_value and kwargs[k] is not None: + static_val = tf.get_static_value(kwargs[k]) + if static_val is not None: + kwargs[k] = static_val return kwargs @@ -101,16 +106,22 @@ def _extract_type_spec_recursively(value): class _AutoCompositeTensorTypeSpec(tf.TypeSpec): """A tf.TypeSpec for `AutoCompositeTensor` objects.""" - __slots__ = ('_param_specs', '_non_tensor_params', '_omit_kwargs') + __slots__ = ('_param_specs', '_non_tensor_params', '_omit_kwargs', + '_prefer_static_value') - def __init__(self, param_specs, non_tensor_params, omit_kwargs): + def __init__(self, param_specs, non_tensor_params, omit_kwargs, + prefer_static_value): self._param_specs = param_specs self._non_tensor_params = non_tensor_params self._omit_kwargs = omit_kwargs + self._prefer_static_value = prefer_static_value @classmethod def from_instance(cls, instance, omit_kwargs=()): - kwargs = _extract_init_kwargs(instance, omit_kwargs) + prefer_static_value = tuple( + getattr(instance, '_composite_tensor_shape_params', ())) + kwargs = _extract_init_kwargs(instance, omit_kwargs=omit_kwargs, + prefer_static_value=prefer_static_value) non_tensor_params = {} param_specs = {} @@ -125,7 +136,8 @@ def from_instance(cls, instance, omit_kwargs=()): # Construct the spec. return cls(param_specs=param_specs, non_tensor_params=non_tensor_params, - omit_kwargs=omit_kwargs) + omit_kwargs=omit_kwargs, + prefer_static_value=prefer_static_value) def _to_components(self, obj): return _extract_init_kwargs(obj, limit_to=list(self._param_specs)) @@ -142,16 +154,20 @@ def _serialize(self): result = (_AUTO_COMPOSITE_TENSOR_VERSION, self._param_specs, self._non_tensor_params, - self._omit_kwargs) + self._omit_kwargs, + self._prefer_static_value) return result @classmethod def _deserialize(cls, encoded): - version, param_specs, non_tensor_params, omit_kwargs = encoded + version = encoded[0] + if version == 1: + encoded = encoded + ((),) + version = 2 if version != _AUTO_COMPOSITE_TENSOR_VERSION: raise ValueError('Expected version {}, but got {}' .format(_AUTO_COMPOSITE_TENSOR_VERSION, version)) - return cls(param_specs, non_tensor_params, omit_kwargs) + return cls(*encoded[1:]) _TypeSpecCodec = nested_structure_coder._TypeSpecCodec # pylint: disable=protected-access @@ -193,6 +209,12 @@ def auto_composite_tensor(cls=None, omit_kwargs=()): - object.attribute = [tf.constant(1.), [tf.constant(2.)]] # valid - object.attribute = ['abc', tf.constant(1.)] # invalid + If the object has a `_composite_tensor_shape_parameters` field (presumed to + have `tuple` of `str` value), the flattening code will use + `tf.get_static_value` to attempt to preserve shapes as static metadata, for + fields whose name matches a name specified in that field. Preserving static + values can be important to correctly propagating shapes through a loop. + If the decorated class `A` does not subclass `CompositeTensor`, a *new class* will be generated, which mixes in `A` and `CompositeTensor`. @@ -277,7 +299,8 @@ def body(obj): composite_tensor_subclass: A subclass of `cls` and TF CompositeTensor. """ if cls is None: - return functools.partial(auto_composite_tensor, omit_kwargs=omit_kwargs) + return functools.partial(auto_composite_tensor, + omit_kwargs=omit_kwargs) # If the declared class is already a CompositeTensor subclass, we can avoid # affecting the actual type of the returned class. Otherwise, we need to diff --git a/tensorflow_probability/python/internal/auto_composite_tensor_test.py b/tensorflow_probability/python/internal/auto_composite_tensor_test.py index 231f6846eb..418a995e32 100644 --- a/tensorflow_probability/python/internal/auto_composite_tensor_test.py +++ b/tensorflow_probability/python/internal/auto_composite_tensor_test.py @@ -24,6 +24,9 @@ from tensorflow_probability.python.internal import test_util +tfd = tfp.distributions + + AutoIdentity = tfp.experimental.auto_composite_tensor( tf.linalg.LinearOperatorIdentity, omit_kwargs=('name',)) AutoDiag = tfp.experimental.auto_composite_tensor( @@ -33,6 +36,11 @@ AutoTriL = tfp.experimental.auto_composite_tensor( tf.linalg.LinearOperatorLowerTriangular, omit_kwargs=('name',)) +AutoNormal = tfp.experimental.auto_composite_tensor( + tfd.Normal, omit_kwargs=('name',)) +AutoIndependent = tfp.experimental.auto_composite_tensor( + tfd.Independent, omit_kwargs=('name',)) + @test_util.test_all_tf_execution_regimes class AutoCompositeTensorTest(test_util.TestCase): @@ -77,6 +85,18 @@ def body(lop): maximum_iterations=3) self.assertAllClose(2.**3 * tf.ones([3]), lop.matvec(tf.ones([3]))) + def test_shape_parameters(self): + dist = AutoIndependent(AutoNormal(0, tf.ones([1])), + reinterpreted_batch_ndims=1) + stream = test_util.test_seed_stream() + lp = dist.log_prob(dist.sample(seed=stream())) + lp, _ = tf.while_loop( + lambda *_: True, + lambda lp, d: (d.log_prob(d.sample(seed=stream())), d), + (lp, dist), + maximum_iterations=2) + self.evaluate(lp) + def test_nested(self): lop = AutoBlockDiag([AutoDiag(tf.ones([2]) * 2), AutoIdentity(1)]) self.assertAllClose( @@ -90,9 +110,9 @@ def test_preconditioner(self): is_self_adjoint=True, is_positive_definite=True) - tfd = tfp.experimental.distributions + tfed = tfp.experimental.distributions auto_ct_mvn_prec_linop = tfp.experimental.auto_composite_tensor( - tfd.MultivariateNormalPrecisionFactorLinearOperator, + tfed.MultivariateNormalPrecisionFactorLinearOperator, omit_kwargs=('name',)) tril = AutoTriL(**cov_linop.cholesky().parameters) momentum_distribution = auto_ct_mvn_prec_linop(precision_factor=tril) diff --git a/tensorflow_probability/python/internal/backend/numpy/BUILD b/tensorflow_probability/python/internal/backend/numpy/BUILD index 5e7c466f46..d157f36c79 100644 --- a/tensorflow_probability/python/internal/backend/numpy/BUILD +++ b/tensorflow_probability/python/internal/backend/numpy/BUILD @@ -363,7 +363,7 @@ py_library( py_library( name = "numpy_testlib", testonly = 1, - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ ":numpy", # absl/testing:parameterized dep, diff --git a/tensorflow_probability/python/internal/backend/numpy/numpy_array.py b/tensorflow_probability/python/internal/backend/numpy/numpy_array.py index 51ba5b2c98..0af0c81ffd 100644 --- a/tensorflow_probability/python/internal/backend/numpy/numpy_array.py +++ b/tensorflow_probability/python/internal/backend/numpy/numpy_array.py @@ -384,7 +384,7 @@ def _zeros_like(input, dtype=None, name=None): # pylint: disable=redefined-buil fill = utils.copy_docstring( 'tf.fill', - lambda dims, value, name=None: np.full(dims, value)) + lambda dims, value, name=None: np.full(dims, ops.convert_to_tensor(value))) gather = utils.copy_docstring( 'tf.gather', diff --git a/tensorflow_probability/python/internal/cache_util.py b/tensorflow_probability/python/internal/cache_util.py index d2f3607492..7dda003016 100644 --- a/tensorflow_probability/python/internal/cache_util.py +++ b/tensorflow_probability/python/internal/cache_util.py @@ -91,7 +91,7 @@ class HashableWeakRef(weakref.ref): def __init__(self, referrent, callback=None): """weakref.ref which makes tf.Tensor and np.array objects hashable. - Arguments: + Args: referrent: Object that is being referred to. callback: Optional callback to invoke when object is GCed. """ @@ -327,7 +327,7 @@ def bijector_class(self): def forward(self, x, **kwargs): """Invokes the 'forward' transformation, or looks up previous results. - Arguments: + Args: x: The singular argument passed to `bijector._forward`. **kwargs: Any auxiliary arguments passed to the function. These reflect shared context to the function, and are associated @@ -340,7 +340,7 @@ def forward(self, x, **kwargs): def inverse(self, y, **kwargs): """Invokes the 'inverse' transformation, or looks up previous results. - Arguments: + Args: y: The singular argument passed to `bijector._inverse`. **kwargs: Any auxiliary arguments passed to the function. These reflect shared context to the function, and are associated @@ -431,7 +431,7 @@ def _attributes(self, input, fn_name, **kwargs): == 0) ``` - Arguments: + Args: input: The singular ordered argument passed to the wrapped function. fn_name: `str`, name of the directed function to which `input` is an arg (typically `'_forward'` or `'_inverse'`). @@ -469,7 +469,7 @@ def _lookup(self, input, forward_name, inverse_name, **kwargs): assert cache.inverse._lookup(y, '_inverse', '_forward') == (x, attrs) ``` - Arguments: + Args: input: The singular ordered argument passed to the wrapped function. forward_name: `str`, the name of the function implementing the bijector's forward transformation (typically `'_forward'`). diff --git a/tensorflow_probability/python/internal/custom_gradient.py b/tensorflow_probability/python/internal/custom_gradient.py index 07bc930881..19d5a19d16 100644 --- a/tensorflow_probability/python/internal/custom_gradient.py +++ b/tensorflow_probability/python/internal/custom_gradient.py @@ -41,8 +41,11 @@ def custom_gradient(vjp_fwd=None, vjp_bwd=None, jvp_fn=None, Args: vjp_fwd: A function (*args) => (output, auxiliaries). - vjp_bwd: A function (auxiliaries, output_gradient) => args_gradients. - jvp_fn: A function (primals, tangents) => (primal_out, tangent_out). + vjp_bwd: A function (auxiliaries, output_gradient) => + nondiff_args_gradients. `None` gradients will be inserted into the correct + positions for `nondiff_argnums`. + jvp_fn: A function (*nondiff_args, primals, tangents) => + (primal_out, tangent_out). nondiff_argnums: Tuple of argument indices which are not differentiable. Returns: diff --git a/tensorflow_probability/python/internal/hypothesis_testlib.py b/tensorflow_probability/python/internal/hypothesis_testlib.py index a7b50bafa3..2b9c12c1cf 100644 --- a/tensorflow_probability/python/internal/hypothesis_testlib.py +++ b/tensorflow_probability/python/internal/hypothesis_testlib.py @@ -435,6 +435,7 @@ def broadcasting_params(draw, enable_vars=False, constraint_fn_for=lambda param: identity_fn, mutex_params=(), + param_strategy_fn=None, dtype=np.float32): """Streategy for drawing parameters which jointly have the given batch shape. @@ -467,6 +468,10 @@ def broadcasting_params(draw, mutually exclusive parameters (e.g., the 'probs' and 'logits' of a Categorical). At most one parameter from each set will appear in the result. + param_strategy_fn: Optional callable with signature + `strategy = param_strategy_fn(shape, dtype, constraint_fn)`. If provided, + overrides the default strategy for generating float-valued parameters. + Default value: `constrained_tensors`. dtype: Dtype for generated parameters. Returns: @@ -479,6 +484,8 @@ def broadcasting_params(draw, """ if event_dim is None: event_dim = draw(hps.integers(min_value=2, max_value=6)) + if param_strategy_fn is None: + param_strategy_fn = constrained_tensors params_event_ndims = params_event_ndims or {} remaining_params = set(params_event_ndims.keys()) @@ -504,12 +511,14 @@ def broadcasting_params(draw, hp.assume(len(param_shape) < 6) # TODO(axch): Can I replace `params_event_ndims` and `constraint_fn_for` - # with a map from params to `Suppport`s, and use `tensors_in_support` here - # instead of this explicit `constrained_tensors` function? - param_strategy = constrained_tensors( - constraint_fn_for(param), param_shape, dtype=dtype) - params_kwargs[param] = draw(maybe_variable( - param_strategy, enable_vars=enable_vars, dtype=dtype, name=param)) + # with a map from params to `Suppport`s, and use `tensors_in_support` here? + param_strategy = param_strategy_fn(constraint_fn=constraint_fn_for(param), + shape=param_shape, + dtype=dtype) + params_kwargs[param] = draw(maybe_variable(param_strategy, + enable_vars=enable_vars, + dtype=dtype, + name=param)) return params_kwargs diff --git a/tensorflow_probability/python/internal/prefer_static.py b/tensorflow_probability/python/internal/prefer_static.py index 48d576f9e8..d719eb8875 100644 --- a/tensorflow_probability/python/internal/prefer_static.py +++ b/tensorflow_probability/python/internal/prefer_static.py @@ -137,7 +137,8 @@ def _convert_to_shape_tensor_jax(value, dtype=None, dtype_hint=None, name=None): """Converts vectors and scalars of `int`-like to `ndarray`.""" dtype = dtype_util.as_numpy_dtype(dtype or dtype_hint or np.int32) try: - return np.array([int(v) for v in value], dtype=dtype) + return np.array([_convert_to_shape_tensor_jax(v, dtype) for v in value], + dtype=dtype) except: # JAX throws raw Exception in some cases. # pylint: disable=bare-except pass return np.array(int(value), dtype=dtype) @@ -205,7 +206,7 @@ def broadcast_shape(x_shape, y_shape): computed statically and returned as a `TensorShape`. Otherwise, a rank-1 `Tensor` will be returned. - Arguments: + Args: x_shape: A `TensorShape` or rank-1 integer `Tensor`. The input `Tensor` is broadcast against this shape. y_shape: A `TensorShape` or rank-1 integer `Tensor`. The input `Tensor` is @@ -230,7 +231,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): If `pred` is a bool or has a constant value, we return either `true_fn()` or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. - Arguments: + Args: pred: A scalar determining whether to return the result of `true_fn` or `false_fn`. true_fn: The callable to be performed if pred is true. diff --git a/tensorflow_probability/python/internal/samplers.py b/tensorflow_probability/python/internal/samplers.py index 2f472527c1..71f5442b33 100644 --- a/tensorflow_probability/python/internal/samplers.py +++ b/tensorflow_probability/python/internal/samplers.py @@ -34,6 +34,7 @@ __all__ = [ 'categorical', + 'fold_in', 'gamma', 'is_stateful_seed', 'normal', @@ -88,16 +89,24 @@ def sanitize_seed(seed, salt=None, name=None): if salt is not None: salt = int(hashlib.sha512(str(salt).encode('utf-8')).hexdigest(), 16) - if JAX_MODE: - from jax import random as jaxrand # pylint: disable=g-import-not-at-top - seed = jaxrand.fold_in(seed, salt & (2**32 - 1)) - else: - seed = tf.bitwise.bitwise_xor( - seed, np.uint64([salt & (2**64 - 1)]).view(np.int32)) + seed = fold_in(seed, salt) return tf.convert_to_tensor(seed, dtype=SEED_DTYPE, name='seed') +def fold_in(seed, salt): + """Folds salt into seed to form a new seed.""" + if JAX_MODE: + from jax import random as jaxrand # pylint: disable=g-import-not-at-top + return jaxrand.fold_in(seed, salt & (2**32 - 1)) + if isinstance(salt, (six.integer_types)): + seed = tf.bitwise.bitwise_xor( + seed, np.uint64([salt & (2**64 - 1)]).view(np.int32)) + else: + seed = tf.random.experimental.stateless_fold_in(seed, salt) + return seed + + def split_seed(seed, n=2, salt=None, name=None): """Splits a seed into `n` derived seeds. diff --git a/tensorflow_probability/python/internal/test_combinations.py b/tensorflow_probability/python/internal/test_combinations.py index f436c8946e..127848f8d4 100644 --- a/tensorflow_probability/python/internal/test_combinations.py +++ b/tensorflow_probability/python/internal/test_combinations.py @@ -78,7 +78,7 @@ def should_execute_combination(self, kwargs): If the environment doesn't satisfy the dependencies of the test combination, then it can be skipped. - Arguments: + Args: kwargs: Arguments that are passed to the test combination. Returns: @@ -100,7 +100,7 @@ def context_managers(self, kwargs): The test combination will run under all context managers that all `TestCombination` instances return. - Arguments: + Args: kwargs: Arguments and their values that are passed to the test combination. @@ -119,7 +119,7 @@ class ParameterModifier(object): def __init__(self, parameter_name=None): """Construct a parameter modifier that may be specific to a parameter. - Arguments: + Args: parameter_name: A `ParameterModifier` instance may operate on a class of parameters or on a parameter with a particular name. Only `ParameterModifier` instances that are of a unique type or were @@ -135,7 +135,7 @@ def modified_arguments(self, kwargs, requested_parameters): This makes it possible to adjust user-provided arguments before passing them to the test method. - Arguments: + Args: kwargs: The combined arguments for the test. requested_parameters: The set of parameters that are defined in the signature of the test method. diff --git a/tensorflow_probability/python/layers/dense_variational_v2.py b/tensorflow_probability/python/layers/dense_variational_v2.py index 2bc6614a0c..116cb2a29d 100644 --- a/tensorflow_probability/python/layers/dense_variational_v2.py +++ b/tensorflow_probability/python/layers/dense_variational_v2.py @@ -53,7 +53,7 @@ def __init__(self, **kwargs): """Creates the `DenseVariational` layer. - Arguments: + Args: units: Positive integer, dimensionality of the output space. make_posterior_fn: Python callable taking `tf.size(kernel)`, `tf.size(bias)`, `dtype` and returns another callable which takes an diff --git a/tensorflow_probability/python/layers/distribution_layer.py b/tensorflow_probability/python/layers/distribution_layer.py index 383ca3fca4..191a33d0a5 100644 --- a/tensorflow_probability/python/layers/distribution_layer.py +++ b/tensorflow_probability/python/layers/distribution_layer.py @@ -1510,7 +1510,7 @@ def new(params, num_components, component_layer, def params_size(num_components, component_params_size, name=None): """Number of `params` needed to create a `MixtureSameFamily` distribution. - Arguments: + Args: num_components: Number of component distributions in the mixture distribution. component_params_size: Number of parameters needed to create a single diff --git a/tensorflow_probability/python/layers/distribution_layer_test.py b/tensorflow_probability/python/layers/distribution_layer_test.py index 8b58b2a4af..abe4683515 100644 --- a/tensorflow_probability/python/layers/distribution_layer_test.py +++ b/tensorflow_probability/python/layers/distribution_layer_test.py @@ -320,7 +320,7 @@ class DistributionLambdaSerializationTest(test_util.TestCase): def assertSerializable(self, model, batch_size=1): """Assert that a model can be saved/loaded via Keras Model.save/load_model. - Arguments: + Args: model: A Keras model that outputs a `tfd.Distribution`. batch_size: The batch size to use when checking that the model produces the same results as a serialized/deserialized copy. Default value: 1. @@ -348,7 +348,7 @@ def assertSerializable(self, model, batch_size=1): def assertExportable(self, model, batch_size=1): """Assert a Keras model supports export_saved_model/load_from_saved_model. - Arguments: + Args: model: A Keras model with Tensor output. batch_size: The batch size to use when checking that the model produces the same results as a serialized/deserialized copy. Default value: 1. diff --git a/tensorflow_probability/python/layers/initializers.py b/tensorflow_probability/python/layers/initializers.py index 530706f5ed..f3e9906dcd 100644 --- a/tensorflow_probability/python/layers/initializers.py +++ b/tensorflow_probability/python/layers/initializers.py @@ -30,7 +30,7 @@ class BlockwiseInitializer(tf.keras.initializers.Initializer): def __init__(self, initializers, sizes, validate_args=False): """Creates the `BlockwiseInitializer`. - Arguments: + Args: initializers: `list` of Keras initializers, e.g., `"glorot_uniform"` or `tf.keras.initializers.Constant(0.5413)`. sizes: `list` of `int` scalars representing the number of elements diff --git a/tensorflow_probability/python/layers/masked_autoregressive.py b/tensorflow_probability/python/layers/masked_autoregressive.py index 0a5f4b9316..4699222412 100644 --- a/tensorflow_probability/python/layers/masked_autoregressive.py +++ b/tensorflow_probability/python/layers/masked_autoregressive.py @@ -123,7 +123,7 @@ def f_inverse(x): def __init__(self, made, **kwargs): """Constructs the AutoregressiveTransform layer. - Arguments: + Args: made: A `Made` layer, which must output two parameters for each input. **kwargs: Additional keyword arguments passed to `tf.keras.Layer`. """ diff --git a/tensorflow_probability/python/layers/variable_input.py b/tensorflow_probability/python/layers/variable_input.py index 63dd74355a..f0e7a9b74b 100644 --- a/tensorflow_probability/python/layers/variable_input.py +++ b/tensorflow_probability/python/layers/variable_input.py @@ -83,7 +83,7 @@ def __init__(self, **kwargs): """Creates the `VariableLayer`. - Arguments: + Args: shape: integer or integer vector specifying the shape of the output of this layer. dtype: TensorFlow `dtype` of the variable created by this layer. diff --git a/tensorflow_probability/python/math/BUILD b/tensorflow_probability/python/math/BUILD index ad8b133d0e..eb66b47871 100644 --- a/tensorflow_probability/python/math/BUILD +++ b/tensorflow_probability/python/math/BUILD @@ -66,7 +66,7 @@ multi_substrate_py_library( srcs = [ "bessel.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ # numpy dep, # tensorflow dep, @@ -119,7 +119,7 @@ multi_substrate_py_library( srcs = [ "gram_schmidt.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ # tensorflow dep, "//tensorflow_probability/python/internal:prefer_static", @@ -199,7 +199,7 @@ multi_substrate_py_library( srcs = [ "hypergeometric.py", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ # numpy dep, # tensorflow dep, diff --git a/tensorflow_probability/python/math/minimize.py b/tensorflow_probability/python/math/minimize.py index b5a5c9f47a..ed52bbaebe 100644 --- a/tensorflow_probability/python/math/minimize.py +++ b/tensorflow_probability/python/math/minimize.py @@ -22,8 +22,6 @@ import tensorflow.compat.v2 as tf -from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import - class MinimizeTraceableQuantities(collections.namedtuple( 'MinimizeTraceableQuantities', @@ -51,31 +49,12 @@ class MinimizeTraceableQuantities(collections.namedtuple( """ -# Backwards compatibility for older `trace_fns` that took separate -# loss, grads, and params. -def _maybe_wrap_old_style_trace_fn(trace_fn): - """Returns a `trace_fn that takes the single `minimizer_state` argument.""" - - def safe_trace_fn(traceable_quantities): - """A `trace_fn that takes the single `minimizer_state` argument.""" - try: - return trace_fn(traceable_quantities) - except TypeError: - deprecated_trace_fn = deprecation.deprecated_args( - '2020-07-01', - 'The signature for `trace_fn`s passed to `minimize` has changed. ' - 'Trace functions now take a single `traceable_quantities` argument, ' - 'which is a `tfp.math.MinimizeTraceableQuantities` namedtuple ' - 'containing `traceable_quantities.loss`, ' - '`traceable_quantities.gradients`, etc. ' - 'Please update your `trace_fn` definition.', - ('loss', 'grads', 'variables') - )(trace_fn) - return deprecated_trace_fn( - traceable_quantities.loss, - traceable_quantities.gradients, - traceable_quantities.parameters) - return safe_trace_fn +def _sanitize_traced_values(traced_values): + """Represents Python values and `None` as Tensors.""" + return tf.nest.map_structure( + lambda x: (tf.zeros([0], dtype=tf.int32) if x is None # pylint: disable=g-long-lambda + else tf.convert_to_tensor(x)), + traced_values) def _tile_last_written_value(trace_array, last_written_idx): @@ -127,7 +106,7 @@ def training_loop_body(step, trace_arrays, has_converged=None, loss=loss, gradients=grads, parameters=parameters, step=step, has_converged=has_converged, convergence_criterion_state=convergence_criterion_state) - traced_values = trace_fn(traceable_quantities) + traced_values = _sanitize_traced_values(trace_fn(traceable_quantities)) trace_arrays = tf.nest.map_structure( lambda ta, x: ta.write(step, x), trace_arrays, traced_values) potential_new_loop_vars = ( @@ -141,6 +120,7 @@ def _initialize_arrays(initial_values, num_steps, truncate_at_convergence): """Construct a structure of `TraceArray`s from initial values.""" + initial_values = _sanitize_traced_values(initial_values) num_steps_ = tf.get_static_value(tf.convert_to_tensor(num_steps)) size_is_dynamic = (num_steps_ is None or truncate_at_convergence) trace_arrays = tf.nest.map_structure( @@ -312,8 +292,6 @@ def minimize(loss_fn, """ - trace_fn = _maybe_wrap_old_style_trace_fn(trace_fn) - def convergence_detected(step, trace_arrays, has_converged=None, convergence_criterion_state=None): @@ -379,4 +357,3 @@ def convergence_detected(step, trace_arrays, trace_arrays) return tf.nest.map_structure(lambda array: array.stack(), trace_arrays) - diff --git a/tensorflow_probability/python/math/minimize_test.py b/tensorflow_probability/python/math/minimize_test.py index e8d2ec50a0..58487b44e0 100644 --- a/tensorflow_probability/python/math/minimize_test.py +++ b/tensorflow_probability/python/math/minimize_test.py @@ -19,24 +19,19 @@ from __future__ import print_function # Dependency imports -from absl.testing import parameterized import numpy as np import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf import tensorflow_probability as tfp - from tensorflow_probability.python.internal import test_util @test_util.test_all_tf_execution_regimes class MinimizeTests(test_util.TestCase): - @parameterized.named_parameters( - {'testcase_name': 'new_style', 'new_style_trace_fn': True}, - {'testcase_name': 'old_style', 'new_style_trace_fn': False}) - def test_custom_trace_fn(self, new_style_trace_fn): + def test_custom_trace_fn(self): init_x = np.array([0., 0.]).astype(np.float32) target_x = np.array([3., 4.]).astype(np.float32) @@ -45,15 +40,9 @@ def test_custom_trace_fn(self, new_style_trace_fn): loss_fn = lambda: tf.reduce_sum((x - target_x)**2) # The trace_fn should determine the structure and values of the results. - if new_style_trace_fn: # Takes a `MinimizerState` namedtuple. - def trace_fn(traceable_quantities): - return {'loss': traceable_quantities.loss, 'x': x, - 'sqdiff': (x - target_x)**2} - else: - def trace_fn(loss, grads, values): # Takes individual args. - del grads - del values - return {'loss': loss, 'x': x, 'sqdiff': (x - target_x)**2} + def trace_fn(traceable_quantities): + return {'loss': traceable_quantities.loss, 'x': x, + 'sqdiff': (x - target_x)**2} results = tfp.math.minimize(loss_fn, num_steps=100, optimizer=tf.optimizers.Adam(0.1), @@ -64,6 +53,16 @@ def trace_fn(loss, grads, values): # Takes individual args. self.assertAllClose(results_['x'][-1], target_x, atol=0.2) self.assertAllClose(results_['sqdiff'][-1], [0., 0.], atol=0.1) + def test_can_trace_all_traceable_quantities(self): + x = tf.Variable(5.0) + trace_fn = lambda traceable_quantities: traceable_quantities + results = tfp.math.minimize(loss_fn=lambda: tf.reduce_sum((x - 1.0)**2), + num_steps=10, + optimizer=tf.optimizers.Adam(0.1), + trace_fn=trace_fn) + self.evaluate(tf1.global_variables_initializer()) + self.evaluate(results) + def test_respects_trainable_variables(self): # Variables not included in `trainable_variables` should stay fixed. x = tf.Variable(5.) @@ -94,7 +93,7 @@ def test_works_when_results_have_dynamic_shape(self): num_steps=num_steps, # TODO(b/137299119) Replace with TF2 optimizer. optimizer=tf1.train.AdamOptimizer(0.1), - trace_fn=lambda loss, grads, vars: (loss, grads), + trace_fn=lambda t: (t.loss, t.gradients), trainable_variables=[x]) with tf.control_dependencies([losses]): final_x = tf.identity(x) diff --git a/tensorflow_probability/python/math/psd_kernels/feature_transformed_test.py b/tensorflow_probability/python/math/psd_kernels/feature_transformed_test.py index cc79930532..654dcbd050 100644 --- a/tensorflow_probability/python/math/psd_kernels/feature_transformed_test.py +++ b/tensorflow_probability/python/math/psd_kernels/feature_transformed_test.py @@ -83,7 +83,7 @@ def testValuesAreCorrectScalarTransform(self, feature_ndims, dims): amplitude, length_scale, feature_ndims) input_shape = [dims] * feature_ndims - bij = tfp.bijectors.AffineScalar(self.dtype(0.), self.dtype(2.)) + bij = tfp.bijectors.Scale(scale=self.dtype(2.)) # Flat multiplication by 2. def scale_transform(x, feature_ndims, param_expansion_ndims): del feature_ndims, param_expansion_ndims @@ -114,7 +114,7 @@ def testValuesAreCorrectVectorTransform(self, feature_ndims, dims): input_shape = [dims] * feature_ndims scale_diag = np.random.uniform(-1, 1, size=(dims,)).astype(self.dtype) - bij = tfp.bijectors.Affine(scale_diag=scale_diag) + bij = tfp.bijectors.ScaleMatvecDiag(scale_diag=scale_diag) # Scaling the last dimension. def vector_transform(x, feature_ndims, param_expansion_ndims): diff --git a/tensorflow_probability/python/math/root_search.py b/tensorflow_probability/python/math/root_search.py index 9d55368bdb..c442b92ab2 100644 --- a/tensorflow_probability/python/math/root_search.py +++ b/tensorflow_probability/python/math/root_search.py @@ -592,7 +592,7 @@ def bracket_root(objective_fn, xs_positive = tf.exp(tf.linspace(tf.cast(-10., dtype), tf.math.log(dtype_info.max), num_points // 2)) - xs = tf.concat([-xs_positive, xs_positive], axis=0) + xs = tf.concat([tf.reverse(-xs_positive, axis=[0]), xs_positive], axis=0) # Evaluate the objective at all points. The objective function may return # a batch of values (e.g., `objective(x) = x - batch_of_roots`). diff --git a/tensorflow_probability/python/math/root_search_test.py b/tensorflow_probability/python/math/root_search_test.py index db998878f3..7a894d49df 100644 --- a/tensorflow_probability/python/math/root_search_test.py +++ b/tensorflow_probability/python/math/root_search_test.py @@ -288,6 +288,20 @@ def objective_fn(x): self.assertAllTrue(low < roots) self.assertAllTrue(high > roots) + def test_negative_root(self): + root = -17.314 + low, high = self.evaluate(tfp.math.bracket_root(lambda x: (x - root))) + self.assertLess(low, root) + self.assertGreater(high, root) + + def test_root_near_zero(self): + root = tf.exp(-13.) + low, high = self.evaluate(tfp.math.bracket_root(lambda x: (x - root))) + self.assertLess(low, np.exp(-13.)) + self.assertGreater(high, np.exp(-13)) + self.assertAllClose(low, root, atol=1e-4) + self.assertAllClose(high, root, atol=1e-4) + def test_returns_zero_width_bracket_at_root(self): root = tf.exp(-10.) low, high = self.evaluate(tfp.math.bracket_root(lambda x: (x - root))) diff --git a/tensorflow_probability/python/mcmc/sample_test.py b/tensorflow_probability/python/mcmc/sample_test.py index b3bb921046..1228768362 100644 --- a/tensorflow_probability/python/mcmc/sample_test.py +++ b/tensorflow_probability/python/mcmc/sample_test.py @@ -433,11 +433,17 @@ def model(): momentum_distribution=momentum_dist) bijector = pinned.experimental_default_event_space_bijector() kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bijector) + pullback_shape = bijector.inverse_event_shape(pinned.event_shape) + kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation( + kernel, + initial_running_variance=struct._make( + tfp.experimental.stats.RunningVariance.from_shape(t) + for t in pullback_shape)) state = bijector(struct._make( tfd.Uniform(-2., 2.).sample(shp) for shp in bijector.inverse_event_shape(pinned.event_shape))) self.evaluate(tfp.mcmc.sample_chain( - 3, current_state=state, kernel=kernel, seed=stream())) + 3, current_state=state, kernel=kernel, seed=stream()).all_states) if __name__ == '__main__': diff --git a/tensorflow_probability/python/mcmc/transformed_kernel_test.py b/tensorflow_probability/python/mcmc/transformed_kernel_test.py index 3b25eb8081..6c9b6be1bb 100644 --- a/tensorflow_probability/python/mcmc/transformed_kernel_test.py +++ b/tensorflow_probability/python/mcmc/transformed_kernel_test.py @@ -252,8 +252,8 @@ def target_log_prob(x, y): step_size=[1.23 / 0.75, 1.23 / 0.5], num_leapfrog_steps=2), bijector=[ - tfb.AffineScalar(scale=0.75), - tfb.AffineScalar(scale=0.5), + tfb.Scale(scale=0.75), + tfb.Scale(scale=0.5), ]) # Recall, tfp.mcmc.sample_chain calls # transformed_hmc.bootstrap_results too. @@ -304,7 +304,7 @@ def test_bootstrap_correctly_untransforms(self): def test_copy_works(self): transformed = tfp.mcmc.TransformedTransitionKernel( inner_kernel=FakeInnerKernel(target_log_prob_fn=fake_target_log_prob), - bijector=tfb.AffineScalar(2.)) + bijector=tfb.Scale(2.)) transformed_copy = tfp.mcmc.TransformedTransitionKernel( **transformed.parameters) diff --git a/tensorflow_probability/python/random/random_ops.py b/tensorflow_probability/python/random/random_ops.py index f4ed20da2a..80c957e8ad 100644 --- a/tensorflow_probability/python/random/random_ops.py +++ b/tensorflow_probability/python/random/random_ops.py @@ -131,7 +131,7 @@ def spherical_uniform( """ with tf.name_scope(name or 'spherical_uniform'): seed = samplers.sanitize_seed(seed) - dimension = ps.convert_to_shape_tensor(tf.cast(dimension, dtype=tf.int32)) + dimension = ps.convert_to_shape_tensor(ps.cast(dimension, dtype=tf.int32)) shape = ps.convert_to_shape_tensor(shape, dtype=tf.int32) dimension_static = tf.get_static_value(dimension) sample_shape = ps.concat([shape, [dimension]], axis=0) diff --git a/tensorflow_probability/python/sts/autoregressive.py b/tensorflow_probability/python/sts/autoregressive.py index fdd32f60b2..8752cb40fe 100644 --- a/tensorflow_probability/python/sts/autoregressive.py +++ b/tensorflow_probability/python/sts/autoregressive.py @@ -387,7 +387,7 @@ def __init__(self, coefficients_prior, coefficient_constraining_bijector), Parameter('level_scale', level_scale_prior, - tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()])) ], latent_size=order, diff --git a/tensorflow_probability/python/sts/dynamic_regression.py b/tensorflow_probability/python/sts/dynamic_regression.py index 6ad093ece2..bef1b1e2c8 100644 --- a/tensorflow_probability/python/sts/dynamic_regression.py +++ b/tensorflow_probability/python/sts/dynamic_regression.py @@ -314,7 +314,7 @@ def __init__(self, super(DynamicLinearRegression, self).__init__( parameters=[ Parameter('drift_scale', drift_scale_prior, - tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()])) ], latent_size=num_features, diff --git a/tensorflow_probability/python/sts/local_level.py b/tensorflow_probability/python/sts/local_level.py index bba44f437e..8a7a3f378e 100644 --- a/tensorflow_probability/python/sts/local_level.py +++ b/tensorflow_probability/python/sts/local_level.py @@ -327,7 +327,7 @@ def __init__(self, super(LocalLevel, self).__init__( parameters=[ Parameter('level_scale', level_scale_prior, - tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()])), ], latent_size=1, diff --git a/tensorflow_probability/python/sts/local_linear_trend.py b/tensorflow_probability/python/sts/local_linear_trend.py index 2ccb6cb0a7..61ed35d77e 100644 --- a/tensorflow_probability/python/sts/local_linear_trend.py +++ b/tensorflow_probability/python/sts/local_linear_trend.py @@ -404,7 +404,7 @@ def __init__(self, initial_slope_prior.stddev() ], axis=-1)) - scaled_softplus = tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + scaled_softplus = tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()]) super(LocalLinearTrend, self).__init__( parameters=[ diff --git a/tensorflow_probability/python/sts/seasonal.py b/tensorflow_probability/python/sts/seasonal.py index 0e665450ff..7c0755a384 100644 --- a/tensorflow_probability/python/sts/seasonal.py +++ b/tensorflow_probability/python/sts/seasonal.py @@ -881,7 +881,7 @@ def __init__(self, if allow_drift: parameters.append(Parameter( 'drift_scale', drift_scale_prior, - tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()]))) self._allow_drift = allow_drift diff --git a/tensorflow_probability/python/sts/semilocal_linear_trend.py b/tensorflow_probability/python/sts/semilocal_linear_trend.py index f5f00b689d..acab1d2bd1 100644 --- a/tensorflow_probability/python/sts/semilocal_linear_trend.py +++ b/tensorflow_probability/python/sts/semilocal_linear_trend.py @@ -429,7 +429,7 @@ def __init__(self, else: autoregressive_coef_bijector = tfb.Identity() # unconstrained - stddev_preconditioner = tfb.AffineScalar(scale=observed_stddev) + stddev_preconditioner = tfb.Scale(scale=observed_stddev) scaled_softplus = tfb.Chain([stddev_preconditioner, tfb.Softplus()]) super(SemiLocalLinearTrend, self).__init__( parameters=[ diff --git a/tensorflow_probability/python/sts/smooth_seasonal.py b/tensorflow_probability/python/sts/smooth_seasonal.py index 3c79e6f74a..c18848f0e8 100644 --- a/tensorflow_probability/python/sts/smooth_seasonal.py +++ b/tensorflow_probability/python/sts/smooth_seasonal.py @@ -441,7 +441,7 @@ def __init__(self, if allow_drift: parameters.append(Parameter( 'drift_scale', drift_scale_prior, - tfb.Chain([tfb.AffineScalar(scale=observed_stddev), + tfb.Chain([tfb.Scale(scale=observed_stddev), tfb.Softplus()]))) self._allow_drift = allow_drift diff --git a/tensorflow_probability/python/sts/sum.py b/tensorflow_probability/python/sts/sum.py index 7fd015b094..96e2f14485 100644 --- a/tensorflow_probability/python/sts/sum.py +++ b/tensorflow_probability/python/sts/sum.py @@ -460,7 +460,7 @@ def __init__(self, parameters = [Parameter('observation_noise_scale', observation_noise_scale_prior, tfb.Chain([ - tfb.AffineScalar(scale=observed_stddev), + tfb.Scale(scale=observed_stddev), tfb.Softplus()]))] for component in components: for parameter in component.parameters: diff --git a/tensorflow_probability/python/version.py b/tensorflow_probability/python/version.py index 0b75390eae..e18bcad27d 100644 --- a/tensorflow_probability/python/version.py +++ b/tensorflow_probability/python/version.py @@ -24,7 +24,7 @@ # stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a # release branch, the current version is by default assumed to be a # 'development' version, labeled 'dev'. -_VERSION_SUFFIX = 'rc4' +_VERSION_SUFFIX = '' # Example, '0.4.0-dev' __version__ = '.'.join([ diff --git a/tensorflow_probability/python/vi/optimization.py b/tensorflow_probability/python/vi/optimization.py index db5ad4d7fc..85e01c7f15 100644 --- a/tensorflow_probability/python/vi/optimization.py +++ b/tensorflow_probability/python/vi/optimization.py @@ -24,7 +24,7 @@ from tensorflow_probability.python import math as tfp_math from tensorflow_probability.python.vi import csiszar_divergence -_trace_loss = lambda loss, grads, variables: loss +_trace_loss = lambda traceable_quantities: traceable_quantities.loss # Silent fallback to score-function gradients leads to difficult-to-debug # failures, so we force reparameterization gradients by default. diff --git a/tensorflow_probability/python/vi/optimization_test.py b/tensorflow_probability/python/vi/optimization_test.py index 546d40f020..27951c0b91 100644 --- a/tensorflow_probability/python/vi/optimization_test.py +++ b/tensorflow_probability/python/vi/optimization_test.py @@ -159,7 +159,7 @@ def variational_model_fn(): num_steps=100, seed=test_util.test_seed(), sample_size=1, - trace_fn=lambda loss, grads, variables: (loss, q.sample(seed=42)[0])) + trace_fn=lambda t: (t.loss, q.sample(seed=42)[0])) self.evaluate(tf1.global_variables_initializer()) losses_, sample_path_ = self.evaluate((losses, sample_path)) diff --git a/testing/install_test_dependencies.sh b/testing/install_test_dependencies.sh index 5007fdb30f..1bdc028a67 100755 --- a/testing/install_test_dependencies.sh +++ b/testing/install_test_dependencies.sh @@ -178,7 +178,7 @@ install_python_packages() { # The following unofficial dependencies are used only by tests. # TODO(b/148685448): Unpin Hypothesis and coverage versions. - python -m pip install $PIP_FLAGS hypothesis==3.56.5 coverage==4.4.2 matplotlib mock scipy + python -m pip install $PIP_FLAGS hypothesis==3.56.5 coverage==4.4.2 matplotlib mock mpmath scipy # Install additional TFP dependencies. python -m pip install $PIP_FLAGS decorator 'cloudpickle>=1.3' dm-tree