diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml
new file mode 100644
index 0000000000..165259aade
--- /dev/null
+++ b/.github/workflows/continuous-integration.yml
@@ -0,0 +1,70 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+name: Tests
+on: [push, pull_request]
+env:
+ TEST_VENV_PATH: ~/test_virtualenv
+jobs:
+ lints:
+ name: Lints
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: [3.7]
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v1
+ with:
+ fetch-depth: 20
+ - name: Setup Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Setup virtualenv
+ run: |
+ sudo apt install virtualenv
+ virtualenv -p python${{ matrix.python-version }} ${TEST_VENV_PATH}
+ - name: Lints
+ run: |
+ source ${TEST_VENV_PATH}/bin/activate
+ ./testing/run_github_lints.sh
+ tests:
+ name: Tests
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: [3.7]
+ shard: [0, 1, 2, 3, 4]
+ env:
+ TEST_VENV_PATH: ~/test_virtualenv
+ SHARD: ${{ matrix.shard }}
+ NUM_SHARDS: 5
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v1
+ with:
+ fetch-depth: 1
+ - name: Setup Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Setup virtualenv
+ run: |
+ sudo apt install virtualenv
+ virtualenv -p python${{ matrix.python-version }} ${TEST_VENV_PATH}
+ - name: Tests
+ run: |
+ source ${TEST_VENV_PATH}/bin/activate
+ ./testing/run_github_tests.sh
diff --git a/.travis.yml b/.travis.yml
deleted file mode 100644
index 1eee0329b0..0000000000
--- a/.travis.yml
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright 2018 The TensorFlow Probability Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-language: python
-dist: xenial
-git:
- depth: 20
- quiet: true
-
-# Specify that the lint build stage should run first; if it fails, the test
-# build stage won't run.
-stages:
- - lint
- - test
-
-# This creates the lint stage, which only runs the linter. The relevant travis
-# feature is "Build stages" (https://docs.travis-ci.com/user/build-stages/).
-# Unfortunately it doesn't support matrix expansion, so we define the test
-# stage the "old-fashioned" way down below this section.
-jobs:
- include:
- - stage: lint
- python: "3.5"
- script: ./testing/run_travis_lints.sh
-
-# The below implicitly run under a build stage called "Test". The full matrix
-# of python version(s) x SHARD will be expanded and run in parallel.
-python:
- - "3.5"
-env:
- global:
- - NUM_SHARDS=8
- matrix:
- # We shard our tests to avoid timeouts. The run_tests.sh script uses the
- # $NUM_SHARDS and $SHARD environment variables to select a partition of the
- # set of all tests.
- - SHARD=0
- - SHARD=1
- - SHARD=2
- - SHARD=3
- - SHARD=4
- - SHARD=5
- - SHARD=6
- - SHARD=7
-script: ./testing/run_travis_tests.sh
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index fa289ca67c..7266106d68 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -32,20 +32,19 @@ repository (with credit to the original author) and closes the pull request.
## Continuous Integration
-We use [Travis CI](https://travis-ci.org/tensorflow/probability) to do automated
-style checking and run unit-tests (discussed in more detail below). A build
-will be triggered when you open a pull request, or update the pull request by
-adding a commit, rebasing etc.
+We use [GitHub Actions](https://github.com/tensorflow/probability/actions) to do
+automated style checking and run unit-tests (discussed in more detail below). A
+build will be triggered when you open a pull request, or update the pull request
+by adding a commit, rebasing etc.
-We test against TensorFlow nightly on Python 2.7 and 3.6. We shard our tests
+We test against TensorFlow nightly on Python 3.7. We shard our tests
across several build jobs (identified by the `SHARD` environment variable).
-Linting, in particular, is only done on the first shard, so look at that shard's
-logs for lint errors if any.
+Lints are also done in a separate job.
All pull-requests will need to pass the automated lint and unit-tests before
-being merged. As Travis-CI tests can take a bit of time, see the following
-sections on how to run the lint checks and unit-tests locally while you're
-developing your change.
+being merged. As the tests can take a bit of time, see the following sections
+on how to run the lint checks and unit-tests locally while you're developing
+your change.
## Style
diff --git a/discussion/fun_mcmc/prefab.py b/discussion/fun_mcmc/prefab.py
index 8b3a2cefa4..2f323acdb7 100644
--- a/discussion/fun_mcmc/prefab.py
+++ b/discussion/fun_mcmc/prefab.py
@@ -347,7 +347,7 @@ def kernel(adaptive_hmc_state):
hmc_state.state,
axis=tuple(range(chain_ndims)) if chain_ndims else None,
window_size=int(np.prod(hmc_state.target_log_prob.shape)) *
- variance_window_steps)
+ variance_window_steps) # pytype: disable=wrong-arg-types
if num_adaptation_steps is not None:
# Take care of adaptation for variance and step size.
diff --git a/spinoffs/inference_gym/inference_gym/BUILD b/spinoffs/inference_gym/inference_gym/BUILD
index f0b13cc51b..e9ce9db7a3 100644
--- a/spinoffs/inference_gym/inference_gym/BUILD
+++ b/spinoffs/inference_gym/inference_gym/BUILD
@@ -16,7 +16,7 @@
# A package for target densities and benchmarking of inference algorithms
# against the same.
-# [internal] load pytype.bzl (pytype_library, pytype_strict_library)
+# [internal] load pytype.bzl (pytype_strict_library)
# [internal] load dummy dependency
package(
@@ -42,7 +42,6 @@ py_library(
],
)
-# pytype
py_library(
name = "using_numpy",
srcs = ["using_numpy.py"],
@@ -56,7 +55,6 @@ py_library(
],
)
-# pytype
py_library(
name = "using_jax",
srcs = ["using_jax.py"],
@@ -71,7 +69,6 @@ py_library(
],
)
-# pytype
py_library(
name = "using_tensorflow",
srcs = ["using_tensorflow.py"],
diff --git a/spinoffs/oryx/oryx/core/interpreters/harvest.py b/spinoffs/oryx/oryx/core/interpreters/harvest.py
index c376f25662..5807ce9814 100644
--- a/spinoffs/oryx/oryx/core/interpreters/harvest.py
+++ b/spinoffs/oryx/oryx/core/interpreters/harvest.py
@@ -333,10 +333,16 @@ def process_higher_order_primitive(self, primitive, f, tracers, params,
if is_map:
# TODO(sharadmv): figure out if invars are mapped or unmapped
params = params.copy()
+ out_axes_thunk = params['out_axes_thunk']
+ @jax_util.as_hashable_function(closure=('harvest', out_axes_thunk))
+ def new_out_axes_thunk():
+ out_axes = out_axes_thunk()
+ assert all(out_axis == 0 for out_axis in out_axes)
+ return (0,) * out_tree().num_leaves
new_params = dict(
params,
- in_axes=(0,) * len(tree_util.tree_leaves(plants)) +
- params['in_axes'])
+ in_axes=(0,) * len(tree_util.tree_leaves(plants)) + params['in_axes'],
+ out_axes_thunk=new_out_axes_thunk)
else:
new_params = dict(params)
all_args, all_tree = tree_util.tree_flatten((plants, vals))
@@ -344,11 +350,10 @@ def process_higher_order_primitive(self, primitive, f, tracers, params,
if 'donated_invars' in params:
new_params['donated_invars'] = ((False,) * num_plants
+ params['donated_invars'])
- f, aux = harvest_eval(f, self, context.settings, all_tree)
+ f, out_tree = harvest_eval(f, self, context.settings, all_tree)
out_flat = primitive.bind(
f, *all_args, **new_params, name=jax_util.wrap_name(name, 'harvest'))
- out_tree = aux()
- out, reaps = tree_util.tree_unflatten(out_tree, out_flat)
+ out, reaps = tree_util.tree_unflatten(out_tree(), out_flat)
out_tracers = safe_map(self.pure, out)
reap_tracers = tree_util.tree_map(self.pure, reaps)
if primitive is nest_p and reap_tracers:
diff --git a/spinoffs/oryx/oryx/core/interpreters/inverse/core.py b/spinoffs/oryx/oryx/core/interpreters/inverse/core.py
index 0044c247eb..e8c86ef2da 100644
--- a/spinoffs/oryx/oryx/core/interpreters/inverse/core.py
+++ b/spinoffs/oryx/oryx/core/interpreters/inverse/core.py
@@ -178,7 +178,7 @@ def wrapped(*args, **kwargs):
flat_incells = [InverseAndILDJ.unknown(aval) for aval in flat_forward_avals]
flat_outcells = safe_map(InverseAndILDJ.new, flat_args)
env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr.jaxpr,
- flat_constcells, flat_incells, flat_outcells)
+ flat_constcells, flat_incells, flat_outcells) # pytype: disable=wrong-arg-types
flat_incells = [env.read(invar) for invar in jaxpr.jaxpr.invars]
if any(not flat_incell.top() for flat_incell in flat_incells):
raise ValueError('Cannot invert function.')
@@ -332,7 +332,7 @@ def hop_inverse_rule(prim):
def initial_ildj(incells, outcells, *, jaxpr, num_consts, **_):
const_cells, incells = jax_util.split_list(incells, [num_consts])
env = propagate.propagate(InverseAndILDJ, ildj_registry, jaxpr, const_cells,
- incells, outcells)
+ incells, outcells) # pytype: disable=wrong-arg-types
new_incells = [env.read(invar) for invar in jaxpr.invars]
new_outcells = [env.read(outvar) for outvar in jaxpr.outvars]
return const_cells + new_incells, new_outcells, None
@@ -377,6 +377,12 @@ def remove_slice(cell):
new_params = dict(params, in_axes=new_in_axes)
if 'donated_invars' in params:
new_params['donated_invars'] = (False,) * len(flat_vals)
+ if 'out_axes' in params:
+ assert all(out_axis == 0 for out_axis in params['out_axes'])
+ new_params['out_axes_thunk'] = jax_util.HashableFunction(
+ lambda: (0,) * aux().num_leaves,
+ closure=('ildj', params['out_axes']))
+ del new_params['out_axes']
subenv_vals = prim.bind(f, *flat_vals, **new_params)
subenv_tree = aux()
subenv = tree_util.tree_unflatten(subenv_tree, subenv_vals)
diff --git a/spinoffs/oryx/oryx/core/interpreters/inverse/inverse_test.py b/spinoffs/oryx/oryx/core/interpreters/inverse/inverse_test.py
index 0e7dd28644..06ecf51609 100644
--- a/spinoffs/oryx/oryx/core/interpreters/inverse/inverse_test.py
+++ b/spinoffs/oryx/oryx/core/interpreters/inverse/inverse_test.py
@@ -249,6 +249,14 @@ def f(x, y):
onp.testing.assert_allclose(y, np.ones(2))
onp.testing.assert_allclose(ildj_, 0., atol=1e-6, rtol=1e-6)
+ def test_inverse_of_reshape(self):
+ def f(x):
+ return np.reshape(x, (4,))
+ f_inv = core.inverse_and_ildj(f, np.ones((2, 2)))
+ x, ildj_ = f_inv(np.ones(4))
+ onp.testing.assert_allclose(x, np.ones((2, 2)))
+ onp.testing.assert_allclose(ildj_, 0.)
+
def test_sigmoid_ildj(self):
def naive_sigmoid(x):
# This is the default JAX implementation of sigmoid.
diff --git a/spinoffs/oryx/oryx/core/interpreters/inverse/rules.py b/spinoffs/oryx/oryx/core/interpreters/inverse/rules.py
index 2ea0ea170f..52ce54e65c 100644
--- a/spinoffs/oryx/oryx/core/interpreters/inverse/rules.py
+++ b/spinoffs/oryx/oryx/core/interpreters/inverse/rules.py
@@ -166,9 +166,8 @@ def reshape_ildj(incells, outcells, **params):
))], None
elif outcell.top() and not incell.top():
val = outcell.val
- ndslice = NDSlice.new(np.reshape(val, incell.aval.shape))
new_incells = [
- InverseAndILDJ(incell.aval, [ndslice])
+ InverseAndILDJ.new(np.reshape(val, incell.aval.shape))
]
return new_incells, outcells, None
return incells, outcells, None
diff --git a/spinoffs/oryx/oryx/core/interpreters/unzip.py b/spinoffs/oryx/oryx/core/interpreters/unzip.py
index c7cbd613dd..a9d294fc14 100644
--- a/spinoffs/oryx/oryx/core/interpreters/unzip.py
+++ b/spinoffs/oryx/oryx/core/interpreters/unzip.py
@@ -288,19 +288,29 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
in_pvals = [pval if pval.is_known() or in_axis is None else
unknown(mapped_aval(params['axis_size'], in_axis, pval[0]))
for pval, in_axis in zip(in_pvals, params['in_axes'])]
+ out_axes_thunk = params['out_axes_thunk']
+ @jax_util.as_hashable_function(closure=('unzip', out_axes_thunk))
+ def new_out_axes_thunk():
+ out_axes = out_axes_thunk()
+ assert all(out_axis == 0 for out_axis in out_axes)
+ _, num_outputs, _ = aux()
+ return (0,) * num_outputs
+ new_params = dict(params, out_axes_thunk=new_out_axes_thunk)
+ else:
+ new_params = params
pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
keys = tuple(t.is_key() for t in tracers)
new_settings = UnzipSettings(settings.tag, call_primitive in block_registry)
fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
- out_flat = call_primitive.bind(fun, *in_consts, **params)
- success, results = aux()
+ out_flat = call_primitive.bind(fun, *in_consts, **new_params)
+ success, _, results = aux()
if not success:
out_pvs, out_keys, jaxpr, env = results
out_pv_consts, consts = jax_util.split_list(out_flat, [len(out_pvs)])
- out_tracers = self._bound_output_tracers(call_primitive, params, jaxpr,
- consts, env, tracers, out_pvs,
- out_pv_consts, out_keys, name,
- is_map)
+ out_tracers = self._bound_output_tracers(call_primitive, new_params,
+ jaxpr, consts, env, tracers,
+ out_pvs, out_pv_consts,
+ out_keys, name, is_map)
return out_tracers
init_name = jax_util.wrap_name(name, 'init')
apply_name = jax_util.wrap_name(name, 'apply')
@@ -319,15 +329,16 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
[len(apply_pvs)])
variable_tracers = self._bound_output_tracers(
- call_primitive, params, init_jaxpr, init_consts, init_env, key_tracers,
- init_pvs, init_pv_consts, [True] * len(init_pvs), init_name, is_map)
+ call_primitive, new_params, init_jaxpr, init_consts, init_env,
+ key_tracers, init_pvs, init_pv_consts, [True] * len(init_pvs),
+ init_name, is_map)
unflat_variables = tree_util.tree_unflatten(variable_tree, variable_tracers)
if call_primitive is harvest.nest_p:
variable_dict = harvest.sow(
dict(safe_zip(variable_names, unflat_variables)),
tag=settings.tag,
- name=params['scope'],
+ name=new_params['scope'],
mode='strict')
unflat_variables = tuple(variable_dict[name] for name in variable_names)
else:
@@ -342,7 +353,7 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
variable_tracers = tree_util.tree_leaves(unflat_variables)
out_tracers = self._bound_output_tracers(
- call_primitive, params, apply_jaxpr, apply_consts, apply_env,
+ call_primitive, new_params, apply_jaxpr, apply_consts, apply_env,
variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts,
apply_keys, apply_name, is_map)
return out_tracers
@@ -365,6 +376,11 @@ def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
tuple(v for v, t in zip(params['donated_invars'], in_tracers)
if not t.pval.is_known()))
new_params['donated_invars'] = new_donated_invars
+ if is_map:
+ out_axes = params['out_axes_thunk']()
+ assert all(out_axis == 0 for out_axis in out_axes)
+ new_params['out_axes'] = (0,) * len(out_tracers)
+ del new_params['out_axes_thunk']
eqn = pe.new_eqn_recipe(
tuple(const_tracers + env_tracers + in_tracers), out_tracers, primitive,
new_params, source_info_util.current()) # pytype: disable=wrong-arg-types
@@ -442,14 +458,16 @@ def unzip_eval_wrapper(pvs, *consts):
out = (
tuple(init_pv_consts) + tuple(init_consts) + tuple(apply_pv_consts) +
tuple(apply_consts))
- yield out, (success, ((init_pvs, len(init_consts), apply_pvs),
- (init_jaxpr, apply_jaxpr), (init_env,
- apply_env), metadata))
+ yield out, (success, len(out),
+ ((init_pvs, len(init_consts), apply_pvs),
+ (init_jaxpr, apply_jaxpr),
+ (init_env, apply_env),
+ metadata))
else:
jaxpr, (out_pvals, out_keys, consts, env) = result
out_pvs, out_consts = jax_util.unzip2(out_pvals)
out = tuple(out_consts) + tuple(consts)
- yield out, (success, (out_pvs, out_keys, jaxpr, env))
+ yield out, (success, len(out), (out_pvs, out_keys, jaxpr, env))
@lu.transformation
diff --git a/spinoffs/oryx/oryx/core/state/module.py b/spinoffs/oryx/oryx/core/state/module.py
index 0d566dcf3c..70a840aefd 100644
--- a/spinoffs/oryx/oryx/core/state/module.py
+++ b/spinoffs/oryx/oryx/core/state/module.py
@@ -168,4 +168,4 @@ def variables(self) -> Dict[str, Any]:
@ppl.log_prob.register(Module)
def module_log_prob(module, *args, **kwargs):
- return log_prob.log_prob(module, *args, **kwargs)
+ return log_prob.log_prob(module, *args, **kwargs) # pytype: disable=wrong-arg-count
diff --git a/spinoffs/oryx/oryx/experimental/mcmc/kernels.py b/spinoffs/oryx/oryx/experimental/mcmc/kernels.py
index 5895afd3eb..bcf02b84f3 100644
--- a/spinoffs/oryx/oryx/experimental/mcmc/kernels.py
+++ b/spinoffs/oryx/oryx/experimental/mcmc/kernels.py
@@ -52,8 +52,8 @@ def step(key, state):
def _sample(key, state):
return ppl.random_variable(
- bd.Independent(
- bd.Normal(state, scale),
+ bd.Independent( # pytype: disable=module-attr
+ bd.Normal(state, scale), # pytype: disable=module-attr
reinterpreted_batch_ndims=np.ndim(state)))(
key)
@@ -176,7 +176,7 @@ def momentum_distribution(key):
def _sample(key, s):
return ppl.random_variable(
- bd.Sample(bd.Normal(0., 1.),
+ bd.Sample(bd.Normal(0., 1.), # pytype: disable=module-attr
sample_shape=s.shape))(key).astype(s.dtype)
return tree_util.tree_multimap(_sample, momentum_keys, state)
diff --git a/spinoffs/oryx/oryx/experimental/nn/combinator.py b/spinoffs/oryx/oryx/experimental/nn/combinator.py
index 8f8be28f0b..759de0b6f9 100644
--- a/spinoffs/oryx/oryx/experimental/nn/combinator.py
+++ b/spinoffs/oryx/oryx/experimental/nn/combinator.py
@@ -40,7 +40,7 @@ def initialize(cls, init_key, *args):
"""
in_specs, layer_inits = args[:-1], args[-1]
layers = state.init(list(layer_inits), name='layers')(init_key, *in_specs)
- return base.LayerParams(tuple(layers))
+ return base.LayerParams(tuple(layers)) # pytype: disable=wrong-arg-types
@classmethod
def spec(cls, *args):
diff --git a/tensorflow_probability/examples/jupyter_notebooks/JointDistributionAutoBatched_A_Gentle_Tutorial.ipynb b/tensorflow_probability/examples/jupyter_notebooks/JointDistributionAutoBatched_A_Gentle_Tutorial.ipynb
index 954921edae..a2032d9379 100644
--- a/tensorflow_probability/examples/jupyter_notebooks/JointDistributionAutoBatched_A_Gentle_Tutorial.ipynb
+++ b/tensorflow_probability/examples/jupyter_notebooks/JointDistributionAutoBatched_A_Gentle_Tutorial.ipynb
@@ -1,1566 +1,1614 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "i3puWgvKeyWu"
- },
- "source": [
- "# Auto-Batched Joint Distributions: A Gentle Tutorial"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ZrwVQsM9TiUw"
- },
- "source": [
- "##### Copyright 2020 The TensorFlow Authors.\n",
- "\n",
- "Licensed under the Apache License, Version 2.0 (the \"License\");"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "cellView": "form",
- "id": "CpDUTVKYTowI"
- },
- "outputs": [],
- "source": [
- "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
- "# you may not use this file except in compliance with the License.\n",
- "# You may obtain a copy of the License at\n",
- "#\n",
- "# https://www.apache.org/licenses/LICENSE-2.0\n",
- "#\n",
- "# Unless required by applicable law or agreed to in writing, software\n",
- "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
- "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
- "# See the License for the specific language governing permissions and\n",
- "# limitations under the License."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ltPJCG6pAUoc"
- },
- "source": [
- "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/probability/examples/Modeling_with_JointDistribution\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Modeling_with_JointDistribution.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Modeling_with_JointDistribution.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- " \u003ctd\u003e\n",
- " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/probability/tensorflow_probability/examples/jupyter_notebooks/Modeling_with_JointDistribution.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
- " \u003c/td\u003e\n",
- "\u003c/table\u003e"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "zzaOJSXagzMY"
- },
- "source": [
- "### Introduction"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "cIvB2CSBe49Z"
- },
- "source": [
- "TensorFlow Probability (TFP) offers a number of `JointDistribution` abstractions that make probabilistic inference easier by allowing a user to easily express a probabilistic graphical model in a near-mathematical form; the abstraction generates methods for sampling from the model and evaluating the log probability of samples from the model. In this tutorial, we review \"autobatched\" variants, which were developed after the original `JointDistribution` abstractions. Relative to the original, non-autobatched abstractions, the autobatched versions are simpler to use and more ergonomic, allowing many models to be expressed with less boilerplate. In this colab, we explore a simple model in (perhaps tedious) detail, making clear the problems autobatching solves, and (hopefully) teaching the reader more about TFP shape concepts along the way.\n",
- "\n",
- "Prior to the introduction of autobatching, there were a few different variants of `JointDistribution`, corresponding to different syntactic styles for expressing probabilistic models: `JointDistributionSequential`, `JointDistributionNamed`, and`JointDistributionCoroutine`. Auobatching exists as a mixin, so we now have `AutoBatched` variants of all of these. In this tutorial, we explore the differences between `JointDistributionSequential` and `JointDistributionSequentialAutoBatched`; however, everything we do here is applicable to the other variants with essentially no changes.\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "uiR4-VOt9NFX"
- },
- "source": [
- "### Dependencies \u0026 Prerequisites\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "id": "coUnDhkpT5_6"
- },
- "outputs": [],
- "source": [
- "#@title Import and set ups{ display-mode: \"form\" }\n",
- "\n",
- "import functools\n",
- "import numpy as np\n",
- "\n",
- "import tensorflow.compat.v2 as tf\n",
- "tf.enable_v2_behavior()\n",
- "\n",
- "import tensorflow_probability as tfp\n",
- "\n",
- "tfd = tfp.distributions"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "KohBmaTn5W7I"
- },
- "source": [
- "### Prerequisite: A Bayesian Regression Problem"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "vChyK0vr9XD8"
- },
- "source": [
- "We'll consider a very simple Bayesian regression scenario:\n",
- "\\begin{align*}\n",
- "m \u0026 \\sim \\text{Normal}(0, 1) \\\\\n",
- "b \u0026 \\sim \\text{Normal}(0, 1) \\\\\n",
- "Y \u0026 \\sim \\text{Normal}(mX + b, 1)\n",
- "\\end{align*}\n",
- "\n",
- "In this model, `m` and `b` are drawn from standard normals, and the observations `Y` are drawn from a normal distribution whose mean depends on the random variables `m` and `b`, and some (nonrandom, known) covariates `X`. (For simplicity, in this example, we assume the scale of all random variables is known.)\n",
- "\n",
- "To perform inference in this model, we'd need to know both the covariates `X` and the observations `Y`, but for the purposes of this tutorial, we'll only need `X`, so we define a simple dummy `X`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "id": "UIpJ_cXUVabB"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array([0, 1, 2, 3, 4, 5, 6])"
- ]
- },
- "execution_count": 3,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "X = np.arange(7)\n",
- "X"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "CIBpupyt9GTT"
- },
- "source": [
- "### Desiderata"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "j2uzL_uI9tqO"
- },
- "source": [
- "In probabilistic inference, we often want to perform two basic operations:\n",
- "- `sample`: Drawing samples from the model.\n",
- "- `log_prob`: Computing the log probability of a sample from the model.\n",
- "\n",
- "The key contribution of TFP's `JointDistribution` abstractions (as well as of many other approaches to probabilistic programming) is to allow users to write a model *once* and have access to both `sample` and `log_prob` computations.\n",
- "\n",
- "Noting that we have 7 points in our data set (`X.shape = (7,)`), we can now state the desiderata for an excellent `JointDistribution`:\n",
- "\n",
- "* `sample()` should produce a list of `Tensors` having shape `[(), (), (7,)`], corresponding to the scalar slope, scalar bias, and vector observations, respectively.\n",
- "* `log_prob(sample())` should produce a scalar: the log probability of a particular slope, bias, and observations.\n",
- "* `sample([5, 3])` should produce a list of `Tensors` having shape `[(5, 3), (5, 3), (5, 3, 7)]`, representing a `(5, 3)`-*batch* of samples from the model.\n",
- "* `log_prob(sample([5, 3]))` should produce a `Tensor` with shape (5, 3).\n",
- "\n",
- "We'll now look at a succession of `JointDistribution` models, see how to achieve the above desiderata, and hopefully learn a little more about TFP shapes along the way. \n",
- "\n",
- "Spoiler alert: The approach that satisfies the above desiderata without added boilerplate is [autobatching](#scrollTo=_h7sJ2bkfOS7). "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "QiII0ypZcyTY"
- },
- "source": [
- "### First Attempt; `JointDistributionSequential`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "id": "kY501q-QVR9g"
- },
- "outputs": [],
- "source": [
- "jds = tfd.JointDistributionSequential([\n",
- " tfd.Normal(loc=0., scale=1.), # m\n",
- " tfd.Normal(loc=0., scale=1.), # b\n",
- " lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y\n",
- "])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "hzNPPqJ-BwA-"
- },
- "source": [
- "This is more or less a direct translation of the model into code. The slope `m` and bias `b` are straightforward. `Y` is defined using a `lambda`-function: the general pattern is that a `lambda`-function of $k$ arguments in a `JointDistributionSequential` (JDS) uses the previous $k$ distributions in the model. Note the \"reverse\" order."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "5jIvsQSOD81N"
- },
- "source": [
- "We'll call `sample_distributions`, which returns both a sample *and* the underlying \"sub-distributions\" that were used to generate the sample. (We could have produced just the sample by calling `sample`; later in the tutorial it will be convenient to have the distributions as well.) The sample we produce is fine:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "id": "y05IrsfiaxCh"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[\u003ctf.Tensor: shape=(), dtype=float32, numpy=0.08079692\u003e,\n",
- " \u003ctf.Tensor: shape=(), dtype=float32, numpy=-1.5032883\u003e,\n",
- " \u003ctf.Tensor: shape=(7,), dtype=float32, numpy=\n",
- " array([-1.906176 , 0.53724945, -0.30291188, -0.86593336, -0.00641394,\n",
- " -0.58248115, -2.907504 ], dtype=float32)\u003e]"
- ]
- },
- "execution_count": 5,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dists, s = jds.sample_distributions()\n",
- "s"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "o7E1WkoCEB12"
- },
- "source": [
- "But `log_prob` produces a result with an undesired shape:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "id": "xR0lbgjNay4X"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(7,), dtype=float32, numpy=\n",
- "array([-3.9711766, -5.8103094, -4.429552 , -3.9680157, -4.578788 ,\n",
- " -4.02357 , -5.674173 ], dtype=float32)\u003e"
- ]
- },
- "execution_count": 6,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds.log_prob(s)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "1mMIs28LEJqN"
- },
- "source": [
- "And multiple sampling doesn't work:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "id": "LbfRiIsfc9Hf"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Incompatible shapes: [5,3] vs. [7] [Op:Mul]\n"
- ]
- }
- ],
- "source": [
- "try:\n",
- " jds.sample([5, 3])\n",
- "except tf.errors.InvalidArgumentError as e:\n",
- " print(e)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Rnvtz3SQHrVL"
- },
- "source": [
- "Let's try to understand what's going wrong."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Dp30JPCmHyuz"
- },
- "source": [
- "### A Brief Review: Batch and Event Shape"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "w24fZn3kH2uF"
- },
- "source": [
- "In TFP, an ordinary (not a `JointDistribution`) probability distribution has an *event shape* and a *batch shape*, and understanding the difference is crucial to effective use of TFP:\n",
- "\n",
- "* Event shape describes the shape of a single draw from the distribution; the draw may be dependent across dimensions. For scalar distributions, the event shape is []. For a 5-dimensional MultivariateNormal, the event shape is [5].\n",
- "* Batch shape describes independent, not identically distributed draws, aka a \"batch\" of distributions. Representing a batch of distributions in a single Python object is one of the key ways TFP achieves efficiency at scale.\n",
- "\n",
- "For our purposes, a critical fact to keep in mind is that if we call `log_prob` on a single sample from a distribution, the result will always have a shape that matches (i.e., has as rightmost dimensions) the *batch* shape.\n",
- "\n",
- "For a more in-depth discussion of shapes, see [the \"Undersanding TensorFlow Distributions Shapes\" tutorial](https://www.tensorflow.org/probability/examples/Understanding_TensorFlow_Distributions_Shapes).\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "nONZMjl-KtTz"
- },
- "source": [
- "### Why Doesn't `log_prob(sample())` Produce a Scalar? "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "VUKyGzkOJiuD"
- },
- "source": [
- "Let's use our knowledge of batch and event shape to explore what's happening with `log_prob(sample())`. Here's our sample again:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "id": "ijRGAnSBJwCG"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[\u003ctf.Tensor: shape=(), dtype=float32, numpy=0.08079692\u003e,\n",
- " \u003ctf.Tensor: shape=(), dtype=float32, numpy=-1.5032883\u003e,\n",
- " \u003ctf.Tensor: shape=(7,), dtype=float32, numpy=\n",
- " array([-1.906176 , 0.53724945, -0.30291188, -0.86593336, -0.00641394,\n",
- " -0.58248115, -2.907504 ], dtype=float32)\u003e]"
- ]
- },
- "execution_count": 8,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "s"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "NAzBAsu3OoLv"
- },
- "source": [
- "And here are our distributions:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "id": "_xtIUKf8Nq3G"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[\u003ctfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32\u003e,\n",
- " \u003ctfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32\u003e,\n",
- " \u003ctfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32\u003e]"
- ]
- },
- "execution_count": 9,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dists"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "LzkLnoZyFeU_"
- },
- "source": [
- "The log probability is computed by summing the log probabilities of the sub-distributions at the (matched) elements of the parts:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "id": "5XTDKVMPO5qg"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[\u003ctf.Tensor: shape=(), dtype=float32, numpy=-0.9222026\u003e,\n",
- " \u003ctf.Tensor: shape=(), dtype=float32, numpy=-2.0488763\u003e,\n",
- " \u003ctf.Tensor: shape=(7,), dtype=float32, numpy=\n",
- " array([-1.0000978 , -2.8392305 , -1.4584732 , -0.99693686, -1.6077087 ,\n",
- " -1.0524913 , -2.703094 ], dtype=float32)\u003e]"
- ]
- },
- "execution_count": 10,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "log_prob_parts = [dist.log_prob(ss) for (dist, ss) in zip(dists, s)]\n",
- "log_prob_parts"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "id": "QoWsVGx8N1IJ"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(7,), dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e"
- ]
- },
- "execution_count": 11,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "np.sum(log_prob_parts) - jds.log_prob(s)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ZJFvR4ZNFngd"
- },
- "source": [
- "So, one level of explanation is that the log probability calculation is returning a 7-Tensor because the third subcomponent of `log_prob_parts` is a 7-Tensor. But why?"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "zdpKnguOPOrr"
- },
- "source": [
- "Well, we see that the last element of `dists`, which corresponds to our distribution over `Y` in the mathematial formulation, has a `batch_shape` of `[7]`. In other words, our distribution over `Y` is a batch of 7 independent normals (with different means and, in this case, the same scale)."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "0WXzlR_diTuZ"
- },
- "source": [
- "We now understand what's wrong: in JDS, the distribution over `Y` has `batch_shape=[7]`, a sample from the JDS represents scalars for `m` and `b` and a \"batch\" of 7 independent normals. and `log_prob` computes 7 separate log-probabilities, each of which represents the log probability of drawing `m` and `b` and a single observation `Y[i]` at some `X[i]`."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "s9RI0oxCi_En"
- },
- "source": [
- "### Fixing `log_prob(sample())` with `Independent`"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "EOL1hllzjDcF"
- },
- "source": [
- "Recall that `dists[2]` has `event_shape=[]` and `batch_shape=[7]`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "id": "TA05J9VwjCLu"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctfp.distributions.Normal 'JointDistributionSequential_sample_distributions_Normal' batch_shape=[7] event_shape=[] dtype=float32\u003e"
- ]
- },
- "execution_count": 12,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dists[2]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "_xQ5ORIqjPAz"
- },
- "source": [
- "By using TFP's `Independent` metadistribution, which converts batch dimensions to event dimensions, we can convert this into a distribution with `event_shape=[7]` and `batch_shape=[]` (we'll rename it `y_dist_i` because it's a distribution on `Y`, with the `_i` standing in for our `Independent` wrapping): "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "id": "Aa_SPItTjLBO"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctfp.distributions.Independent 'IndependentJointDistributionSequential_sample_distributions_Normal' batch_shape=[] event_shape=[7] dtype=float32\u003e"
- ]
- },
- "execution_count": 13,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)\n",
- "y_dist_i"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "JrRjuDhhmBEr"
- },
- "source": [
- "Now, the `log_prob` of a 7-vector is a scalar:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "id": "y9yZs-kwdLGa"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(), dtype=float32, numpy=-11.658031\u003e"
- ]
- },
- "execution_count": 14,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "y_dist_i.log_prob(s[2])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "RqNEen4Ujkhh"
- },
- "source": [
- "Under the covers, `Independent` sums over the batch:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {
- "id": "SxYr1McJkWFx"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(), dtype=float32, numpy=0.0\u003e"
- ]
- },
- "execution_count": 15,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "y_dist_i.log_prob(s[2]) - tf.reduce_sum(dists[2].log_prob(s[2]))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "00lD003YkojA"
- },
- "source": [
- "And indeed, we can use this to construct a new `jds_i` (the `i` again stands for `Independent`) where `log_prob` returns a scalar:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {
- "id": "1jwoSeNWkhT6"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(), dtype=float32, numpy=-14.62911\u003e"
- ]
- },
- "execution_count": 16,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds_i = tfd.JointDistributionSequential([\n",
- " tfd.Normal(loc=0., scale=1.), # m\n",
- " tfd.Normal(loc=0., scale=1.), # b\n",
- " lambda b, m: tfd.Independent( # Y\n",
- " tfd.Normal(loc=m*X + b, scale=1.),\n",
- " reinterpreted_batch_ndims=1)\n",
- "])\n",
- "\n",
- "jds_i.log_prob(s)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "hYY3CNBXlAIZ"
- },
- "source": [
- "A couple notes:\n",
- "- `jds_i.log_prob(s)` is *not* the same as `tf.reduce_sum(jds.log_prob(s))`. The former produces the \"correct\" log probability of the joint distribution. The latter sums over a 7-Tensor, each element of which is the sum of the log probability of `m`, `b`, and a single element of the log probability of `Y`, so it overcounts `m` and `b`. (`log_prob(m) + log_prob(b) + log_prob(Y)` returns a result rather than throwing an exception because TFP follows TF and NumPy's broadcasting rules; adding a scalar to a vector produces a vector-sized result.)\n",
- "- In this particular case, we could have solved the problem and achieved the same result using `MultivariateNormalDiag` instead of `Independent(Normal(...))`. `MultivariateNormalDiag` is a vector-valued distribution (i.e., it already has vector event-shape). Indeeed `MultivariateNormalDiag` could be (but isn't) implemented as a composition of `Independent` and `Normal`. It's worthwhile to remember that given a vector `V`, samples from `n1 = Normal(loc=V)`, and `n2 = MultivariateNormalDiag(loc=V)` are indistinguishable; the difference beween these distributions is that `n1.log_prob(n1.sample())` is a vector and `n2.log_prob(n2.sample())` is a scalar."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "b-iFi65ZmvpB"
- },
- "source": [
- "### Multiple Samples?"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "PZcEBJS_nAhA"
- },
- "source": [
- "Drawing multiple samples still doesn't work:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "id": "PkvYmB3jm2sI"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Incompatible shapes: [5,3] vs. [7] [Op:Mul]\n"
- ]
- }
- ],
- "source": [
- "try:\n",
- " jds_i.sample([5, 3])\n",
- "except tf.errors.InvalidArgumentError as e:\n",
- " print(e)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "b9Jh0MTCn0Mr"
- },
- "source": [
- "Let's think about why. When we call `jds_i.sample([5, 3])`, we'll first draw samples for `m` and `b`, each with shape `(5, 3)`. Next, we're going to try to construct a `Normal` distribution via:\n",
- "```\n",
- "tfd.Normal(loc=m*X + b, scale=1.)\n",
- "```\n",
- "\n",
- "But if `m` has shape `(5, 3)` and `X` has shape `7`, we can't multiply them together, and indeed this is the error we're hitting:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "id": "ei9Z2Nozp8Dy"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Incompatible shapes: [5,3] vs. [7] [Op:Mul]\n"
- ]
- }
- ],
- "source": [
- "m = tfd.Normal(0., 1.).sample([5, 3])\n",
- "try:\n",
- " m * X\n",
- "except tf.errors.InvalidArgumentError as e:\n",
- " print(e)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "1uqaIx2LlaeP"
- },
- "source": [
- "To resolve this issue, let's think about what properties the distribution over `Y` has to have. If we've called `jds_i.sample([5, 3])`, then we know `m` and `b` will both have shape `(5, 3)`. What shape should a call to `sample` on the `Y` distribution produce? The obvious answer is `(5, 3, 7)`: for each batch point, we want a sample with the same size as `X`. We can achieve this by using TensorFlow's broadcasting capabilities, adding extra dimensions:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {
- "id": "-22Bg8Yfr6tg"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "TensorShape([5, 3, 1])"
- ]
- },
- "execution_count": 19,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "m[..., tf.newaxis].shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {
- "id": "7k21MOvlsHGe"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "TensorShape([5, 3, 7])"
- ]
- },
- "execution_count": 20,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "(m[..., tf.newaxis] * X).shape"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "5AEBbcjVsXQR"
- },
- "source": [
- "Adding an axis to both `m` and `b`, we can define a new JDS that supports multiple samples:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {
- "id": "9rJ9WCVQsW0S"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[\u003ctf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
- " array([[-1.0641694 , 0.88205844, -0.3132895 ],\n",
- " [ 0.7708484 , -0.08183189, -0.543864 ],\n",
- " [-0.46075284, -1.8269578 , 0.30572248],\n",
- " [-1.4730763 , -1.749881 , -0.18791775],\n",
- " [-1.1432608 , -0.03570032, -0.47378683]], dtype=float32)\u003e,\n",
- " \u003ctf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
- " array([[-0.05995331, 0.4670131 , 0.39853612],\n",
- " [-0.50897926, 0.55372673, -0.44930768],\n",
- " [ 2.12264 , -0.8941609 , -0.22456498],\n",
- " [-0.28325766, -0.6039566 , -0.7982028 ],\n",
- " [ 1.6194319 , -1.5981796 , -1.0267515 ]], dtype=float32)\u003e,\n",
- " \u003ctf.Tensor: shape=(5, 3, 7), dtype=float32, numpy=\n",
- " array([[[ 1.23666501e+00, -2.72573185e+00, -1.06902647e+00,\n",
- " -3.99592471e+00, -5.51451778e+00, -4.81725502e+00,\n",
- " -5.56694984e+00],\n",
- " [ 4.49036419e-01, 8.42256904e-01, 1.65697992e+00,\n",
- " 2.83218813e+00, 2.02821064e+00, 5.30640173e+00,\n",
- " 7.88480282e+00],\n",
- " [ 1.04598761e-01, -6.91915929e-01, -1.45380819e+00,\n",
- " -2.99107218e+00, -7.48243392e-01, -2.39095449e-01,\n",
- " -2.95680499e+00]],\n",
- " \n",
- " [[ 1.42532825e+00, -3.29503775e-01, 2.74825788e+00,\n",
- " 4.71045971e-01, 2.95442867e+00, 5.41281986e+00,\n",
- " 4.12423992e+00],\n",
- " [ 1.13851607e+00, 1.34247184e+00, 5.38553715e-01,\n",
- " 4.75679219e-01, 1.15889467e-01, 2.28273201e+00,\n",
- " 1.66366085e-01],\n",
- " [-8.50983739e-01, -2.25449228e+00, -1.62029576e+00,\n",
- " -2.47048473e+00, -8.28547478e-02, -1.62208068e+00,\n",
- " -3.38254881e+00]],\n",
- " \n",
- " [[ 2.75410676e+00, 1.73929715e+00, 1.65932381e+00,\n",
- " 1.43238759e+00, 7.23003149e-01, -4.07665223e-01,\n",
- " -5.24324298e-01],\n",
- " [-3.93893182e-01, -1.79903293e+00, -3.79906535e+00,\n",
- " -4.41074371e+00, -9.76827240e+00, -9.46045876e+00,\n",
- " -1.14899712e+01],\n",
- " [-1.37748170e+00, 5.45929432e-01, -8.51358235e-01,\n",
- " 2.76324749e-02, 5.16971350e-01, -6.29880428e-01,\n",
- " 2.23690033e+00]],\n",
- " \n",
- " [[ 2.06451297e+00, -2.04346943e+00, -3.22309828e+00,\n",
- " -5.45961189e+00, -5.86767960e+00, -7.99706030e+00,\n",
- " -8.01118088e+00],\n",
- " [-1.71845675e+00, -2.55129766e+00, -2.98688173e+00,\n",
- " -4.69979382e+00, -6.89284897e+00, -1.11423817e+01,\n",
- " -1.29737835e+01],\n",
- " [-1.13922238e-01, -1.64989650e-01, -1.72910857e+00,\n",
- " -2.97116470e+00, -2.48031807e+00, -2.05811620e+00,\n",
- " -1.51430011e+00]],\n",
- " \n",
- " [[ 2.13675165e+00, 1.30672932e+00, -3.27593088e-03,\n",
- " -1.38755083e+00, -1.46972406e+00, -3.88024116e+00,\n",
- " -4.52536440e+00],\n",
- " [-2.77965927e+00, -1.04991031e+00, -1.96163297e+00,\n",
- " -1.44081473e+00, -6.46156311e-01, -3.07756782e+00,\n",
- " -3.05591631e+00],\n",
- " [-6.00465536e-01, -1.80835783e+00, -2.14595556e+00,\n",
- " -2.22402120e+00, -2.26174808e+00, -3.47439361e+00,\n",
- " -3.31842375e+00]]], dtype=float32)\u003e]"
- ]
- },
- "execution_count": 21,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds_ia = tfd.JointDistributionSequential([\n",
- " tfd.Normal(loc=0., scale=1.), # m\n",
- " tfd.Normal(loc=0., scale=1.), # b\n",
- " lambda b, m: tfd.Independent( # Y\n",
- " tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),\n",
- " reinterpreted_batch_ndims=1)\n",
- "])\n",
- "\n",
- "ss = jds_ia.sample([5, 3])\n",
- "ss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {
- "id": "8fsYEy6Fla0o"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
- "array([[-13.1261215, -13.386831 , -14.021704 ],\n",
- " [-15.311695 , -11.299437 , -13.955756 ],\n",
- " [-11.30703 , -14.554242 , -12.285254 ],\n",
- " [-13.204155 , -15.515974 , -11.195538 ],\n",
- " [-12.403549 , -12.712912 , -9.4606905]], dtype=float32)\u003e"
- ]
- },
- "execution_count": 22,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds_ia.log_prob(ss)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "6ArLyKqJtY3Z"
- },
- "source": [
- "As an extra check, we'll verify that the log probability for a single batch point matches what we had before:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {
- "id": "9_2lIJyJtpyW"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(), dtype=float32, numpy=0.0\u003e"
- ]
- },
- "execution_count": 23,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "(jds_ia.log_prob(ss)[3, 1] -\n",
- " jds_i.log_prob([ss[0][3, 1],\n",
- " ss[1][3, 1],\n",
- " ss[2][3, 1, :]]))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "_h7sJ2bkfOS7"
- },
- "source": [
- "\u003ca id='AutoBatching-For-The-Win'\u003e\u003c/a\u003e\n",
- "### AutoBatching For The Win\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "J7nqIUMxuKzw"
- },
- "source": [
- "Excellent! We now have a version of JointDistribution that handles all our desiderata: `log_prob` returns a scalar thanks to the use of `tfd.Independent`, and multiple samples work now that we fixed broadcasting by adding extra axes.\n",
- "\n",
- "What if I told you there was an easier, better way? There is, and it's called `JointDistributionSequentialAutoBatched` (JDSAB):"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "id": "LZtVljb0fRx2"
- },
- "outputs": [],
- "source": [
- "jds_ab = tfd.JointDistributionSequentialAutoBatched([\n",
- " tfd.Normal(loc=0., scale=1.), # m\n",
- " tfd.Normal(loc=0., scale=1.), # b\n",
- " lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y\n",
- "])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {
- "id": "gpvjnvXqu2Mk"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(), dtype=float32, numpy=-10.550432\u003e"
- ]
- },
- "execution_count": 25,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds_ab.log_prob(jds.sample())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {
- "id": "Js3luiUfns_R"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
- "array([[-16.063435 , -11.415724 , -13.347199 ],\n",
- " [-13.534442 , -20.753754 , -11.381274 ],\n",
- " [-10.44528 , -12.624834 , -10.739721 ],\n",
- " [-16.03442 , -13.358179 , -11.850428 ],\n",
- " [ -9.4756365, -11.457652 , -10.145042 ]], dtype=float32)\u003e"
- ]
- },
- "execution_count": 26,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "ss = jds_ab.sample([5, 3])\n",
- "jds_ab.log_prob(ss)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {
- "id": "v1ppa6F6bdkv"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
- "array([[0., 0., 0.],\n",
- " [0., 0., 0.],\n",
- " [0., 0., 0.],\n",
- " [0., 0., 0.],\n",
- " [0., 0., 0.]], dtype=float32)\u003e"
- ]
- },
- "execution_count": 27,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds_ab.log_prob(ss) - jds_ia.log_prob(ss)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "xy-kuUbYwFB3"
- },
- "source": [
- "How does this work? While you could attempt to [read the code](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py#L426) for a deep understanding, we'll give a brief overview which is sufficient for most use cases:\n",
- "- Recall that our first problem was that our distribution for `Y` had `batch_shape=[7]` and `event_shape=[]`, and we used `Independent` to convert the batch dimension to an event dimension. JDSAB ignores the batch shapes of component distributions; instead it treats batch shape as an overall property of the model, which is assumed to be `[]` (unless specified otherwise by setting `batch_ndims \u003e 0`). The effect is equivalent to using tfd.Independent to convert *all* batch dimensions of component distributions into event dimensions, as we did manually above.\n",
- "- Our second problem was a need to massage the shapes of `m` and `b` so that they could broadcast appropriately with `X` when creating multiple samples. With JDSAB, you write a model to generate a single sample, and we \"lift\" the entire model to generate multiple samples using TensorFlow's [vectorized_map](https://www.tensorflow.org/api_docs/python/tf/vectorized_map). (This feature is analagous to JAX's [vmap](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap).)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "jUsWfVGqJiph"
- },
- "source": [
- "Exploring the batch shape issue in more detail, we can compare the batch shapes of our original \"bad\" joint distribution `jds`, our batch-fixed distributions `jds_i` and `jds_ia`, and our autobatched `jds_ab`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {
- "id": "298I732fJDk5"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[TensorShape([]), TensorShape([]), TensorShape([7])]"
- ]
- },
- "execution_count": 28,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds.batch_shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {
- "id": "SBmdWrUuJGx0"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[TensorShape([]), TensorShape([]), TensorShape([])]"
- ]
- },
- "execution_count": 29,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds_i.batch_shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {
- "id": "vD71eqN2JMhx"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[TensorShape([]), TensorShape([]), TensorShape([])]"
- ]
- },
- "execution_count": 30,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds_ia.batch_shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {
- "id": "qHmvRcxBJOAZ"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "TensorShape([])"
- ]
- },
- "execution_count": 31,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds_ab.batch_shape"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ozegq0diJuOL"
- },
- "source": [
- "We see that the original `jds` has subdistributions with different batch shapes. `jds_i` and `jds_ia` fix this by creating subdistributions with the same (empty) batch shape. `jds_ab` has only a single (empty) batch shape."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "bMm55xqV1dz6"
- },
- "source": [
- "It's worth noting that `JointDistributionSequentialAutoBatched` offers some additional generality for free. Suppose we make the covariates `X` (and, implicitly, the observations `Y`) two-dimensional:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "metadata": {
- "id": "1WfK-XbR1tXU"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array([[ 0, 1, 2, 3, 4, 5, 6],\n",
- " [ 7, 8, 9, 10, 11, 12, 13]])"
- ]
- },
- "execution_count": 32,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "X = np.arange(14).reshape((2, 7))\n",
- "X"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "VOnnkZooSj2C"
- },
- "source": [
- "Our `JointDistributionSequentialAutoBatched` works with no changes (we need to redefine the model because the shape of `X` is cached by `jds_ab.log_prob`):"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {
- "id": "6WwMvoY71qph"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[\u003ctf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
- " array([[-1.0845535 , -1.1255777 , -0.77237695],\n",
- " [-1.2722294 , 1.9274628 , 0.75446165],\n",
- " [ 1.214832 , 2.03594 , 0.68272597],\n",
- " [-0.5651716 , 1.6402307 , 0.6128305 ],\n",
- " [-0.01167952, 1.2298371 , -1.2706645 ]], dtype=float32)\u003e,\n",
- " \u003ctf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
- " array([[-0.5194242 , 0.2823965 , -0.9434134 ],\n",
- " [ 0.43568254, -0.37366644, -1.9174438 ],\n",
- " [-0.8661425 , -1.4302185 , 0.44063085],\n",
- " [ 0.36433375, -0.38744366, 0.6491046 ],\n",
- " [ 0.91218525, 0.36210015, -0.00910723]], dtype=float32)\u003e,\n",
- " \u003ctf.Tensor: shape=(5, 3, 2, 7), dtype=float32, numpy=\n",
- " array([[[[ -0.16874743, -0.38901854, -2.40703 , -5.930318 ,\n",
- " -3.416317 , -7.0882726 , -6.631361 ],\n",
- " [ -8.920654 , -10.499766 , -10.377804 , -12.9798355 ,\n",
- " -11.721172 , -14.460028 , -14.922584 ]],\n",
- " \n",
- " [[ 0.50552297, -0.9746385 , -2.047492 , -3.0749147 ,\n",
- " -4.5619793 , -6.072114 , -5.1145515 ],\n",
- " [ -7.2961216 , -8.094927 , -10.25211 , -12.26688 ,\n",
- " -12.046576 , -15.34705 , -15.152906 ]],\n",
- " \n",
- " [[ -0.8465157 , -2.6433449 , -0.76057017, -3.1688592 ,\n",
- " -4.687352 , -5.183547 , -5.0896225 ],\n",
- " [ -6.222906 , -8.103443 , -7.795763 , -8.36684 ,\n",
- " -10.562037 , -9.326081 , -9.593762 ]]],\n",
- " \n",
- " \n",
- " [[[ 1.054948 , -2.203673 , -3.035731 , -4.800442 ,\n",
- " -5.2899976 , -5.9240775 , -6.730611 ],\n",
- " [ -6.4754405 , -7.446973 , -10.764748 , -12.194825 ,\n",
- " -11.556754 , -14.941436 , -14.943226 ]],\n",
- " \n",
- " [[ 0.87307787, 1.3859878 , 2.6136284 , 5.4836617 ,\n",
- " 5.8579865 , 10.494877 , 11.823118 ],\n",
- " [ 11.510672 , 14.746766 , 16.719799 , 18.618593 ,\n",
- " 21.580097 , 22.609585 , 25.759428 ]],\n",
- " \n",
- " [[ -2.0380569 , -2.2008557 , 0.43357986, 0.32134444,\n",
- " 0.36675143, 2.9957676 , 1.6615164 ],\n",
- " [ 3.2243397 , 3.220036 , 4.315905 , 6.7883563 ,\n",
- " 6.503477 , 8.810654 , 5.883856 ]]],\n",
- " \n",
- " \n",
- " [[[ -0.477881 , 1.4766507 , 1.5208708 , 3.147714 ,\n",
- " 2.9273605 , 5.7710776 , 7.128166 ],\n",
- " [ 7.3486524 , 7.48754 , 8.853534 , 11.846103 ,\n",
- " 13.041363 , 12.164903 , 13.826527 ]],\n",
- " \n",
- " [[ -2.935304 , 0.5696763 , 2.1498902 , 6.319368 ,\n",
- " 7.923173 , 8.151863 , 11.570858 ],\n",
- " [ 14.339904 , 14.18277 , 18.049622 , 19.047941 ,\n",
- " 22.653297 , 25.26222 , 25.464987 ]],\n",
- " \n",
- " [[ 1.0329808 , -0.10444701, 0.99885136, 2.5327475 ,\n",
- " 2.0721416 , 1.9450207 , 4.6753073 ],\n",
- " [ 6.184873 , 8.452423 , 7.8260746 , 7.713975 ,\n",
- " 7.0077796 , 10.046227 , 10.1453085 ]]],\n",
- " \n",
- " \n",
- " [[[ 0.3361371 , -0.62899804, 1.2562443 , -1.935529 ,\n",
- " -1.4381697 , -1.5268946 , -3.8008852 ],\n",
- " [ -4.1968484 , -6.028409 , -4.970623 , -4.9823346 ,\n",
- " -5.6923776 , -6.535574 , -5.5532475 ]],\n",
- " \n",
- " [[ -2.0243526 , 3.3777661 , 0.97641647, 4.6852875 ,\n",
- " 7.6430597 , 5.8280125 , 9.0458555 ],\n",
- " [ 10.250172 , 12.831018 , 13.659218 , 16.075794 ,\n",
- " 16.925209 , 16.90435 , 19.38226 ]],\n",
- " \n",
- " [[ 1.2758106 , 0.83274007, 2.1775467 , 3.1251085 ,\n",
- " 3.9337432 , 2.543648 , 5.1000204 ],\n",
- " [ 5.8442574 , 6.0312934 , 6.379141 , 8.768039 ,\n",
- " 9.291983 , 8.260785 , 8.451964 ]]],\n",
- " \n",
- " \n",
- " [[[ 3.0444725 , 0.73759735, 2.5216937 , 0.04277879,\n",
- " 0.9555798 , -0.614954 , 1.0725826 ],\n",
- " [ 3.0648081 , 1.0510775 , 0.9096012 , 0.28714108,\n",
- " 1.4371622 , 2.1362674 , 1.9903467 ]],\n",
- " \n",
- " [[ 0.05708131, 1.2491966 , 1.9845967 , 3.4259818 ,\n",
- " 5.5484996 , 7.8822956 , 7.0572023 ],\n",
- " [ 9.535346 , 11.390023 , 10.360718 , 12.881494 ,\n",
- " 11.301062 , 13.86196 , 16.829353 ]],\n",
- " \n",
- " [[ -0.5573631 , -1.0938222 , -3.0080914 , -3.1928232 ,\n",
- " -4.713949 , -7.016099 , -6.185412 ],\n",
- " [ -8.42309 , -9.375599 , -10.624992 , -11.47895 ,\n",
- " -14.62926 , -14.905938 , -18.084822 ]]]], dtype=float32)\u003e]"
- ]
- },
- "execution_count": 33,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds_ab = tfd.JointDistributionSequentialAutoBatched([\n",
- " tfd.Normal(loc=0., scale=1.), # m\n",
- " tfd.Normal(loc=0., scale=1.), # b\n",
- " lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y\n",
- "])\n",
- "\n",
- "ss = jds_ab.sample([5, 3])\n",
- "ss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {
- "id": "GLvHMTpnSyvH"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "\u003ctf.Tensor: shape=(5, 3), dtype=float32, numpy=\n",
- "array([[-23.592081, -20.392092, -20.310911],\n",
- " [-25.823744, -22.132751, -23.761002],\n",
- " [-21.39077 , -27.747965, -25.098429],\n",
- " [-21.14306 , -29.653296, -21.353765],\n",
- " [-24.754295, -23.107279, -20.329145]], dtype=float32)\u003e"
- ]
- },
- "execution_count": 34,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jds_ab.log_prob(ss)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "AI40r2oETnVP"
- },
- "source": [
- "On the other hand, our carefully crafted `JointDistributionSequential` no longer works:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "metadata": {
- "id": "tfYkdBIi0wJl"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]\n"
- ]
- }
- ],
- "source": [
- "jds_ia = tfd.JointDistributionSequential([\n",
- " tfd.Normal(loc=0., scale=1.), # m\n",
- " tfd.Normal(loc=0., scale=1.), # b\n",
- " lambda b, m: tfd.Independent( # Y\n",
- " tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),\n",
- " reinterpreted_batch_ndims=1)\n",
- "])\n",
- "\n",
- "try:\n",
- " jds_ia.sample([5, 3])\n",
- "except tf.errors.InvalidArgumentError as e:\n",
- " print(e)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "WLERQvFNTwQJ"
- },
- "source": [
- "To fix this, we'd have to add a second `tf.newaxis` to both `m` and `b` match the shape, and increase `reinterpreted_batch_ndims` to 2 in the call to `Independent`. In this case, letting the auto-batching machinery handle the shape issues is shorter, easier, and more ergonomic."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "HIgCF6yJXpHE"
- },
- "source": [
- "Once again, we note that while this notebook explored `JointDistributionSequentialAutoBatched`, the other variants of `JointDistribution` have equivalent `AutoBatched`. (For users of `JointDistributionCoroutine`, `JointDistributionCoroutineAutoBatched` has the additional benefit that you no longer need to specify `Root` nodes; if you've never used `JointDistributionCoroutine` you can safely ignore this statement.)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "mHacIM0iUW09"
- },
- "source": [
- "### Concluding Thoughts"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "kXAC7GDWUaaY"
- },
- "source": [
- "In this notebook, we introduced `JointDistributionSequentialAutoBatched` and worked through a simple example in detail. Hopefully you learned something about TFP shapes and about autobatching!"
- ]
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "i3puWgvKeyWu"
+ },
+ "source": [
+ "# Auto-Batched Joint Distributions: A Gentle Tutorial"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZrwVQsM9TiUw"
+ },
+ "source": [
+ "##### Copyright 2020 The TensorFlow Authors.\n",
+ "\n",
+ "Licensed under the Apache License, Version 2.0 (the \"License\");"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "CpDUTVKYTowI"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ltPJCG6pAUoc"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zzaOJSXagzMY"
+ },
+ "source": [
+ "### Introduction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cIvB2CSBe49Z"
+ },
+ "source": [
+ "TensorFlow Probability (TFP) offers a number of `JointDistribution` abstractions that make probabilistic inference easier by allowing a user to easily express a probabilistic graphical model in a near-mathematical form; the abstraction generates methods for sampling from the model and evaluating the log probability of samples from the model. In this tutorial, we review \"autobatched\" variants, which were developed after the original `JointDistribution` abstractions. Relative to the original, non-autobatched abstractions, the autobatched versions are simpler to use and more ergonomic, allowing many models to be expressed with less boilerplate. In this colab, we explore a simple model in (perhaps tedious) detail, making clear the problems autobatching solves, and (hopefully) teaching the reader more about TFP shape concepts along the way.\n",
+ "\n",
+ "Prior to the introduction of autobatching, there were a few different variants of `JointDistribution`, corresponding to different syntactic styles for expressing probabilistic models: `JointDistributionSequential`, `JointDistributionNamed`, and`JointDistributionCoroutine`. Auobatching exists as a mixin, so we now have `AutoBatched` variants of all of these. In this tutorial, we explore the differences between `JointDistributionSequential` and `JointDistributionSequentialAutoBatched`; however, everything we do here is applicable to the other variants with essentially no changes.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uiR4-VOt9NFX"
+ },
+ "source": [
+ "### Dependencies & Prerequisites\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "coUnDhkpT5_6"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Import and set ups{ display-mode: \"form\" }\n",
+ "\n",
+ "import functools\n",
+ "import numpy as np\n",
+ "\n",
+ "import tensorflow.compat.v2 as tf\n",
+ "tf.enable_v2_behavior()\n",
+ "\n",
+ "import tensorflow_probability as tfp\n",
+ "\n",
+ "tfd = tfp.distributions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KohBmaTn5W7I"
+ },
+ "source": [
+ "### Prerequisite: A Bayesian Regression Problem"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vChyK0vr9XD8"
+ },
+ "source": [
+ "We'll consider a very simple Bayesian regression scenario:\n",
+ "\n",
+ "$$\n",
+ "\\begin{align*}\n",
+ "m & \\sim \\text{Normal}(0, 1) \\\\\n",
+ "b & \\sim \\text{Normal}(0, 1) \\\\\n",
+ "Y & \\sim \\text{Normal}(mX + b, 1)\n",
+ "\\end{align*}\n",
+ "$$\n",
+ "\n",
+ "In this model, `m` and `b` are drawn from standard normals, and the observations `Y` are drawn from a normal distribution whose mean depends on the random variables `m` and `b`, and some (nonrandom, known) covariates `X`. (For simplicity, in this example, we assume the scale of all random variables is known.)\n",
+ "\n",
+ "To perform inference in this model, we'd need to know both the covariates `X` and the observations `Y`, but for the purposes of this tutorial, we'll only need `X`, so we define a simple dummy `X`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "UIpJ_cXUVabB",
+ "outputId": "15ab9b27-586c-46d2-f9a1-ae36cfcec736"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0, 1, 2, 3, 4, 5, 6])"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X = np.arange(7)\n",
+ "X"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CIBpupyt9GTT"
+ },
+ "source": [
+ "### Desiderata"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "j2uzL_uI9tqO"
+ },
+ "source": [
+ "In probabilistic inference, we often want to perform two basic operations:\n",
+ "- `sample`: Drawing samples from the model.\n",
+ "- `log_prob`: Computing the log probability of a sample from the model.\n",
+ "\n",
+ "The key contribution of TFP's `JointDistribution` abstractions (as well as of many other approaches to probabilistic programming) is to allow users to write a model *once* and have access to both `sample` and `log_prob` computations.\n",
+ "\n",
+ "Noting that we have 7 points in our data set (`X.shape = (7,)`), we can now state the desiderata for an excellent `JointDistribution`:\n",
+ "\n",
+ "* `sample()` should produce a list of `Tensors` having shape `[(), (), (7,)`], corresponding to the scalar slope, scalar bias, and vector observations, respectively.\n",
+ "* `log_prob(sample())` should produce a scalar: the log probability of a particular slope, bias, and observations.\n",
+ "* `sample([5, 3])` should produce a list of `Tensors` having shape `[(5, 3), (5, 3), (5, 3, 7)]`, representing a `(5, 3)`-*batch* of samples from the model.\n",
+ "* `log_prob(sample([5, 3]))` should produce a `Tensor` with shape (5, 3).\n",
+ "\n",
+ "We'll now look at a succession of `JointDistribution` models, see how to achieve the above desiderata, and hopefully learn a little more about TFP shapes along the way. \n",
+ "\n",
+ "Spoiler alert: The approach that satisfies the above desiderata without added boilerplate is [autobatching](#scrollTo=_h7sJ2bkfOS7). "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QiII0ypZcyTY"
+ },
+ "source": [
+ "### First Attempt; `JointDistributionSequential`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "kY501q-QVR9g"
+ },
+ "outputs": [],
+ "source": [
+ "jds = tfd.JointDistributionSequential([\n",
+ " tfd.Normal(loc=0., scale=1.), # m\n",
+ " tfd.Normal(loc=0., scale=1.), # b\n",
+ " lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y\n",
+ "])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hzNPPqJ-BwA-"
+ },
+ "source": [
+ "This is more or less a direct translation of the model into code. The slope `m` and bias `b` are straightforward. `Y` is defined using a `lambda`-function: the general pattern is that a `lambda`-function of $k$ arguments in a `JointDistributionSequential` (JDS) uses the previous $k$ distributions in the model. Note the \"reverse\" order."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5jIvsQSOD81N"
+ },
+ "source": [
+ "We'll call `sample_distributions`, which returns both a sample *and* the underlying \"sub-distributions\" that were used to generate the sample. (We could have produced just the sample by calling `sample`; later in the tutorial it will be convenient to have the distributions as well.) The sample we produce is fine:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "y05IrsfiaxCh",
+ "outputId": "aef09be2-527f-4352-a968-a314ebb48faa"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[,\n",
+ " ,\n",
+ " ]"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
}
- ],
- "metadata": {
- "colab": {
- "collapsed_sections": [],
- "name": "JointDistributionAutoBatched_A_Gentle_Tutorial.ipynb",
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
+ ],
+ "source": [
+ "dists, sample = jds.sample_distributions()\n",
+ "sample"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "o7E1WkoCEB12"
+ },
+ "source": [
+ "But `log_prob` produces a result with an undesired shape:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "xR0lbgjNay4X",
+ "outputId": "ce642821-2450-4bc0-b65b-d0c94f1dd15f"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
}
+ ],
+ "source": [
+ "jds.log_prob(sample)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1mMIs28LEJqN"
+ },
+ "source": [
+ "And multiple sampling doesn't work:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "LbfRiIsfc9Hf",
+ "outputId": "d93c1e83-9623-4ead-b50a-272d8f27ceae"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Incompatible shapes: [5,3] vs. [7] [Op:Mul]\n"
+ ]
+ }
+ ],
+ "source": [
+ "try:\n",
+ " jds.sample([5, 3])\n",
+ "except tf.errors.InvalidArgumentError as e:\n",
+ " print(e)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Rnvtz3SQHrVL"
+ },
+ "source": [
+ "Let's try to understand what's going wrong."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Dp30JPCmHyuz"
+ },
+ "source": [
+ "### A Brief Review: Batch and Event Shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "w24fZn3kH2uF"
+ },
+ "source": [
+ "In TFP, an ordinary (not a `JointDistribution`) probability distribution has an *event shape* and a *batch shape*, and understanding the difference is crucial to effective use of TFP:\n",
+ "\n",
+ "* Event shape describes the shape of a single draw from the distribution; the draw may be dependent across dimensions. For scalar distributions, the event shape is []. For a 5-dimensional MultivariateNormal, the event shape is [5].\n",
+ "* Batch shape describes independent, not identically distributed draws, aka a \"batch\" of distributions. Representing a batch of distributions in a single Python object is one of the key ways TFP achieves efficiency at scale.\n",
+ "\n",
+ "For our purposes, a critical fact to keep in mind is that if we call `log_prob` on a single sample from a distribution, the result will always have a shape that matches (i.e., has as rightmost dimensions) the *batch* shape.\n",
+ "\n",
+ "For a more in-depth discussion of shapes, see [the \"Understanding TensorFlow Distributions Shapes\" tutorial](https://www.tensorflow.org/probability/examples/Understanding_TensorFlow_Distributions_Shapes).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nONZMjl-KtTz"
+ },
+ "source": [
+ "### Why Doesn't `log_prob(sample())` Produce a Scalar? "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VUKyGzkOJiuD"
+ },
+ "source": [
+ "Let's use our knowledge of batch and event shape to explore what's happening with `log_prob(sample())`. Here's our sample again:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "ijRGAnSBJwCG",
+ "outputId": "807fdfa9-b05f-4188-c544-2614c7210e4c"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[,\n",
+ " ,\n",
+ " ]"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sample"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NAzBAsu3OoLv"
+ },
+ "source": [
+ "And here are our distributions:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "_xtIUKf8Nq3G",
+ "outputId": "aecd827f-d06e-45d1-c3de-845c8ccd5b0d"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[,\n",
+ " ,\n",
+ " ]"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dists"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LzkLnoZyFeU_"
+ },
+ "source": [
+ "The log probability is computed by summing the log probabilities of the sub-distributions at the (matched) elements of the parts:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "5XTDKVMPO5qg",
+ "outputId": "ac4e837a-9f43-433a-afea-83e9a4c4e0b1"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[,\n",
+ " ,\n",
+ " ]"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "log_prob_parts = [dist.log_prob(s) for (dist, s) in zip(dists, sample)]\n",
+ "log_prob_parts"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "QoWsVGx8N1IJ",
+ "outputId": "369342f9-cc85-4be1-8ed9-3b0c6cf04373"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "np.sum(log_prob_parts) - jds.log_prob(sample)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZJFvR4ZNFngd"
+ },
+ "source": [
+ "So, one level of explanation is that the log probability calculation is returning a 7-Tensor because the third subcomponent of `log_prob_parts` is a 7-Tensor. But why?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zdpKnguOPOrr"
+ },
+ "source": [
+ "Well, we see that the last element of `dists`, which corresponds to our distribution over `Y` in the mathematial formulation, has a `batch_shape` of `[7]`. In other words, our distribution over `Y` is a batch of 7 independent normals (with different means and, in this case, the same scale)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0WXzlR_diTuZ"
+ },
+ "source": [
+ "We now understand what's wrong: in JDS, the distribution over `Y` has `batch_shape=[7]`, a sample from the JDS represents scalars for `m` and `b` and a \"batch\" of 7 independent normals. and `log_prob` computes 7 separate log-probabilities, each of which represents the log probability of drawing `m` and `b` and a single observation `Y[i]` at some `X[i]`."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "s9RI0oxCi_En"
+ },
+ "source": [
+ "### Fixing `log_prob(sample())` with `Independent`"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EOL1hllzjDcF"
+ },
+ "source": [
+ "Recall that `dists[2]` has `event_shape=[]` and `batch_shape=[7]`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "TA05J9VwjCLu",
+ "outputId": "b795bb88-8806-42a6-b5c3-d6c68bf8133e"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dists[2]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_xQ5ORIqjPAz"
+ },
+ "source": [
+ "By using TFP's `Independent` metadistribution, which converts batch dimensions to event dimensions, we can convert this into a distribution with `event_shape=[7]` and `batch_shape=[]` (we'll rename it `y_dist_i` because it's a distribution on `Y`, with the `_i` standing in for our `Independent` wrapping): "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "Aa_SPItTjLBO",
+ "outputId": "901d1d16-a7b2-4e16-de9e-7c051db73d37"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_dist_i = tfd.Independent(dists[2], reinterpreted_batch_ndims=1)\n",
+ "y_dist_i"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JrRjuDhhmBEr"
+ },
+ "source": [
+ "Now, the `log_prob` of a 7-vector is a scalar:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "y9yZs-kwdLGa",
+ "outputId": "787974e2-fa3b-4532-ecda-6e313cd62b80"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_dist_i.log_prob(sample[2])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RqNEen4Ujkhh"
+ },
+ "source": [
+ "Under the covers, `Independent` sums over the batch:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "SxYr1McJkWFx",
+ "outputId": "54b64a21-c3ee-41ca-97f0-91000245bbf5"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_dist_i.log_prob(sample[2]) - tf.reduce_sum(dists[2].log_prob(sample[2]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "00lD003YkojA"
+ },
+ "source": [
+ "And indeed, we can use this to construct a new `jds_i` (the `i` again stands for `Independent`) where `log_prob` returns a scalar:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "1jwoSeNWkhT6",
+ "outputId": "7d7838a4-93a0-4866-ecc0-e2413d4fc972"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds_i = tfd.JointDistributionSequential([\n",
+ " tfd.Normal(loc=0., scale=1.), # m\n",
+ " tfd.Normal(loc=0., scale=1.), # b\n",
+ " lambda b, m: tfd.Independent( # Y\n",
+ " tfd.Normal(loc=m*X + b, scale=1.),\n",
+ " reinterpreted_batch_ndims=1)\n",
+ "])\n",
+ "\n",
+ "jds_i.log_prob(sample)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hYY3CNBXlAIZ"
+ },
+ "source": [
+ "A couple notes:\n",
+ "- `jds_i.log_prob(s)` is *not* the same as `tf.reduce_sum(jds.log_prob(s))`. The former produces the \"correct\" log probability of the joint distribution. The latter sums over a 7-Tensor, each element of which is the sum of the log probability of `m`, `b`, and a single element of the log probability of `Y`, so it overcounts `m` and `b`. (`log_prob(m) + log_prob(b) + log_prob(Y)` returns a result rather than throwing an exception because TFP follows TF and NumPy's broadcasting rules; adding a scalar to a vector produces a vector-sized result.)\n",
+ "- In this particular case, we could have solved the problem and achieved the same result using `MultivariateNormalDiag` instead of `Independent(Normal(...))`. `MultivariateNormalDiag` is a vector-valued distribution (i.e., it already has vector event-shape). Indeeed `MultivariateNormalDiag` could be (but isn't) implemented as a composition of `Independent` and `Normal`. It's worthwhile to remember that given a vector `V`, samples from `n1 = Normal(loc=V)`, and `n2 = MultivariateNormalDiag(loc=V)` are indistinguishable; the difference beween these distributions is that `n1.log_prob(n1.sample())` is a vector and `n2.log_prob(n2.sample())` is a scalar."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b-iFi65ZmvpB"
+ },
+ "source": [
+ "### Multiple Samples?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PZcEBJS_nAhA"
+ },
+ "source": [
+ "Drawing multiple samples still doesn't work:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "PkvYmB3jm2sI",
+ "outputId": "6fdf7cbe-b08e-4a44-a171-99dca44b2c77"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Incompatible shapes: [5,3] vs. [7] [Op:Mul]\n"
+ ]
+ }
+ ],
+ "source": [
+ "try:\n",
+ " jds_i.sample([5, 3])\n",
+ "except tf.errors.InvalidArgumentError as e:\n",
+ " print(e)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b9Jh0MTCn0Mr"
+ },
+ "source": [
+ "Let's think about why. When we call `jds_i.sample([5, 3])`, we'll first draw samples for `m` and `b`, each with shape `(5, 3)`. Next, we're going to try to construct a `Normal` distribution via:\n",
+ "```\n",
+ "tfd.Normal(loc=m*X + b, scale=1.)\n",
+ "```\n",
+ "\n",
+ "But if `m` has shape `(5, 3)` and `X` has shape `7`, we can't multiply them together, and indeed this is the error we're hitting:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "ei9Z2Nozp8Dy",
+ "outputId": "75c788b6-9d43-4cbc-db6b-a3bf88394c79"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Incompatible shapes: [5,3] vs. [7] [Op:Mul]\n"
+ ]
+ }
+ ],
+ "source": [
+ "m = tfd.Normal(0., 1.).sample([5, 3])\n",
+ "try:\n",
+ " m * X\n",
+ "except tf.errors.InvalidArgumentError as e:\n",
+ " print(e)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1uqaIx2LlaeP"
+ },
+ "source": [
+ "To resolve this issue, let's think about what properties the distribution over `Y` has to have. If we've called `jds_i.sample([5, 3])`, then we know `m` and `b` will both have shape `(5, 3)`. What shape should a call to `sample` on the `Y` distribution produce? The obvious answer is `(5, 3, 7)`: for each batch point, we want a sample with the same size as `X`. We can achieve this by using TensorFlow's broadcasting capabilities, adding extra dimensions:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "-22Bg8Yfr6tg",
+ "outputId": "ac7b7cfe-7a92-438f-9f39-0c5ae99b230c"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TensorShape([5, 3, 1])"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "m[..., tf.newaxis].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "7k21MOvlsHGe",
+ "outputId": "d981d02b-fc70-4657-e145-5d7e996ff7d2"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TensorShape([5, 3, 7])"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "(m[..., tf.newaxis] * X).shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5AEBbcjVsXQR"
+ },
+ "source": [
+ "Adding an axis to both `m` and `b`, we can define a new JDS that supports multiple samples:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "9rJ9WCVQsW0S",
+ "outputId": "a91f10f4-1180-42ef-99e2-115153a99414"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[,\n",
+ " ,\n",
+ " ]"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds_ia = tfd.JointDistributionSequential([\n",
+ " tfd.Normal(loc=0., scale=1.), # m\n",
+ " tfd.Normal(loc=0., scale=1.), # b\n",
+ " lambda b, m: tfd.Independent( # Y\n",
+ " tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),\n",
+ " reinterpreted_batch_ndims=1)\n",
+ "])\n",
+ "\n",
+ "shaped_sample = jds_ia.sample([5, 3])\n",
+ "shaped_sample"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "8fsYEy6Fla0o",
+ "outputId": "3dc6c656-d990-4fc5-ddf9-0ec8d92cd7af"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds_ia.log_prob(shaped_sample)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6ArLyKqJtY3Z"
+ },
+ "source": [
+ "As an extra check, we'll verify that the log probability for a single batch point matches what we had before:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "9_2lIJyJtpyW",
+ "outputId": "ad08f15a-05e0-4f20-a157-bfa4834378c8"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "(jds_ia.log_prob(shaped_sample)[3, 1] -\n",
+ " jds_i.log_prob([shaped_sample[0][3, 1],\n",
+ " shaped_sample[1][3, 1],\n",
+ " shaped_sample[2][3, 1, :]]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_h7sJ2bkfOS7"
+ },
+ "source": [
+ "\n",
+ "### AutoBatching For The Win\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "J7nqIUMxuKzw"
+ },
+ "source": [
+ "Excellent! We now have a version of JointDistribution that handles all our desiderata: `log_prob` returns a scalar thanks to the use of `tfd.Independent`, and multiple samples work now that we fixed broadcasting by adding extra axes.\n",
+ "\n",
+ "What if I told you there was an easier, better way? There is, and it's called `JointDistributionSequentialAutoBatched` (JDSAB):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "LZtVljb0fRx2"
+ },
+ "outputs": [],
+ "source": [
+ "jds_ab = tfd.JointDistributionSequentialAutoBatched([\n",
+ " tfd.Normal(loc=0., scale=1.), # m\n",
+ " tfd.Normal(loc=0., scale=1.), # b\n",
+ " lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y\n",
+ "])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "gpvjnvXqu2Mk",
+ "outputId": "8cdc7f45-fe5d-4640-b413-ee5a03f35f78"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds_ab.log_prob(jds.sample())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "Js3luiUfns_R",
+ "outputId": "b864c734-5667-43f6-fc38-bfc4fcb5ce1c"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "shaped_sample = jds_ab.sample([5, 3])\n",
+ "jds_ab.log_prob(shaped_sample)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "v1ppa6F6bdkv",
+ "outputId": "ccd1b9eb-3f80-416a-90ee-d190b4776fe1"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds_ab.log_prob(shaped_sample) - jds_ia.log_prob(shaped_sample)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xy-kuUbYwFB3"
+ },
+ "source": [
+ "How does this work? While you could attempt to [read the code](https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py#L426) for a deep understanding, we'll give a brief overview which is sufficient for most use cases:\n",
+ "- Recall that our first problem was that our distribution for `Y` had `batch_shape=[7]` and `event_shape=[]`, and we used `Independent` to convert the batch dimension to an event dimension. JDSAB ignores the batch shapes of component distributions; instead it treats batch shape as an overall property of the model, which is assumed to be `[]` (unless specified otherwise by setting `batch_ndims > 0`). The effect is equivalent to using tfd.Independent to convert *all* batch dimensions of component distributions into event dimensions, as we did manually above.\n",
+ "- Our second problem was a need to massage the shapes of `m` and `b` so that they could broadcast appropriately with `X` when creating multiple samples. With JDSAB, you write a model to generate a single sample, and we \"lift\" the entire model to generate multiple samples using TensorFlow's [vectorized_map](https://www.tensorflow.org/api_docs/python/tf/vectorized_map). (This feature is analagous to JAX's [vmap](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#Auto-vectorization-with-vmap).)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jUsWfVGqJiph"
+ },
+ "source": [
+ "Exploring the batch shape issue in more detail, we can compare the batch shapes of our original \"bad\" joint distribution `jds`, our batch-fixed distributions `jds_i` and `jds_ia`, and our autobatched `jds_ab`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "298I732fJDk5",
+ "outputId": "b97e2ea3-e89a-4922-ee20-59544c5134d7"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[TensorShape([]), TensorShape([]), TensorShape([7])]"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds.batch_shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "SBmdWrUuJGx0",
+ "outputId": "7101d876-68fc-4f7c-8ac9-47fa3cb3a5a0"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[TensorShape([]), TensorShape([]), TensorShape([])]"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds_i.batch_shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "vD71eqN2JMhx",
+ "outputId": "fac0f5bc-c6aa-4ac2-b9ac-7e7b101828d7"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[TensorShape([]), TensorShape([]), TensorShape([])]"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds_ia.batch_shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "qHmvRcxBJOAZ",
+ "outputId": "fc21534e-f019-46f8-d136-82afa055e0ff"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TensorShape([])"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds_ab.batch_shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ozegq0diJuOL"
+ },
+ "source": [
+ "We see that the original `jds` has subdistributions with different batch shapes. `jds_i` and `jds_ia` fix this by creating subdistributions with the same (empty) batch shape. `jds_ab` has only a single (empty) batch shape."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "bMm55xqV1dz6"
+ },
+ "source": [
+ "It's worth noting that `JointDistributionSequentialAutoBatched` offers some additional generality for free. Suppose we make the covariates `X` (and, implicitly, the observations `Y`) two-dimensional:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "1WfK-XbR1tXU",
+ "outputId": "b76cfaaa-acc7-4413-ed42-0db2b8f539bd"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[ 0, 1, 2, 3, 4, 5, 6],\n",
+ " [ 7, 8, 9, 10, 11, 12, 13]])"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "X = np.arange(14).reshape((2, 7))\n",
+ "X"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "VOnnkZooSj2C"
+ },
+ "source": [
+ "Our `JointDistributionSequentialAutoBatched` works with no changes (we need to redefine the model because the shape of `X` is cached by `jds_ab.log_prob`):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "6WwMvoY71qph",
+ "outputId": "66ac655d-de0f-4647-a490-443467cc1555"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[,\n",
+ " ,\n",
+ " ]"
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds_ab = tfd.JointDistributionSequentialAutoBatched([\n",
+ " tfd.Normal(loc=0., scale=1.), # m\n",
+ " tfd.Normal(loc=0., scale=1.), # b\n",
+ " lambda b, m: tfd.Normal(loc=m*X + b, scale=1.) # Y\n",
+ "])\n",
+ "\n",
+ "shaped_sample = jds_ab.sample([5, 3])\n",
+ "shaped_sample"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "GLvHMTpnSyvH",
+ "outputId": "736817b2-ca10-48b5-8618-bbd892cb6f34"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 0,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "jds_ab.log_prob(shaped_sample)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AI40r2oETnVP"
+ },
+ "source": [
+ "On the other hand, our carefully crafted `JointDistributionSequential` no longer works:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "id": "tfYkdBIi0wJl",
+ "outputId": "517af7fb-913a-4ac5-cf0c-c758e9942bbd"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Incompatible shapes: [5,3,1] vs. [2,7] [Op:Mul]\n"
+ ]
+ }
+ ],
+ "source": [
+ "jds_ia = tfd.JointDistributionSequential([\n",
+ " tfd.Normal(loc=0., scale=1.), # m\n",
+ " tfd.Normal(loc=0., scale=1.), # b\n",
+ " lambda b, m: tfd.Independent( # Y\n",
+ " tfd.Normal(loc=m[..., tf.newaxis]*X + b[..., tf.newaxis], scale=1.),\n",
+ " reinterpreted_batch_ndims=1)\n",
+ "])\n",
+ "\n",
+ "try:\n",
+ " jds_ia.sample([5, 3])\n",
+ "except tf.errors.InvalidArgumentError as e:\n",
+ " print(e)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WLERQvFNTwQJ"
+ },
+ "source": [
+ "To fix this, we'd have to add a second `tf.newaxis` to both `m` and `b` match the shape, and increase `reinterpreted_batch_ndims` to 2 in the call to `Independent`. In this case, letting the auto-batching machinery handle the shape issues is shorter, easier, and more ergonomic."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HIgCF6yJXpHE"
+ },
+ "source": [
+ "Once again, we note that while this notebook explored `JointDistributionSequentialAutoBatched`, the other variants of `JointDistribution` have equivalent `AutoBatched`. (For users of `JointDistributionCoroutine`, `JointDistributionCoroutineAutoBatched` has the additional benefit that you no longer need to specify `Root` nodes; if you've never used `JointDistributionCoroutine` you can safely ignore this statement.)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "mHacIM0iUW09"
+ },
+ "source": [
+ "### Concluding Thoughts"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "kXAC7GDWUaaY"
+ },
+ "source": [
+ "In this notebook, we introduced `JointDistributionSequentialAutoBatched` and worked through a simple example in detail. Hopefully you learned something about TFP shapes and about autobatching!"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "JointDistributionAutoBatched_A_Gentle_Tutorial.ipynb",
+ "toc_visible": true
},
- "nbformat": 4,
- "nbformat_minor": 0
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}
diff --git a/tensorflow_probability/python/__init__.py b/tensorflow_probability/python/__init__.py
index 601c74228f..345dbe2511 100644
--- a/tensorflow_probability/python/__init__.py
+++ b/tensorflow_probability/python/__init__.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import functools
+import types
from tensorflow_probability.python.internal import all_util
from tensorflow_probability.python.internal import lazy_loader
@@ -77,6 +78,24 @@ def _validate_tf_environment(package):
)
+# Declare these explicitly to appease pytype, which otherwise misses them,
+# presumably due to lazy loading.
+bijectors: types.ModuleType
+debugging: types.ModuleType
+distributions: types.ModuleType
+experimental: types.ModuleType
+glm: types.ModuleType
+layers: types.ModuleType
+math: types.ModuleType
+mcmc: types.ModuleType
+monte_carlo: types.ModuleType
+optimizer: types.ModuleType
+random: types.ModuleType
+stats: types.ModuleType
+sts: types.ModuleType
+util: types.ModuleType
+vi: types.ModuleType
+
_allowed_symbols = [
'bijectors',
'debugging',
diff --git a/tensorflow_probability/python/bijectors/BUILD b/tensorflow_probability/python/bijectors/BUILD
index 2ecf6e472b..0df4e10958 100644
--- a/tensorflow_probability/python/bijectors/BUILD
+++ b/tensorflow_probability/python/bijectors/BUILD
@@ -72,6 +72,7 @@ multi_substrate_py_library(
":inline",
":invert",
":iterated_sigmoid_centered",
+ ":joint_map",
":kumaraswamy_cdf",
":lambertw_transform",
":masked_autoregressive",
@@ -88,6 +89,7 @@ multi_substrate_py_library(
":real_nvp",
":reciprocal",
":reshape",
+ ":restructure",
":scale",
":scale_matvec_diag",
":scale_matvec_linear_operator",
@@ -117,6 +119,7 @@ multi_substrate_py_library(
srcs = [
"bijector.py",
"chain.py",
+ "composition.py",
],
deps = [
# numpy dep,
@@ -317,10 +320,21 @@ multi_substrate_py_library(
# tensorflow dep,
"//tensorflow_probability/python/internal:distribution_util",
"//tensorflow_probability/python/internal:dtype_util",
+ "//tensorflow_probability/python/internal:nest_util",
"//tensorflow_probability/python/internal:tensorshape_util",
],
)
+multi_substrate_py_library(
+ name = "joint_map",
+ srcs = ["joint_map.py"],
+ deps = [
+ ":bijector",
+ # tensorflow dep,
+ "//tensorflow_probability/python/internal:nest_util",
+ ],
+)
+
multi_substrate_py_library(
name = "cholesky_outer_product",
srcs = ["cholesky_outer_product.py"],
@@ -770,6 +784,17 @@ multi_substrate_py_library(
],
)
+multi_substrate_py_library(
+ name = "restructure",
+ srcs = ["restructure.py"],
+ deps = [
+ ":bijector",
+ # tensorflow dep,
+ "//tensorflow_probability/python/internal:nest_util",
+ "//tensorflow_probability/python/internal:prefer_static",
+ ],
+)
+
multi_substrate_py_library(
name = "fill_scale_tril",
srcs = ["fill_scale_tril.py"],
@@ -1471,6 +1496,19 @@ multi_substrate_py_test(
],
)
+multi_substrate_py_test(
+ name = "joint_map_test",
+ size = "small",
+ srcs = ["joint_map_test.py"],
+ deps = [
+ ":bijector_test_util",
+ ":bijectors",
+ # numpy dep,
+ # tensorflow dep,
+ "//tensorflow_probability/python/internal:test_util",
+ ],
+)
+
multi_substrate_py_test(
name = "kumaraswamy_cdf_test",
size = "small",
@@ -1726,6 +1764,18 @@ multi_substrate_py_test(
],
)
+multi_substrate_py_test(
+ name = "restructure_test",
+ size = "small",
+ srcs = ["restructure_test.py"],
+ deps = [
+ ":bijector_test_util",
+ ":bijectors",
+ # tensorflow dep,
+ "//tensorflow_probability/python/internal:test_util",
+ ],
+)
+
multi_substrate_py_test(
name = "fill_scale_tril_test",
size = "small",
diff --git a/tensorflow_probability/python/bijectors/__init__.py b/tensorflow_probability/python/bijectors/__init__.py
index 86eaf037cc..863832c2fe 100644
--- a/tensorflow_probability/python/bijectors/__init__.py
+++ b/tensorflow_probability/python/bijectors/__init__.py
@@ -31,6 +31,7 @@
from tensorflow_probability.python.bijectors.chain import Chain
from tensorflow_probability.python.bijectors.cholesky_outer_product import CholeskyOuterProduct
from tensorflow_probability.python.bijectors.cholesky_to_inv_cholesky import CholeskyToInvCholesky
+from tensorflow_probability.python.bijectors.composition import Composition
from tensorflow_probability.python.bijectors.correlation_cholesky import CorrelationCholesky
from tensorflow_probability.python.bijectors.cumsum import Cumsum
from tensorflow_probability.python.bijectors.discrete_cosine_transform import DiscreteCosineTransform
@@ -53,6 +54,7 @@
from tensorflow_probability.python.bijectors.inline import Inline
from tensorflow_probability.python.bijectors.invert import Invert
from tensorflow_probability.python.bijectors.iterated_sigmoid_centered import IteratedSigmoidCentered
+from tensorflow_probability.python.bijectors.joint_map import JointMap
from tensorflow_probability.python.bijectors.kumaraswamy_cdf import KumaraswamyCDF
from tensorflow_probability.python.bijectors.lambertw_transform import LambertWTail
from tensorflow_probability.python.bijectors.masked_autoregressive import AutoregressiveNetwork
@@ -73,6 +75,7 @@
from tensorflow_probability.python.bijectors.real_nvp import RealNVP
from tensorflow_probability.python.bijectors.reciprocal import Reciprocal
from tensorflow_probability.python.bijectors.reshape import Reshape
+from tensorflow_probability.python.bijectors.restructure import Restructure
from tensorflow_probability.python.bijectors.scale import Scale
from tensorflow_probability.python.bijectors.scale_matvec_diag import ScaleMatvecDiag
from tensorflow_probability.python.bijectors.scale_matvec_linear_operator import ScaleMatvecLinearOperator
@@ -113,6 +116,7 @@
"Chain",
"CholeskyOuterProduct",
"CholeskyToInvCholesky",
+ "Composition",
"CorrelationCholesky",
"Cumsum",
"DiscreteCosineTransform",
@@ -133,6 +137,7 @@
"Inline",
"Invert",
"IteratedSigmoidCentered",
+ "JointMap",
"KumaraswamyCDF",
"LambertWTail",
"Log",
@@ -152,6 +157,7 @@
"RealNVP",
"Reciprocal",
"Reshape",
+ "Restructure",
"Scale",
"ScaleMatvecDiag",
"ScaleMatvecLinearOperator",
diff --git a/tensorflow_probability/python/bijectors/bijector.py b/tensorflow_probability/python/bijectors/bijector.py
index 4be79019bc..1bbaddc9b7 100644
--- a/tensorflow_probability/python/bijectors/bijector.py
+++ b/tensorflow_probability/python/bijectors/bijector.py
@@ -214,6 +214,8 @@ def _forward_log_det_jacobian(self, x):
(e.g. a bijector which pads an extra dimension at the end, might have
`forward_min_event_ndims=0` and `inverse_min_event_ndims=1`.
+ ##### Additional Considerations for "Multi Tensor" Bijectors
+
Bijectors which operate on structures of `Tensor` require structured
`min_event_ndims` matching the structure of the inputs. In these cases,
`min_event_ndims` describes both the minimum dimensionality *and* the
@@ -225,8 +227,12 @@ def _forward_log_det_jacobian(self, x):
inverse_min_event_ndims=[-axis] * len(sizes)
```
- In these cases, the leftmost `min_event_ndims[i]` elements of
- `tensor[i].shape` must be identical for all structured inputs `i`.
+ Note: By default, we require `shape(x[i])[-event_ndims:-min_event_ndims]` to
+ be identical for all elements `i` of the structured input `x`. Specifically,
+ broadcasting over non-minimal event-dims is not allowed for structured inputs.
+ In cases where broadcasting is used as a "computational shorthand" for a dense
+ operation (that is, the _broadcasted_ inputs are assumed to be independent),
+ users should set `bijector._allow_event_shape_broadcasting = True`.
Finally, some bijectors that operate on structures of inputs may not know
the minimum structured rank of their inputs without calltime shape information
@@ -537,12 +543,14 @@ def __init__(self,
# structures, so it is important that we retain the original containers.
self._forward_min_event_ndims = self._no_dependency(forward_min_event_ndims)
self._inverse_min_event_ndims = self._no_dependency(inverse_min_event_ndims)
+ self._has_static_min_event_ndims = None not in (
+ nest.flatten([forward_min_event_ndims, inverse_min_event_ndims]))
- # By default, allow broadcasting within the `event_shape`, and do not check
- # for differences in the total number of degrees of freedom between forward
- # and inverse event shapes. This may cause incorrect LDJ values.
- # Subclasses may set this value to `False`.
- self._allow_event_shape_broadcasting = True
+ # Whether to allow broadcasting over the (non-minimal) event-shape for
+ # structured inputs. When `False` (default), assert that LDJ reduction
+ # shapes are identical for all components of nested inputs. When `True`,
+ # event-shape broadcasing is allow, but LDJ may be incorrect.
+ self._allow_event_shape_broadcasting = False
# Batch shape implied by the bijector's parameters, for use in validating
# LDJ shapes (currently only used in multipart bijectors.)
@@ -577,6 +585,11 @@ def inverse_min_event_ndims(self):
"""
return self._inverse_min_event_ndims
+ @property
+ def has_static_min_event_ndims(self):
+ """Returns True if the bijector has statically-known `min_event_ndims`."""
+ return self._has_static_min_event_ndims
+
@property
def is_constant_jacobian(self):
"""Returns true iff the Jacobian matrix is not a function of x.
@@ -610,6 +623,11 @@ def _is_scalar(self):
return (tf.get_static_value(self._forward_min_event_ndims) == 0 and
tf.get_static_value(self._inverse_min_event_ndims) == 0)
+ @property
+ def _is_permutation(self):
+ """Whether `y` is purely a reordering / restructuring of `x`."""
+ return False
+
@property
def validate_args(self):
"""Returns True if Tensor arguments will be validated."""
@@ -780,7 +798,7 @@ def forward_event_shape_tensor(self,
indicating event-portion shape after applying `forward`.
"""
with self._name_and_control_scope(name):
- # Use statically-known dtype attribute to infer structure.
+ # Use statically-known structure from min_event_ndims.
input_shape_dtype = nest_util.broadcast_structure(
self.forward_min_event_ndims, tf.int32)
input_shape = nest_util.convert_to_nested_tensor(
@@ -817,7 +835,9 @@ def forward_event_shape(self, input_shape):
"""
# Use statically-known dtype attribute to infer structure.
input_shape = nest.map_structure_up_to(
- self.forward_min_event_ndims, tf.TensorShape, input_shape)
+ self.forward_min_event_ndims, tf.TensorShape,
+ nest_util.coerce_structure(self.forward_min_event_ndims, input_shape),
+ check_types=False)
return nest.map_structure_up_to(
self.inverse_min_event_ndims, tf.TensorShape,
self._forward_event_shape(input_shape))
@@ -877,7 +897,9 @@ def inverse_event_shape(self, output_shape):
"""
# Use statically-known dtype attribute to infer structure.
output_shape = nest.map_structure_up_to(
- self.inverse_min_event_ndims, tf.TensorShape, output_shape)
+ self.inverse_min_event_ndims, tf.TensorShape,
+ nest_util.coerce_structure(self.inverse_min_event_ndims, output_shape),
+ check_types=False)
return nest.map_structure_up_to(
self.forward_min_event_ndims, tf.TensorShape,
self._inverse_event_shape(output_shape))
@@ -1001,7 +1023,7 @@ def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
ildj: the inverse log det jacobian at `y`. Also updates the cache as
needed.
"""
- if any(nd is None for nd in nest.flatten(self.inverse_min_event_ndims)):
+ if not self.has_static_min_event_ndims:
raise NotImplementedError(
'Subclasses without static `inverse_min_event_ndims` must override '
'`_call_inverse_log_det_jacobian`.')
@@ -1015,7 +1037,8 @@ def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
reduce_shape, assertions = ldj_reduction_shape(
nest.map_structure(ps.shape, y),
- event_ndims=event_ndims,
+ event_ndims=nest_util.coerce_structure(
+ self.inverse_min_event_ndims, event_ndims),
min_event_ndims=self._inverse_min_event_ndims,
parameter_batch_shape=self._parameter_batch_shape,
allow_event_shape_broadcasting=self._allow_event_shape_broadcasting,
@@ -1113,7 +1136,7 @@ def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
'forward_log_det_jacobian cannot be implemented for non-injective '
'transforms.')
- if any(nd is None for nd in nest.flatten(self.forward_min_event_ndims)):
+ if not self.has_static_min_event_ndims:
raise NotImplementedError(
'Subclasses without static `forward_min_event_ndims` must override '
'`_call_forward_log_det_jacobian`.')
@@ -1127,9 +1150,10 @@ def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
reduce_shape, assertions = ldj_reduction_shape(
nest.map_structure(ps.shape, x),
- event_ndims=event_ndims,
- parameter_batch_shape=self._parameter_batch_shape,
+ event_ndims=nest_util.coerce_structure(
+ self.forward_min_event_ndims, event_ndims),
min_event_ndims=self._forward_min_event_ndims,
+ parameter_batch_shape=self._parameter_batch_shape,
allow_event_shape_broadcasting=self._allow_event_shape_broadcasting,
validate_args=self.validate_args)
@@ -1198,20 +1222,26 @@ def _inverse_dtype(self, output_dtype, **kwargs):
def forward_dtype(self, dtype=UNSPECIFIED, name='forward_dtype', **kwargs):
"""Returns the dtype returned by `forward` for the provided input."""
with tf.name_scope('{}/{}'.format(self.name, name)):
- input_dtype = nest_util.broadcast_structure(
- self.forward_min_event_ndims, self.dtype)
- if dtype is not UNSPECIFIED:
+ if dtype is UNSPECIFIED:
+ # We pass the the broadcasted input structure through `_forward_dtype`
+ # rather than directly returning the output structure, allowing
+ # subclasses to alter results based on `**kwargs`.
+ input_dtype = nest_util.broadcast_structure(
+ self.forward_min_event_ndims, self.dtype)
+ else:
# Make sure inputs are compatible with statically-known dtype.
input_dtype = nest.map_structure_up_to(
- input_dtype,
- lambda x, dt: dtype_util.convert_to_dtype(x, dtype=dt),
- dtype, input_dtype)
+ self.forward_min_event_ndims,
+ lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
+ nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
+ check_types=False)
output_dtype = self._forward_dtype(input_dtype, **kwargs)
try:
# kwargs may alter dtypes themselves, but we currently require
# structure to be statically known.
- nest.assert_same_structure(self.inverse_min_event_ndims, output_dtype)
+ nest.assert_same_structure(self.inverse_min_event_ndims, output_dtype,
+ check_types=False)
except Exception as err:
raise NotImplementedError(
'Changing output structure in `forward_dtype` '
@@ -1219,22 +1249,28 @@ def forward_dtype(self, dtype=UNSPECIFIED, name='forward_dtype', **kwargs):
return output_dtype
def inverse_dtype(self, dtype=UNSPECIFIED, name='inverse_dtype', **kwargs):
- """Returns the dtype returned by forward for the provided input."""
+ """Returns the dtype returned by `inverse` for the provided input."""
with tf.name_scope('{}/{}'.format(self.name, name)):
- output_dtype = nest_util.broadcast_structure(
- self.inverse_min_event_ndims, self.dtype)
- if dtype is not UNSPECIFIED:
+ if dtype is UNSPECIFIED:
+ # We pass the the broadcasted output structure through `_inverse_dtype`
+ # rather than directly returning the input structure, allowing
+ # subclasses to alter results based on `**kwargs`.
+ output_dtype = nest_util.broadcast_structure(
+ self.inverse_min_event_ndims, self.dtype)
+ else:
# Make sure inputs are compatible with statically-known dtype.
output_dtype = nest.map_structure_up_to(
- output_dtype,
- lambda y, dt: dtype_util.convert_to_dtype(y, dtype=dt),
- dtype, output_dtype)
+ self.inverse_min_event_ndims,
+ lambda y: dtype_util.convert_to_dtype(y, dtype=self.dtype),
+ nest_util.coerce_structure(self.inverse_min_event_ndims, dtype),
+ check_types=False)
input_dtype = self._inverse_dtype(output_dtype, **kwargs)
try:
# kwargs may alter dtypes themselves, but we currently require
# structure to be statically known.
- nest.assert_same_structure(self.forward_min_event_ndims, input_dtype)
+ nest.assert_same_structure(self.forward_min_event_ndims, input_dtype,
+ check_types=False)
except Exception as err:
raise NotImplementedError(
'Changing output structure in `inverse_dtype` '
@@ -1243,28 +1279,26 @@ def inverse_dtype(self, dtype=UNSPECIFIED, name='inverse_dtype', **kwargs):
def forward_event_ndims(self, event_ndims, **kwargs):
"""Returns the number of event dimensions produced by `forward`."""
- if self._forward_min_event_ndims is None:
+ if not self.has_static_min_event_ndims:
raise NotImplementedError(
'Subclasses without static min_event_ndims must override '
'`forward_event_ndims`')
ldj_reduce_ndims = ldj_reduction_ndims(
- event_ndims,
- self._forward_min_event_ndims,
- self.validate_args)
+ nest_util.coerce_structure(self.forward_min_event_ndims, event_ndims),
+ self._forward_min_event_ndims)
return nest.map_structure(
lambda ndims: ldj_reduce_ndims + ndims,
self._inverse_min_event_ndims)
def inverse_event_ndims(self, event_ndims, **kwargs):
"""Returns the number of event dimensions produced by `inverse`."""
- if self._inverse_min_event_ndims is None:
+ if not self.has_static_min_event_ndims:
raise NotImplementedError(
'Subclasses without static min_event_ndims must override '
'`inverse_event_ndims`')
ldj_reduce_ndims = ldj_reduction_ndims(
- event_ndims,
- self._inverse_min_event_ndims,
- validate_args=self.validate_args)
+ nest_util.coerce_structure(self.inverse_min_event_ndims, event_ndims),
+ self._inverse_min_event_ndims)
return nest.map_structure(
lambda ndims: ldj_reduce_ndims + ndims,
self._forward_min_event_ndims)
@@ -1429,7 +1463,6 @@ def ldj_reduction_ndims(event_ndims,
Raises:
ValueError: When the structured difference between `event_ndims` and
`min_event_ndims` is not the same for all elements.
- ValueError: If the resulting `reduction_ndims` is negative.
"""
with tf.name_scope(name):
assertions = []
@@ -1467,20 +1500,8 @@ def ldj_reduction_ndims(event_ndims,
'input. Saw event_ndims={}, min_event_ndims={}.'
).format(event_ndims, min_event_ndims)))
- # Finally, make sure the difference is positive.
+ # Now the we know they're all the same, just choose the first.
result = flat_differences[0]
- if differences_all_static:
- if result < 0:
- raise ValueError('`event_ndims` must be at least {}. Saw: {}'
- .format(min_event_ndims, event_ndims))
- elif validate_args:
- with tf.control_dependencies(assertions):
- assertions.append(
- assert_util.assert_greater_equal(
- result, 0,
- message='`event_ndims` must be at least {}. Saw: {}'.format(
- min_event_ndims, event_ndims)))
-
if assertions:
with tf.control_dependencies(assertions):
result = tf.identity(result)
@@ -1594,6 +1615,20 @@ def ldj_reduction_shape(shape_structure,
event_ndims, min_event_ndims, validate_args=validate_args,
name='reduce_ndims')
+ # Make sure the number of dimensions we're reducing over is non-negative.
+ reduce_ndims_ = tf.get_static_value(reduce_ndims)
+ if reduce_ndims_ is not None:
+ if reduce_ndims_ < 0:
+ raise ValueError('`event_ndims must be at least {}. Saw: {}.'
+ .format(event_ndims, min_event_ndims))
+ elif validate_args:
+ with tf.control_dependencies(assertions):
+ assertions.append(
+ assert_util.assert_non_negative(
+ reduce_ndims,
+ message='`event_ndims` must be at least {}. Saw: {}'.format(
+ min_event_ndims, event_ndims)))
+
# Make sure inputs have rank greater than event_ndims.
rank_structure = nest.map_structure_up_to(
event_ndims, ps.size, shape_structure)
@@ -1609,8 +1644,9 @@ def ldj_reduction_shape(shape_structure,
with tf.control_dependencies(assertions):
assertions.append(
assert_util.assert_greater_equal(
- rank, ndims, message=('Input must have rank at least {}.'
- 'Saw: {}'.format(ndims, rank))))
+ rank, tf.cast(ndims, dtype_util.convert_to_dtype(rank)),
+ message=('Input must have rank at least {}.'
+ 'Saw: {}'.format(ndims, rank))))
# Get the non-minimal portion of the event shape over which to reduce LDJ.
ldj_reduce_shapes = nest.flatten(
diff --git a/tensorflow_probability/python/bijectors/bijector_properties_test.py b/tensorflow_probability/python/bijectors/bijector_properties_test.py
index a1277657b2..c6b73b37f1 100644
--- a/tensorflow_probability/python/bijectors/bijector_properties_test.py
+++ b/tensorflow_probability/python/bijectors/bijector_properties_test.py
@@ -660,6 +660,11 @@ def exception(bijector):
grads = tape.gradient(ldj, wrt_vars)
assert_no_none_grad(bijector, 'inverse_log_det_jacobian', wrt_vars, grads)
+ # Verify that `_is_permutation` implies constant zero Jacobian.
+ if bijector._is_permutation:
+ self.assertTrue(bijector._is_constant_jacobian)
+ self.assertAllEqual(ldj, 0.)
+
# Check that the outputs of forward_dtype and inverse_dtype match the dtypes
# of the outputs of forward and inverse.
self.assertAllEqualNested(ys.dtype, bijector.forward_dtype(xs.dtype))
diff --git a/tensorflow_probability/python/bijectors/blockwise.py b/tensorflow_probability/python/bijectors/blockwise.py
index 6efbffa06b..925b0b72c6 100644
--- a/tensorflow_probability/python/bijectors/blockwise.py
+++ b/tensorflow_probability/python/bijectors/blockwise.py
@@ -21,11 +21,16 @@
import numpy as np
import tensorflow.compat.v2 as tf
-from tensorflow_probability.python.bijectors import bijector as bijector_base
+from tensorflow_probability.python.bijectors import chain
+from tensorflow_probability.python.bijectors import invert
+from tensorflow_probability.python.bijectors import joint_map
+from tensorflow_probability.python.bijectors import split
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensorshape_util
+from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
+
__all__ = [
'Blockwise',
]
@@ -37,7 +42,7 @@ def _get_static_splits(splits):
return splits if static_splits is None else static_splits
-class Blockwise(bijector_base.Bijector):
+class Blockwise(chain.Chain):
"""Bijector which applies a list of bijectors to blocks of a `Tensor`.
More specifically, given [F_0, F_1, ... F_n] which are scalar or vector
@@ -110,127 +115,126 @@ def __init__(self,
if not name:
name = 'blockwise_of_' + '_and_'.join([b.name for b in bijectors])
name = name.replace('/', '')
+
with tf.name_scope(name) as name:
+ for b in bijectors:
+ if (nest.is_nested(b.forward_min_event_ndims)
+ or nest.is_nested(b.inverse_min_event_ndims)):
+ raise ValueError('Bijectors must all be single-part.')
+ elif isinstance(b.forward_min_event_ndims, int):
+ if b.forward_min_event_ndims != b.inverse_min_event_ndims:
+ raise ValueError('Rank-changing bijectors are not supported.')
+ elif b.forward_min_event_ndims > 1:
+ raise ValueError('Only scalar and vector event-shape '
+ 'bijectors are supported at this time.')
+
+ b_joint = joint_map.JointMap(list(bijectors), name='jointmap')
+
+ block_sizes = (
+ np.ones(len(bijectors), dtype=np.int32)
+ if block_sizes is None else
+ _validate_block_sizes(block_sizes, bijectors, validate_args))
+ b_split = split.Split(
+ block_sizes, name='split', validate_args=validate_args)
+
+ if maybe_changes_size:
+ i_block_sizes = _validate_block_sizes(
+ ps.concat(b_joint.forward_event_shape_tensor(
+ ps.split(block_sizes, len(bijectors))), axis=0),
+ bijectors, validate_args)
+ maybe_changes_size = not tf.get_static_value(
+ ps.reduce_all(block_sizes == i_block_sizes))
+ b_concat = invert.Invert(
+ (split.Split(i_block_sizes, name='isplit')
+ if maybe_changes_size else b_split),
+ name='concat')
+
+ self._maybe_changes_size = maybe_changes_size
super(Blockwise, self).__init__(
- forward_min_event_ndims=1,
+ bijectors=[b_concat, b_joint, b_split],
validate_args=validate_args,
parameters=parameters,
name=name)
- if not bijectors:
- raise ValueError('`bijectors` must not be empty.')
-
- for bijector in bijectors:
- if (bijector.forward_min_event_ndims > 1 or
- bijector.inverse_min_event_ndims > 1):
- # TODO(siege): In the future, it can be reasonable to support N-D
- # bijectors by concatenating along some specific axis, broadcasting
- # low-D bijectors appropriately.
- raise NotImplementedError('Only scalar and vector event-shape '
- 'bijectors are supported at this time.')
-
- self._bijectors = bijectors
- self._maybe_changes_size = maybe_changes_size
+ @property
+ def _b_joint(self):
+ return self._bijectors[1]
- if block_sizes is None:
- block_sizes = np.ones(len(bijectors), dtype=np.int32)
- self._block_sizes = ps.convert_to_shape_tensor(
- block_sizes, name='block_sizes', dtype_hint=tf.int32)
+ @property
+ def _b_split(self):
+ return self._bijectors[-1]
- self._block_sizes = _validate_block_sizes(self._block_sizes, bijectors,
- validate_args)
+ @property
+ def _b_concat(self):
+ return self._bijectors[0].bijector
@property
def bijectors(self):
- return self._bijectors
+ return self._b_joint.bijectors
@property
def block_sizes(self):
- return self._block_sizes
+ return self._b_split.split_sizes
- def _output_block_sizes(self):
- return [
- b.forward_event_shape_tensor(bs[tf.newaxis])[0]
- for b, bs in zip(self.bijectors,
- tf.unstack(self.block_sizes, num=len(self.bijectors)))
- ]
+ @property
+ def inverse_block_sizes(self):
+ return self._b_concat.split_sizes
+
+ def _forward(self, x, **kwargs):
+ y = super(Blockwise, self)._forward(x, **kwargs)
+ if not self._maybe_changes_size:
+ tensorshape_util.set_shape(y, x.shape)
+ return y
+
+ def _inverse(self, y, **kwargs):
+ x = super(Blockwise, self)._inverse(y, **kwargs)
+ if not self._maybe_changes_size:
+ tensorshape_util.set_shape(x, y.shape)
+ return x
def _forward_event_shape(self, input_shape):
+ if not self._maybe_changes_size:
+ return input_shape
input_shape = tensorshape_util.with_rank_at_least(input_shape, 1)
- static_block_sizes = tf.get_static_value(self.block_sizes)
+ static_block_sizes = tf.get_static_value(self.inverse_block_sizes)
if static_block_sizes is None:
return tensorshape_util.concatenate(input_shape[:-1], [None])
-
- output_size = sum(
- b.forward_event_shape([bs])[0]
- for b, bs in zip(self.bijectors, static_block_sizes))
-
+ output_size = sum(static_block_sizes)
return tensorshape_util.concatenate(input_shape[:-1], [output_size])
- def _forward_event_shape_tensor(self, input_shape):
- output_size = ps.reduce_sum(self._output_block_sizes())
- return ps.concat([input_shape[:-1], output_size[tf.newaxis]], -1)
-
def _inverse_event_shape(self, output_shape):
+ if not self._maybe_changes_size:
+ return output_shape
output_shape = tensorshape_util.with_rank_at_least(output_shape, 1)
static_block_sizes = tf.get_static_value(self.block_sizes)
if static_block_sizes is None:
return tensorshape_util.concatenate(output_shape[:-1], [None])
-
input_size = sum(static_block_sizes)
-
return tensorshape_util.concatenate(output_shape[:-1], [input_size])
- def _inverse_event_shape_tensor(self, output_shape):
- input_size = ps.reduce_sum(self.block_sizes)
- return ps.concat([output_shape[:-1], input_size[tf.newaxis]], -1)
-
- def _forward(self, x, **kwargs):
- split_x = tf.split(x, _get_static_splits(self.block_sizes), axis=-1,
- num=len(self.bijectors))
- # TODO(b/162023850): Sanitize the kwargs better.
- split_y = [
- b.forward(x_, **kwargs.get(b.name, {}))
- for b, x_ in zip(self.bijectors, split_x)
- ]
- y = tf.concat(split_y, axis=-1)
+ def _forward_event_shape_tensor(self, x, **kwargs):
if not self._maybe_changes_size:
- tensorshape_util.set_shape(y, x.shape)
- return y
+ return x
+ return super(Blockwise, self)._forward_event_shape_tensor(x, **kwargs)
- def _inverse(self, y, **kwargs):
- split_y = tf.split(y, _get_static_splits(self._output_block_sizes()),
- axis=-1, num=len(self.bijectors))
- split_x = [
- b.inverse(y_, **kwargs.get(b.name, {}))
- for b, y_ in zip(self.bijectors, split_y)
- ]
- x = tf.concat(split_x, axis=-1)
+ def _inverse_event_shape_tensor(self, y, **kwargs):
if not self._maybe_changes_size:
- tensorshape_util.set_shape(x, y.shape)
- return x
+ return y
+ return super(Blockwise, self)._inverse_event_shape_tensor(y, **kwargs)
+
+ def _walk_forward(self, step_fn, x, **kwargs):
+ return super(Blockwise, self)._walk_forward(
+ step_fn, x, **{self._b_joint.name: kwargs})
- def _forward_log_det_jacobian(self, x, **kwargs):
- split_x = tf.split(x, _get_static_splits(self.block_sizes), axis=-1,
- num=len(self.bijectors))
- fldjs = [
- b.forward_log_det_jacobian(x_, event_ndims=1, **kwargs.get(b.name, {}))
- for b, x_ in zip(self.bijectors, split_x)
- ]
- return sum(fldjs)
-
- def _inverse_log_det_jacobian(self, y, **kwargs):
- split_y = tf.split(y, _get_static_splits(self._output_block_sizes()),
- axis=-1, num=len(self.bijectors))
- ildjs = [
- b.inverse_log_det_jacobian(y_, event_ndims=1, **kwargs.get(b.name, {}))
- for b, y_ in zip(self.bijectors, split_y)
- ]
- return sum(ildjs)
+ def _walk_inverse(self, step_fn, x, **kwargs):
+ return super(Blockwise, self)._walk_inverse(
+ step_fn, x, **{self._b_joint.name: kwargs})
def _validate_block_sizes(block_sizes, bijectors, validate_args):
"""Helper to validate block sizes."""
+ block_sizes = ps.convert_to_shape_tensor(
+ block_sizes, name='block_sizes', dtype_hint=tf.int32)
block_sizes_shape = block_sizes.shape
if tensorshape_util.is_fully_defined(block_sizes_shape):
if (tensorshape_util.rank(block_sizes_shape) != 1 or
@@ -240,6 +244,7 @@ def _validate_block_sizes(block_sizes, bijectors, validate_args):
'`bijectors`. Got a `Tensor` with shape {} and `bijectors` of '
'length {}'.format(block_sizes_shape, len(bijectors)))
return block_sizes
+
elif validate_args:
message = ('`block_sizes` must be `None`, or a vector of the same length '
'as `bijectors`.')
@@ -248,6 +253,8 @@ def _validate_block_sizes(block_sizes, bijectors, validate_args):
tf.size(block_sizes), len(bijectors), message=message),
assert_util.assert_equal(tf.rank(block_sizes), 1)
]):
- return tf.identity(block_sizes)
- else:
- return block_sizes
+ block_sizes = tf.identity(block_sizes)
+
+ # Set the shape if missing to pass statically known structure to split.
+ tensorshape_util.set_shape(block_sizes, [len(bijectors)])
+ return block_sizes
diff --git a/tensorflow_probability/python/bijectors/blockwise_test.py b/tensorflow_probability/python/bijectors/blockwise_test.py
index 4c07105f2f..31dc836190 100644
--- a/tensorflow_probability/python/bijectors/blockwise_test.py
+++ b/tensorflow_probability/python/bijectors/blockwise_test.py
@@ -42,7 +42,10 @@ class BlockwiseBijectorTest(test_util.TestCase):
def testExplicitBlocks(self, dynamic_shape, batch_shape):
block_sizes = tf.convert_to_tensor(value=[2, 1, 3])
block_sizes = tf1.placeholder_with_default(
- block_sizes, shape=None if dynamic_shape else block_sizes.shape)
+ block_sizes,
+ shape=([None] * len(block_sizes.shape)
+ if dynamic_shape else
+ block_sizes.shape))
exp = tfb.Exp()
sp = tfb.Softplus()
aff = tfb.Affine(scale_diag=[2., 3., 4.])
@@ -237,11 +240,6 @@ def testRaisesEmptyBijectors(self):
with self.assertRaisesRegexp(ValueError, '`bijectors` must not be empty'):
tfb.Blockwise(bijectors=[])
- def testRaisesBadBijectors(self):
- with self.assertRaisesRegexp(NotImplementedError,
- 'Only scalar and vector event-shape'):
- tfb.Blockwise(bijectors=[tfb.Reshape(event_shape_out=[1, 1])])
-
def testRaisesBadBlocks(self):
with self.assertRaisesRegexp(
ValueError,
@@ -287,10 +285,10 @@ def testKwargs(self):
blockwise.inverse_log_det_jacobian(
x, event_ndims=1, inner0={'arg': 7}, inner1={'arg': 8})
- bijectors[0]._forward.assert_called_with(mock.ANY, arg=1)
- bijectors[1]._forward.assert_called_with(mock.ANY, arg=2)
- bijectors[0]._inverse.assert_called_with(mock.ANY, arg=3)
- bijectors[1]._inverse.assert_called_with(mock.ANY, arg=4)
+ bijectors[0]._forward.assert_any_call(mock.ANY, arg=1)
+ bijectors[1]._forward.assert_any_call(mock.ANY, arg=2)
+ bijectors[0]._inverse.assert_any_call(mock.ANY, arg=3)
+ bijectors[1]._inverse.assert_any_call(mock.ANY, arg=4)
bijectors[0]._forward_log_det_jacobian.assert_called_with(mock.ANY, arg=5)
bijectors[1]._forward_log_det_jacobian.assert_called_with(mock.ANY, arg=6)
bijectors[0]._inverse_log_det_jacobian.assert_called_with(mock.ANY, arg=7)
diff --git a/tensorflow_probability/python/bijectors/chain.py b/tensorflow_probability/python/bijectors/chain.py
index d0714698c1..69aa55ef9e 100644
--- a/tensorflow_probability/python/bijectors/chain.py
+++ b/tensorflow_probability/python/bijectors/chain.py
@@ -14,101 +14,25 @@
# ============================================================================
"""Chain bijector."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
-from tensorflow_probability.python.bijectors import bijector
-from tensorflow_probability.python.internal import distribution_util
-from tensorflow_probability.python.internal import dtype_util
-from tensorflow_probability.python.internal import prefer_static
-from tensorflow_probability.python.internal import tensorshape_util
+from tensorflow_probability.python.bijectors import bijector as bijector_lib
+from tensorflow_probability.python.bijectors import composition
+from tensorflow_probability.python.internal import nest_util
+from tensorflow_probability.python.internal import prefer_static as ps
+from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
__all__ = [
- "Chain",
+ 'Chain',
]
-def _use_static_shape(input_tensor, ndims):
- return (tensorshape_util.is_fully_defined(input_tensor.shape) and
- isinstance(ndims, int))
-
-
-def _compute_min_event_ndims(bijector_list, compute_forward=True):
- """Computes the min_event_ndims associated with the give list of bijectors.
-
- Given a list `bijector_list` of bijectors, compute the min_event_ndims that is
- associated with the composition of bijectors in that list.
-
- min_event_ndims is the # of right most dimensions for which the bijector has
- done necessary computation on (i.e. the non-broadcastable part of the
- computation).
-
- We can derive the min_event_ndims for a chain of bijectors as follows:
-
- In the case where there are no rank changing bijectors, this will simply be
- `max(b.forward_min_event_ndims for b in bijector_list)`. This is because the
- bijector with the most forward_min_event_ndims requires the most dimensions,
- and hence the chain also requires operating on those dimensions.
-
- However in the case of rank changing, more care is needed in determining the
- exact amount of dimensions. Padding dimensions causes subsequent bijectors to
- operate on the padded dimensions, and Removing dimensions causes bijectors to
- operate more left.
-
- Args:
- bijector_list: List of bijectors to be composed by chain.
- compute_forward: Boolean. If True, computes the min_event_ndims associated
- with a forward call to Chain, and otherwise computes the min_event_ndims
- associated with an inverse call to Chain. The latter is the same as the
- min_event_ndims associated with a forward call to Invert(Chain(....)).
-
- Returns:
- min_event_ndims
- """
- min_event_ndims = 0
- # This is a mouthful, but what this encapsulates is that if not for rank
- # changing bijectors, we'd only need to compute the largest of the min
- # required ndims. Hence "max_min". Due to rank changing bijectors, we need to
- # account for synthetic rank growth / synthetic rank decrease from a rank
- # changing bijector.
- rank_changed_adjusted_max_min_event_ndims = 0
-
- if compute_forward:
- bijector_list = reversed(bijector_list)
-
- for b in bijector_list:
- if compute_forward:
- current_min_event_ndims = b.forward_min_event_ndims
- current_inverse_min_event_ndims = b.inverse_min_event_ndims
- else:
- current_min_event_ndims = b.inverse_min_event_ndims
- current_inverse_min_event_ndims = b.forward_min_event_ndims
-
- # New dimensions were touched.
- if rank_changed_adjusted_max_min_event_ndims < current_min_event_ndims:
- min_event_ndims += (
- current_min_event_ndims - rank_changed_adjusted_max_min_event_ndims)
- rank_changed_adjusted_max_min_event_ndims = max(
- current_min_event_ndims, rank_changed_adjusted_max_min_event_ndims)
-
- # If the number of dimensions has increased via forward, then
- # inverse_min_event_ndims > forward_min_event_ndims, and hence the
- # dimensions we computed on, have moved left (so we have operated
- # on additional dimensions).
- # Conversely, if the number of dimensions has decreased via forward,
- # then we have inverse_min_event_ndims < forward_min_event_ndims,
- # and so we will have operated on fewer right most dimensions.
-
- number_of_changed_dimensions = (
- current_min_event_ndims - current_inverse_min_event_ndims)
- rank_changed_adjusted_max_min_event_ndims -= number_of_changed_dimensions
- return min_event_ndims
-
-
-class Chain(bijector.Bijector):
+class Chain(composition.Composition):
"""Bijector which applies a sequence of bijectors.
Example Use:
@@ -158,6 +82,7 @@ class Chain(bijector.Bijector):
def __init__(self,
bijectors=None,
validate_args=False,
+ validate_event_size=True,
parameters=None,
name=None):
"""Instantiates `Chain` bijector.
@@ -167,6 +92,16 @@ def __init__(self,
bijector equivalent to the `Identity` bijector.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
+ validate_event_size: Checks that bijectors are not applied to inputs with
+ incomplete support (that is, inputs where one or more elements are a
+ deterministic transformation of the others). For example, the following
+ LDJ would be incorrect:
+ `Chain([Scale(), SoftmaxCentered()]).forward_log_det_jacobian([1], [1])`
+ The jacobian contribution from `Scale` applies to a 2-dimensional input,
+ but the output from `SoftMaxCentered` is a 1-dimensional input embedded
+ in a 2-dimensional space. Setting `validate_event_size=True` (default)
+ prints warnings in these cases. When `validate_args` is also `True`, the
+ warning is promoted to an exception.
parameters: Locals dict captured by subclass constructor, to be used for
copy/slice re-instantiation operators.
name: Python `str`, name given to ops managed by this object. Default:
@@ -176,164 +111,47 @@ def __init__(self,
ValueError: if bijectors have different dtypes.
"""
parameters = dict(locals()) if parameters is None else parameters
- if name is None:
- name = ("identity" if not bijectors else
- "_of_".join(["chain"] + [b.name for b in bijectors]))
- name = name.replace("/", "")
- with tf.name_scope(name) as name:
- if bijectors is None:
- bijectors = ()
- self._bijectors = bijectors
- for a_bijector in bijectors:
- if not a_bijector._is_injective: # pylint: disable=protected-access
- raise NotImplementedError(
- "Invert is not implemented for non-injective bijector "
- "({})".format(a_bijector.name))
+ if name is None:
+ name = ('identity' if not bijectors else
+ '_of_'.join(['chain'] + [b.name for b in bijectors]))
+ name = name.replace('/', '')
- inverse_min_event_ndims = _compute_min_event_ndims(
- bijectors, compute_forward=False)
- forward_min_event_ndims = _compute_min_event_ndims(
- bijectors, compute_forward=True)
+ if bijectors:
+ f_min_event_ndims, i_min_event_ndims = _infer_min_event_ndims(bijectors)
+ else:
+ # If there are no bijectors, treat this like a single-part Identity.
+ f_min_event_ndims = i_min_event_ndims = None
+ with tf.name_scope(name) as name:
super(Chain, self).__init__(
- forward_min_event_ndims=forward_min_event_ndims,
- inverse_min_event_ndims=inverse_min_event_ndims,
- is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors),
+ bijectors=bijectors or (),
+ forward_min_event_ndims=f_min_event_ndims,
+ inverse_min_event_ndims=i_min_event_ndims,
validate_args=validate_args,
+ validate_event_size=validate_event_size,
parameters=parameters,
name=name)
- @property
- def bijectors(self):
- return self._bijectors
-
- def _shape_helper(self, func_name, input_shape, reverse):
- new_shape = input_shape
- for b in reversed(self.bijectors) if reverse else self.bijectors:
- func = getattr(b, func_name, None)
- if func is None:
- raise ValueError("unable to call %s on bijector %s (%s)" %
- (func_name, b.name, func))
- new_shape = func(new_shape)
- return new_shape
-
- def _forward_event_shape(self, input_shape):
- return self._shape_helper("forward_event_shape", input_shape,
- reverse=True)
-
- def _forward_event_shape_tensor(self, input_shape):
- return self._shape_helper(
- "forward_event_shape_tensor", input_shape, reverse=True)
-
- def _inverse_event_shape(self, output_shape):
- return self._shape_helper("inverse_event_shape", output_shape,
- reverse=False)
-
- def _inverse_event_shape_tensor(self, output_shape):
- return self._shape_helper("inverse_event_shape_tensor", output_shape,
- reverse=False)
-
def _is_increasing(self, **kwargs):
# desc(desc)=>asc, asc(asc)=>asc, other cases=>desc.
is_increasing = True
- for b in self.bijectors:
- is_increasing = prefer_static.equal(
+ for b in self._bijectors:
+ is_increasing = ps.equal(
is_increasing, b._internal_is_increasing(**kwargs.get(b.name, {}))) # pylint: disable=protected-access
return is_increasing
- def _inverse(self, y, **kwargs):
- for b in self.bijectors:
- y = b.inverse(y, **kwargs.get(b.name, {}))
- return y
-
- def _inverse_log_det_jacobian(self, y, **kwargs):
- y = tf.convert_to_tensor(y, name="y")
- ildj = tf.cast(0., dtype=dtype_util.base_dtype(y.dtype))
+ def _walk_forward(self, step_fn, x, **kwargs):
+ """Applies `transform_fn` to `x` sequentially over nested bijectors."""
+ for bij in reversed(self._bijectors):
+ x = step_fn(bij, x, **kwargs.get(bij.name, {}))
+ return x # Now `y`
- if not self.bijectors:
- return ildj
-
- # TODO(b/162764645): Remove explicit event_ndims in Composite CL.
- event_ndims = self.inverse_min_event_ndims
-
- if _use_static_shape(y, event_ndims):
- event_shape = y.shape[tensorshape_util.rank(y.shape) - event_ndims:]
- else:
- event_shape = tf.shape(y)[tf.rank(y) - event_ndims:]
-
- # TODO(b/129973548): Document and simplify.
- for b in self.bijectors:
- ildj = ildj + b.inverse_log_det_jacobian(
- y, event_ndims=event_ndims, **kwargs.get(b.name, {}))
-
- if _use_static_shape(y, event_ndims):
- event_shape = b.inverse_event_shape(event_shape)
- event_ndims = tensorshape_util.rank(event_shape)
- else:
- event_shape = b.inverse_event_shape_tensor(event_shape)
- event_shape_ = distribution_util.maybe_get_static_value(event_shape)
- event_ndims = tf.size(event_shape)
- event_ndims_ = event_ndims
-
- if event_ndims_ is not None and event_shape_ is not None:
- event_ndims = event_ndims_
- event_shape = event_shape_
-
- y = b.inverse(y, **kwargs.get(b.name, {}))
- return ildj
-
- def _inverse_dtype(self, dtype, **kwargs):
- for b in self.bijectors:
- dtype = b.inverse_dtype(dtype, **kwargs.get(b.name, {}))
- return dtype
-
- def _forward(self, x, **kwargs):
- # TODO(b/162023850): Sanitize the kwargs better.
- for b in reversed(self.bijectors):
- x = b.forward(x, **kwargs.get(b.name, {}))
- return x
-
- def _forward_log_det_jacobian(self, x, **kwargs):
- x = tf.convert_to_tensor(x, name="x")
-
- fldj = tf.cast(0., dtype=dtype_util.base_dtype(x.dtype))
-
- if not self.bijectors:
- return fldj
-
- event_ndims = self.forward_min_event_ndims
-
- if _use_static_shape(x, event_ndims):
- event_shape = x.shape[tensorshape_util.rank(x.shape) - event_ndims:]
- else:
- event_shape = tf.shape(x)[tf.rank(x) - event_ndims:]
-
- # TODO(b/129973548): Document and simplify.
- for b in reversed(self.bijectors):
- fldj = fldj + b.forward_log_det_jacobian(
- x, event_ndims=event_ndims, **kwargs.get(b.name, {}))
- if _use_static_shape(x, event_ndims):
- event_shape = b.forward_event_shape(event_shape)
- event_ndims = tensorshape_util.rank(event_shape)
- else:
- event_shape = b.forward_event_shape_tensor(event_shape)
- event_shape_ = distribution_util.maybe_get_static_value(event_shape)
- event_ndims = tf.size(event_shape)
- event_ndims_ = event_ndims
-
- if event_ndims_ is not None and event_shape_ is not None:
- event_ndims = event_ndims_
- event_shape = event_shape_
-
- x = b.forward(x, **kwargs.get(b.name, {}))
-
- return fldj
-
- def _forward_dtype(self, dtype, **kwargs):
- for b in reversed(self.bijectors):
- dtype = b.forward_dtype(dtype, **kwargs.get(b.name, {}))
- return dtype
+ def _walk_inverse(self, step_fn, y, **kwargs):
+ """Applies `transform_fn` to `y` sequentially over nested bijectors."""
+ for bij in self._bijectors:
+ y = step_fn(bij, y, **kwargs.get(bij.name, {}))
+ return y # Now `x`
@property
def _composite_tensor_nonshape_params(self):
@@ -345,4 +163,85 @@ def _composite_tensor_nonshape_params(self):
identifies the keys of parameters that are expected to be tensors, except
those that are shape-related.
"""
- return ("bijectors",)
+ return ('bijectors',)
+
+
+def _infer_min_event_ndims(bijectors):
+ """Computes `min_event_ndims` for a sequence of bijectors."""
+ # Find the index of the first bijector with statically-known min_event_ndims.
+ try:
+ idx = next(i for i, b in enumerate(bijectors)
+ if b.has_static_min_event_ndims)
+ except StopIteration:
+ # If none of the nested bijectors have static min_event_ndims, give up
+ # and return tail-structures filled with `None`.
+ return (
+ nest_util.broadcast_structure(
+ bijectors[-1].forward_min_event_ndims, None),
+ nest_util.broadcast_structure(
+ bijectors[0].inverse_min_event_ndims, None))
+
+ # Accumulator tracking the maximum value of "min_event_ndims - ndims".
+ rolling_offset = 0
+
+ def update_event_ndims(input_event_ndims,
+ input_min_event_ndims,
+ output_min_event_ndims):
+ """Returns output_event_ndims and updates rolling_offset as needed."""
+ nonlocal rolling_offset
+ ldj_reduce_ndims = bijector_lib.ldj_reduction_ndims(
+ input_event_ndims, input_min_event_ndims)
+ # Update rolling_offset when batch_ndims are negative.
+ rolling_offset = ps.maximum(rolling_offset, -ldj_reduce_ndims)
+ return nest.map_structure(lambda nd: ldj_reduce_ndims + nd,
+ output_min_event_ndims)
+
+ def sanitize_event_ndims(event_ndims):
+ """Updates `rolling_offset` when event_ndims are negative."""
+ nonlocal rolling_offset
+ max_missing_ndims = -ps.reduce_min(nest.flatten(event_ndims))
+ rolling_offset = ps.maximum(rolling_offset, max_missing_ndims)
+ return event_ndims
+
+ # Wrappers for Bijector.forward_event_ndims and Bijector.inverse_event_ndims
+ # that recursively walk into Composition bijectors when static min_event_ndims
+ # is not available.
+
+ def update_f_event_ndims(bij, event_ndims):
+ event_ndims = nest_util.coerce_structure(
+ bij.inverse_min_event_ndims, event_ndims)
+ if bij.has_static_min_event_ndims:
+ return update_event_ndims(
+ input_event_ndims=event_ndims,
+ input_min_event_ndims=bij.inverse_min_event_ndims,
+ output_min_event_ndims=bij.forward_min_event_ndims)
+ elif isinstance(bij, composition.Composition):
+ return bij._call_walk_inverse(update_f_event_ndims, event_ndims) # pylint: disable=protected-access
+ else:
+ return sanitize_event_ndims(bij.inverse_event_ndims(event_ndims))
+
+ def update_i_event_ndims(bij, event_ndims):
+ event_ndims = nest_util.coerce_structure(
+ bij.forward_min_event_ndims, event_ndims)
+ if bij.has_static_min_event_ndims:
+ return update_event_ndims(
+ input_event_ndims=event_ndims,
+ input_min_event_ndims=bij.forward_min_event_ndims,
+ output_min_event_ndims=bij.inverse_min_event_ndims)
+ elif isinstance(bij, composition.Composition):
+ return bij._call_walk_forward(update_i_event_ndims, event_ndims) # pylint: disable=protected-access
+ else:
+ return sanitize_event_ndims(bij.forward_event_ndims(event_ndims))
+
+ # Initialize event_ndims to the first statically-known min_event_ndims in
+ # the Chain of bijectors.
+ f_event_ndims = i_event_ndims = bijectors[idx].inverse_min_event_ndims
+ for b in bijectors[idx:]:
+ f_event_ndims = update_f_event_ndims(b, f_event_ndims)
+ for b in reversed(bijectors[:idx]):
+ i_event_ndims = update_i_event_ndims(b, i_event_ndims)
+
+ # Shift both event_ndims to satisfy min_event_ndims for nested components.
+ return (nest.map_structure(lambda nd: rolling_offset + nd, f_event_ndims),
+ nest.map_structure(lambda nd: rolling_offset + nd, i_event_ndims))
+
diff --git a/tensorflow_probability/python/bijectors/chain_test.py b/tensorflow_probability/python/bijectors/chain_test.py
index 02b4451023..1bf97c1a0f 100644
--- a/tensorflow_probability/python/bijectors/chain_test.py
+++ b/tensorflow_probability/python/bijectors/chain_test.py
@@ -19,6 +19,7 @@
from __future__ import print_function
# Dependency imports
+import mock
import numpy as np
import tensorflow.compat.v1 as tf1
@@ -69,6 +70,16 @@ def testBijectorIdentity(self):
self.assertAllClose(
0., self.evaluate(chain.forward_log_det_jacobian(x, event_ndims=1)))
+ def testNestedDtype(self):
+ chain = tfb.Chain([
+ tfb.Identity(),
+ tfb.Scale(tf.constant(2., tf.float64)),
+ tfb.Identity()
+ ])
+
+ self.assertAllClose(tf.constant([2, 4, 6], tf.float64),
+ self.evaluate(chain.forward([1, 2, 3])))
+
def testScalarCongruency(self):
chain = tfb.Chain((tfb.Exp(), tfb.Softplus()))
bijector_test_util.assert_scalar_congruency(
@@ -155,6 +166,29 @@ def testMinEventNdimsShapeChangingAddRemoveDims(self):
self.assertEqual(4, chain.forward_min_event_ndims)
self.assertEqual(1, chain.inverse_min_event_ndims)
+ def testMinEventNdimsWithJointMap(self):
+ jm_0 = tfb.JointMap([ShapeChanging(1, 1), ShapeChanging(3, 1)])
+ split = ShapeChanging(1, [1, 1])
+ concat = ShapeChanging([1, 1], 1)
+ jm_1 = tfb.JointMap([ShapeChanging(1, 0), ShapeChanging(1, 1)])
+
+ self.assertFalse(jm_0.has_static_min_event_ndims)
+ self.assertFalse(jm_1.has_static_min_event_ndims)
+ self.assertTrue(split.has_static_min_event_ndims)
+ self.assertTrue(concat.has_static_min_event_ndims)
+
+ # Decidable. Inner bijectors have static min_event_ndims.
+ chain = tfb.Chain([jm_0, split, concat, jm_1])
+ self.assertTrue(chain.has_static_min_event_ndims)
+ self.assertAllEqualNested([4, 3], chain.forward_min_event_ndims)
+ self.assertAllEqualNested([3, 1], chain.inverse_min_event_ndims)
+
+ # Undecidable. None of the nested bijectors have known event_ndims.
+ chain = tfb.Chain([jm_0, jm_1])
+ self.assertFalse(chain.has_static_min_event_ndims)
+ self.assertAllEqualNested([None, None], chain.forward_min_event_ndims)
+ self.assertAllEqualNested([None, None], chain.inverse_min_event_ndims)
+
def testChainExpAffine(self):
scale_diag = np.array([1., 2., 3.], dtype=np.float32)
chain = tfb.Chain([tfb.Exp(), tfb.Affine(scale_diag=scale_diag)])
@@ -233,13 +267,42 @@ def ldj(_):
# The shape of `ildj` is known statically to be scalar; its value is
# not statically known.
self.assertTrue(tensorshape_util.is_fully_defined(ildj.shape))
- self.assertEqual(self.evaluate(ildj), -9.)
+
+ # `ldj_reduce_shape` uses `prefer_static` to get input shapes. That means
+ # that we respect statically-known shape information where present.
+ # In this case, the manually-assigned static shape is incorrect.
+ self.assertEqual(self.evaluate(ildj), -7.)
# Ditto.
fldj = chain.forward_log_det_jacobian([0.], event_ndims=0)
self.assertTrue(tensorshape_util.is_fully_defined(fldj.shape))
self.assertEqual(self.evaluate(fldj), 3.)
+ def testDofChangeError(self):
+ exp = tfb.Exp()
+ smc = tfb.SoftmaxCentered()
+
+ # Increase in event-size is the last step. No problems here.
+ safe_bij = tfb.Chain([smc, exp], validate_args=True)
+ self.evaluate(safe_bij.forward_log_det_jacobian([1., 2., 3.], 1))
+
+ # Increase in event-size before Exp.
+ raise_bij = tfb.Chain([exp, smc], validate_args=True)
+ with self.assertRaisesRegex((ValueError, tf.errors.InvalidArgumentError),
+ r".+degrees of freedom.+"):
+ self.evaluate(raise_bij.forward_log_det_jacobian([1., 2., 3.], 1))
+
+ # When validate_args is False, warns instead of raising.
+ warn_bij = tfb.Chain([exp, smc], validate_args=False)
+ with mock.patch.object(tf, "print", return_value=tf.no_op()) as mock_print:
+ self.evaluate(warn_bij.forward_log_det_jacobian([1., 2., 3.], 1))
+ print_args, _ = mock_print.call_args
+ self.assertRegex(print_args[0], r"WARNING:.+degrees of freedom")
+
+ # When validate_event_shape is False, neither warns nor raises.
+ ignore_bij = tfb.Chain([exp, smc], validate_event_size=False)
+ self.evaluate(ignore_bij.forward_log_det_jacobian([1., 2., 3.], 1))
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow_probability/python/bijectors/composition.py b/tensorflow_probability/python/bijectors/composition.py
new file mode 100644
index 0000000000..f26277a4c6
--- /dev/null
+++ b/tensorflow_probability/python/bijectors/composition.py
@@ -0,0 +1,562 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Composition base class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import sys
+
+import tensorflow.compat.v2 as tf
+
+from tensorflow_probability.python.bijectors import bijector
+from tensorflow_probability.python.internal import assert_util
+from tensorflow_probability.python.internal import dtype_util
+from tensorflow_probability.python.internal import nest_util
+from tensorflow_probability.python.internal import prefer_static as ps
+from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
+
+
+__all__ = [
+ 'Composition',
+]
+
+
+def pack_structs_like(template, *structures):
+ """Converts a tuple of structs like `template` to a structure of tuples."""
+ if not structures:
+ return nest.map_structure(lambda x: (), template)
+ return nest.map_structure_up_to(template, (lambda *args: args),
+ *structures, check_types=False)
+
+
+def unpack_structs_like(template, packed):
+ """Converts a structure of tuples like `template` to a tuple of structures."""
+ return tuple(nest.pack_sequence_as(template, flat) for flat in
+ zip(*nest.flatten_up_to(template, packed, check_types=False)))
+
+
+def _event_size(tensor_structure, event_ndims):
+ """Returns the number of elements in the event-portion of a structure."""
+ event_shapes = nest.map_structure(
+ lambda t, nd: ps.slice(ps.shape(t), [ps.rank(t)-nd], [nd]),
+ tensor_structure, event_ndims)
+ return sum(ps.reduce_prod(shape) for shape in nest.flatten(event_shapes))
+
+
+def _max_precision_sum(a, b):
+ """Coerces `a` or `b` to the higher-precision dtype, and returns the sum."""
+ if not dtype_util.base_equal(a.dtype, b.dtype):
+ if dtype_util.size(a.dtype) >= dtype_util.size(b.dtype):
+ b = tf.cast(b, a.dtype)
+ else:
+ a = tf.cast(a, b.dtype)
+ return a + b
+
+
+class Composition(bijector.Bijector):
+ """Base class for Composition bijectors (Chain, JointMap).
+
+ A Composition represents a partially ordered set of invertible
+ transformations. These transformations may happen in series (Chain), in
+ parallel (JointMap), or they could be an arbitrary DAG. Composition handles
+ the common machinery of such transformations, delegating graph-traversal to
+ `_walk_forward` and `_walk_inverse` (which must be overridden by subclasses).
+
+ The `_walk_{direction}` methods take a `step_fn`, a single (structured)
+ `argument` (representing zipped `*args`), and arbitrary `**kwargs`. They are
+ responsible for invoking `step_fn(bij, bij_inputs, **bij_kwds)`
+ for each nested bijector. See `Chain` and `JointMap` for examples.
+
+ These methods are typically invoked using `_call_walk_{direction}`, which
+ wraps `step_fn` and converts structured `*args` into a single structure of
+ tuples, allowing users to provide a `step_fn` with multiple positional
+ arguments (e.g., `foward_log_det_jacobian`).
+
+ In practice, Bijector methods are defined in the base-class, and users
+ should not need to invoke `walk` methods directly.
+ """
+
+ def __init__(self,
+ bijectors,
+ forward_min_event_ndims,
+ inverse_min_event_ndims,
+ name,
+ parameters,
+ validate_event_size=False,
+ **kwargs):
+ """Instantiates a Composition of bijectors.
+
+ Args:
+ bijectors: A nest-compatible structure of bijector instances.
+ forward_min_event_ndims: A (structure of) integer describing both the
+ multi-part structure of inputs to `forward` and the _aligned_ mininimum
+ valid event-ndims. Compositions that allow different relative ranks
+ should pass structures of `None`.
+ inverse_min_event_ndims: A (structure of) integer describing both the
+ multi-part structure of inputs to `inverse` and the _aligned_ mininimum
+ valid event-ndims. Compositions that allow different relative ranks
+ should pass structures of `None`.
+ name: Name of this bijector.
+ parameters: Dictionary of parameters used to initialize this bijector.
+ These must be the exact values passed to `__init__`.
+ validate_event_size: Checks that bijectors are not applied to inputs with
+ incomplete support. For example, the following LDJ would be incorrect:
+ `Chain([Scale(), SoftmaxCentered()]).forward_log_det_jacobian([1], [1])`
+ The jacobian contribution from `Scale` applies to a 2-dimensional input,
+ but the output from `SoftMaxCentered` is a 1-dimensional input embedded
+ in a 2-dimensional space. Setting `validate_event_size=True` (default)
+ prints warnings in these cases. When `validate_args` is also `True`, the
+ warning is promoted to an exception.
+ **kwargs: Additional parameters forwarded to the bijector base-class.
+ """
+
+ with tf.name_scope(name):
+ is_constant_jacobian = True
+ is_injective = True
+ is_permutation = True
+ for bij in nest.flatten(bijectors):
+ is_injective &= bij._is_injective
+ is_constant_jacobian &= bij.is_constant_jacobian
+ is_permutation &= bij._is_permutation
+
+ super(Composition, self).__init__(
+ forward_min_event_ndims=forward_min_event_ndims,
+ inverse_min_event_ndims=inverse_min_event_ndims,
+ is_constant_jacobian=is_constant_jacobian,
+ parameters=parameters,
+ name=name,
+ **kwargs)
+
+ # Copy the nested structure so we don't mutate arguments during tracking.
+ self._bijectors = nest.map_structure(lambda b: b, bijectors)
+ self._validate_event_size = validate_event_size
+ self.__is_injective = is_injective
+ self.__is_permutation = is_permutation
+
+ @property
+ def bijectors(self):
+ return self._bijectors
+
+ @property
+ def validate_event_size(self):
+ return self._validate_event_size
+
+ @property
+ def _is_injective(self):
+ return self.__is_injective
+
+ @property
+ def _is_permutation(self):
+ return self.__is_permutation
+
+ # pylint: disable=redefined-builtin
+
+ def _call_walk_forward(self, step_fn, *args, **kwargs):
+ """Prepares args and calls `_walk_forward`.
+
+ Converts a tuple of structured positional arguments to a structure of
+ argument tuples, and wraps `step_fn` to unpack inputs and re-pack
+ returned values. This way, users may invoke walks using `map_structure`
+ semantics, and the concrete `_walk` implementations can operate on
+ a single structure of inputs (without worrying about tuple unpacking).
+
+
+ For example, the `forward` method looks roughly like:
+ ```python
+
+ MyComposition()._call_walk_forward(
+ lambda bij, x, **kwargs: bij.forward(x, **kwargs),
+ composite_inputs, **composite_kwargs)
+ ```
+
+ More complex methods may need to mutate external state from `step_fn`:
+ ```python
+
+ shape_trace = {}
+
+ def trace_step(bijector, x_shape):
+ shape_trace[bijector.name] = x_shape
+ return bijector.forward_event_shape(x_shape)
+
+ # Calling this populates the `shape_trace` dictionary
+ composition.walk_forward(trace_step, composite_input_shape)
+ ```
+
+ Args:
+ step_fn: Callable applied to each wrapped bijector.
+ Must accept a bijector instance followed by `len(args)` positional
+ arguments whose structures match `bijector.forward_min_event_ndims`,
+ and return `len(args)` structures matching
+ `bijector.inverse_min_event_ndims`.
+ *args: Input arguments propagated to nested bijectors.
+ **kwargs: Keyword arguments forwarded to `_walk_forward`.
+ Returns:
+ The transformed output. If multiple positional arguments are provided, a
+ tuple of matching length will be returned.
+ """
+ args = tuple(nest_util.coerce_structure(self.forward_min_event_ndims, x)
+ for x in args)
+
+ if len(args) == 1:
+ return self._walk_forward(step_fn, *args, **kwargs)
+
+ # Convert a tuple of structures to a structure of tuples. This
+ # allows `_walk` methods to route aligned structures of inputs/outputs
+ # independently, obviates the need for conditional tuple unpacking.
+ packed_args = pack_structs_like(self.forward_min_event_ndims, *args)
+
+ def transform_wrapper(bij, packed_xs, **nested):
+ xs = unpack_structs_like(bij.forward_min_event_ndims, packed_xs)
+ ys = step_fn(bij, *xs, **nested)
+ return pack_structs_like(bij.inverse_min_event_ndims, *ys)
+
+ packed_result = self._walk_forward(
+ transform_wrapper, packed_args, **kwargs)
+ return unpack_structs_like(self.inverse_min_event_ndims, packed_result)
+
+ def _call_walk_inverse(self, step_fn, *args, **kwargs):
+ """Prepares args and calls `_walk_inverse`.
+
+ Converts a tuple of structured positional arguments to a structure of
+ argument tuples, and wraps `step_fn` to unpack inputs and re-pack
+ returned values. This way, users may invoke walks using `map_structure`
+ semantics, and the concrete `_walk` implementations can operate on
+ single-structure of inputs (without worrying about tuple unpacking).
+
+ For example, the `inverse` method looks roughly like:
+ ```python
+
+ MyComposition()._call_walk_inverse(
+ lambda bij, y, **kwargs: bij.inverse(y, **kwargs),
+ composite_inputs, **composite_kwargs)
+ ```
+
+ More complex methods may need to mutate external state from `step_fn`:
+ ```python
+
+ shape_trace = {}
+
+ def trace_step(bijector, y_shape):
+ shape_trace[bijector.name] = y_shape
+ return bijector.inverse_event_shape(y_shape)
+
+ # Calling this populates the `shape_trace` dictionary
+ composition.walk_forward(trace_step, composite_y_shape)
+ ```
+
+ Args:
+ step_fn: Callable applied to each wrapped bijector.
+ Must accept a bijector instance followed by `len(args)` positional
+ arguments whose structures match `bijector.inverse_min_event_ndims`,
+ and return `len(args)` structures matching
+ `bijector.forward_min_event_ndims`.
+ *args: Input arguments propagated to nested bijectors.
+ **kwargs: Keyword arguments forwarded to `_walk_inverse`.
+ Returns:
+ The transformed output. If multiple positional arguments are provided, a
+ tuple of matching length will be returned.
+ """
+ args = tuple(nest_util.coerce_structure(self.inverse_min_event_ndims, y)
+ for y in args)
+
+ if len(args) == 1:
+ return self._walk_inverse(step_fn, *args, **kwargs)
+
+ # Convert a tuple of structures to a structure of tuples. This
+ # allows `_walk` methods to route aligned structures of inputs/outputs
+ # independently, obviates the need for conditional tuple unpacking.
+ packed_args = pack_structs_like(self.inverse_min_event_ndims, *args)
+
+ def transform_wrapper(bij, packed_ys, **nested):
+ ys = unpack_structs_like(bij.inverse_min_event_ndims, packed_ys)
+ xs = step_fn(bij, *ys, **nested)
+ return pack_structs_like(bij.forward_min_event_ndims, *xs)
+
+ packed_result = self._walk_inverse(
+ transform_wrapper, packed_args, **kwargs)
+ return unpack_structs_like(self.forward_min_event_ndims, packed_result)
+
+ ### Abstract Methods
+
+ @abc.abstractmethod
+ def _walk_forward(self, step_fn, argument, **kwargs):
+ """Subclass stub for forward-mode traversals.
+
+ The `_walk_{direction}` methods define how arguments are routed through
+ nested bijectors, expressing the directed topology of the underlying graph.
+
+ Arguments:
+ step_fn: A method taking a bijector, a single positional argument
+ matching `bijector.forward_min_event_ndims`, and arbitrary **kwargs,
+ and returning a structure matching `bijector.inverse_min_event_ndims`.
+ In cases where multiple structured inputs are required, use
+ `_call_walk_forward` instead.
+ argument: A (structure of) Tensor matching `self.forward_min_event_ndims`.
+ **kwargs: Keyword arguments to be forwarded to nested bijectors.
+ """
+ raise NotImplementedError('{}._walk_forward is not implemented'.format(
+ type(self).__name__))
+
+ @abc.abstractmethod
+ def _walk_inverse(self, step_fn, argument, **kwargs):
+ """Subclass stub for inverse-mode traversals.
+
+ The `_walk_{direction}` methods define how arguments are routed through
+ nested bijectors, expressing the directed topology of the underlying graph.
+
+ Arguments:
+ step_fn: A method taking a bijector, a single positional argument
+ matching `bijector.inverse_min_event_ndims`, and arbitrary **kwargs,
+ and returning a structure matching `bijector.forward_min_event_ndims`.
+ In cases where multiple structured inputs are required, use
+ `_call_walk_inverse` instead.
+ argument: A (structure of) Tensor matching `self.inverse_min_event_ndims`.
+ **kwargs: Keyword arguments to be forwarded to nested bijectors.
+ """
+ raise NotImplementedError('{}._walk_inverse is not implemented'.format(
+ type(self).__name__))
+
+ ###
+ ### Nontrivial Methods
+ ###
+
+ ## LDJ Methods
+
+ # DAGs of bijectors do not generally have statically-known `min_event_ndims`
+ # in the way that most other bijectors do.
+
+ # Consider a single bijector that applies Exp() to a 2-tuple of Tensors with
+ # shapes `[2, 3]` and `[3, 2]`. Valid values for `event_ndims` may have
+ # different relative-ranks (e.g., `(2,2)` and `(1,0)`). Meanwhile,
+ # passing `(1,1)` would result in a broadcasting exception. This being the
+ # case, we cannot return a "minimally-reduced LDJ" without knowing both the
+ # event-dimensionality _and_ the shapes of inputs. As such, we forego
+ # intermediate LDJ caching entirely, and request fully-reduced LDJ from nested
+ # bijectors. This requires us to change the signature of
+ # `_{direction}_log_det_jacobian` to include `event_ndims`.
+
+ def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
+ """Compute forward_log_det_jacobian over the composition."""
+ with self._name_and_control_scope(name):
+ dtype = self.inverse_dtype(**kwargs)
+ x = nest_util.convert_to_nested_tensor(
+ x, name='x', dtype_hint=dtype,
+ dtype=None if bijector.SKIP_DTYPE_CHECKS else dtype,
+ allow_packing=True)
+ event_ndims = nest_util.coerce_structure(
+ self.forward_min_event_ndims, event_ndims)
+ return self._forward_log_det_jacobian(x, event_ndims, **kwargs)
+
+ def _forward_log_det_jacobian(self, x, event_ndims, **kwargs):
+ # Container for accumulated LDJ.
+ ldj_sum = tf.zeros([], dtype=tf.float32)
+ # Container for accumulated assertions.
+ assertions = []
+
+ def step(bij, x, x_event_ndims, increased_dof, **kwargs): # pylint: disable=missing-docstring
+ nonlocal ldj_sum
+
+ # Compute the LDJ for this step, and add it to the rolling sum.
+ component_ldj = tf.convert_to_tensor(
+ bij.forward_log_det_jacobian(x, x_event_ndims, **kwargs),
+ dtype_hint=ldj_sum.dtype)
+
+ if not dtype_util.is_floating(component_ldj.dtype):
+ raise TypeError(('Nested bijector "{}" of Composition "{}" returned '
+ 'LDJ with a non-floating dtype: {}')
+ .format(bij.name, self.name, component_ldj.dtype))
+ ldj_sum = _max_precision_sum(ldj_sum, component_ldj)
+
+ # Transform inputs for the next bijector.
+ y = bij.forward(x, **kwargs)
+ y_event_ndims = bij.forward_event_ndims(x_event_ndims, **kwargs)
+
+ # Check if the inputs to this bijector have increased degrees of freedom
+ # due to some upstream bijector. We assume that the upstream bijector
+ # produced a valid LDJ, but this one does not (unless LDJ is 0, in which
+ # case it doesn't matter).
+ increased_dof = ps.reduce_any(nest.flatten(increased_dof))
+ if self.validate_event_size:
+ assertions.append(self._maybe_warn_increased_dof(
+ component_name=bij.name,
+ component_ldj=component_ldj,
+ increased_dof=increased_dof))
+ increased_dof |= (_event_size(y, y_event_ndims)
+ > _event_size(x, x_event_ndims))
+
+ increased_dof = nest_util.broadcast_structure(y, increased_dof)
+ return y, y_event_ndims, increased_dof
+
+ increased_dof = nest_util.broadcast_structure(event_ndims, False)
+ self._call_walk_forward(step, x, event_ndims, increased_dof, **kwargs)
+ with tf.control_dependencies([x for x in assertions if x is not None]):
+ return tf.identity(ldj_sum, name='fldj')
+
+ def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
+ """Compute inverse_log_det_jacobian over the composition."""
+ with self._name_and_control_scope(name):
+ dtype = self.forward_dtype(**kwargs)
+ y = nest_util.convert_to_nested_tensor(
+ y, name='y', dtype_hint=dtype,
+ dtype=None if bijector.SKIP_DTYPE_CHECKS else dtype,
+ allow_packing=True)
+ event_ndims = nest_util.coerce_structure(
+ self.inverse_min_event_ndims, event_ndims)
+ return self._inverse_log_det_jacobian(y, event_ndims, **kwargs)
+
+ def _inverse_log_det_jacobian(self, y, event_ndims, **kwargs):
+ # Container for accumulated LDJ.
+ ldj_sum = tf.convert_to_tensor(0., dtype=tf.float32)
+ # Container for accumulated assertions.
+ assertions = []
+
+ def step(bij, y, y_event_ndims, increased_dof=False, **kwargs): # pylint: disable=missing-docstring
+ nonlocal ldj_sum
+
+ # Compute the LDJ for this step, and add it to the rolling sum.
+ component_ldj = tf.convert_to_tensor(
+ bij.inverse_log_det_jacobian(y, y_event_ndims, **kwargs),
+ dtype_hint=ldj_sum.dtype)
+
+ if not dtype_util.is_floating(component_ldj.dtype):
+ raise TypeError(('Nested bijector "{}" of Composition "{}" returned '
+ 'LDJ with a non-floating dtype: {}')
+ .format(bij.name, self.name, component_ldj.dtype))
+ ldj_sum = _max_precision_sum(ldj_sum, component_ldj)
+
+ # Transform inputs for the next bijector.
+ x = bij.inverse(y, **kwargs)
+ x_event_ndims = bij.inverse_event_ndims(y_event_ndims, **kwargs)
+
+ # Check if the inputs to this bijector have increased degrees of freedom
+ # due to some upstream bijector. We assume that the upstream bijector
+ # produced a valid LDJ, but this one does not (unless LDJ is 0, in which
+ # case it doesn't matter).
+ increased_dof = ps.reduce_any(nest.flatten(increased_dof))
+ if self.validate_event_size:
+ assertions.append(self._maybe_warn_increased_dof(
+ component_name=bij.name,
+ component_ldj=component_ldj,
+ increased_dof=increased_dof))
+ increased_dof |= (_event_size(x, x_event_ndims)
+ > _event_size(y, y_event_ndims))
+
+ increased_dof = nest_util.broadcast_structure(x, increased_dof)
+ return x, x_event_ndims, increased_dof
+
+ increased_dof = nest_util.broadcast_structure(event_ndims, False)
+ self._call_walk_inverse(step, y, event_ndims, increased_dof, **kwargs)
+ with tf.control_dependencies([x for x in assertions if x is not None]):
+ return tf.identity(ldj_sum, name='ildj')
+
+ def _maybe_warn_increased_dof(self,
+ component_name,
+ component_ldj,
+ increased_dof):
+ """Warns or raises when `increased_dof` is True."""
+ # Short-circuit when the component LDJ is statically zero.
+ if (tf.get_static_value(tf.rank(component_ldj)) == 0
+ and tf.get_static_value(component_ldj) == 0):
+ return
+
+ # Short-circuit when increased_dof is statically False.
+ increased_dof_ = tf.get_static_value(increased_dof)
+ if increased_dof_ is False: # pylint: disable=g-bool-id-comparison
+ return
+
+ error_message = (
+ 'Nested component "{}" in composition "{}" operates on inputs '
+ 'with increased degrees of freedom. This may result in an '
+ 'incorrect log_det_jacobian.'
+ ).format(component_name, self.name)
+
+ # When validate_args is True, we raise on increased DoF.
+ if self._validate_args:
+ if increased_dof_:
+ raise ValueError(error_message)
+ return assert_util.assert_equal(False, increased_dof, error_message)
+
+ # Otherwise, we print a warning and continue.
+ return ps.cond(
+ pred=increased_dof,
+ false_fn=tf.no_op,
+ true_fn=lambda: tf.print( # pylint: disable=g-long-lambda
+ 'WARNING: ' + error_message, output_stream=sys.stderr))
+
+ ###
+ ### Trivial traversals
+ ###
+
+ def _forward(self, x, **kwargs):
+ return self._call_walk_forward(
+ lambda b, x, **kwargs: b.forward(x, **kwargs),
+ x, **kwargs)
+
+ def _inverse(self, y, **kwargs):
+ if not self._is_injective: # pylint: disable=protected-access
+ raise NotImplementedError(
+ 'Invert is not implemented for compositions of '
+ 'non-injective bijectors.')
+ return self._call_walk_inverse(
+ lambda b, y, **kwargs: b.inverse(y, **kwargs),
+ y, **kwargs)
+
+ def _forward_event_shape_tensor(self, x, **kwargs):
+ return self._call_walk_forward(
+ lambda b, x, **kwds: b.forward_event_shape_tensor(x, **kwds),
+ x, **kwargs)
+
+ def _inverse_event_shape_tensor(self, y, **kwargs):
+ return self._call_walk_inverse(
+ lambda b, y, **kwds: b.inverse_event_shape_tensor(y, **kwds),
+ y, **kwargs)
+
+ def _forward_event_shape(self, x, **kwargs):
+ return self._call_walk_forward(
+ lambda b, x, **kwds: b.forward_event_shape(x, **kwds),
+ x, **kwargs)
+
+ def _inverse_event_shape(self, y, **kwargs):
+ return self._call_walk_inverse(
+ lambda b, y, **kwds: b.inverse_event_shape(y, **kwds),
+ y, **kwargs)
+
+ def _forward_dtype(self, x, **kwargs):
+ return self._call_walk_forward(
+ lambda b, x, **kwds: b.forward_dtype(x, **kwds),
+ x, **kwargs)
+
+ def _inverse_dtype(self, y, **kwargs):
+ return self._call_walk_inverse(
+ lambda b, y, **kwds: b.inverse_dtype(y, **kwds),
+ y, **kwargs)
+
+ def forward_event_ndims(self, event_ndims, **kwargs):
+ if self._has_static_min_event_ndims:
+ return super(Composition, self).forward_event_ndims(event_ndims, **kwargs)
+ return self._call_walk_forward(
+ lambda b, nd, **kwds: b.forward_event_ndims(nd, **kwds),
+ event_ndims, **kwargs)
+
+ def inverse_event_ndims(self, event_ndims, **kwargs):
+ if self._has_static_min_event_ndims:
+ return super(Composition, self).inverse_event_ndims(event_ndims, **kwargs)
+ return self._call_walk_inverse(
+ lambda b, nd, **kwds: b.inverse_event_ndims(nd, **kwds),
+ event_ndims, **kwargs)
+
diff --git a/tensorflow_probability/python/bijectors/fill_scale_tril.py b/tensorflow_probability/python/bijectors/fill_scale_tril.py
index 0718e11c80..134fe81b03 100644
--- a/tensorflow_probability/python/bijectors/fill_scale_tril.py
+++ b/tensorflow_probability/python/bijectors/fill_scale_tril.py
@@ -123,5 +123,6 @@ def __init__(self,
[transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector),
fill_triangular.FillTriangular()],
validate_args=validate_args,
+ validate_event_size=False,
parameters=parameters,
name=name)
diff --git a/tensorflow_probability/python/bijectors/fill_triangular.py b/tensorflow_probability/python/bijectors/fill_triangular.py
index 5a5d2c62c6..c788a7e147 100644
--- a/tensorflow_probability/python/bijectors/fill_triangular.py
+++ b/tensorflow_probability/python/bijectors/fill_triangular.py
@@ -100,6 +100,10 @@ def _forward_log_det_jacobian(self, x):
def _inverse_log_det_jacobian(self, y):
return tf.zeros([], dtype=y.dtype)
+ @property
+ def _is_permutation(self):
+ return True
+
def _forward_event_shape(self, input_shape):
batch_shape, d = input_shape[:-1], tf.compat.dimension_value(
input_shape[-1])
diff --git a/tensorflow_probability/python/bijectors/gev_cdf.py b/tensorflow_probability/python/bijectors/gev_cdf.py
index b40bed524c..c30c060993 100644
--- a/tensorflow_probability/python/bijectors/gev_cdf.py
+++ b/tensorflow_probability/python/bijectors/gev_cdf.py
@@ -120,14 +120,18 @@ def _is_increasing(cls):
def _forward(self, x):
loc = tf.convert_to_tensor(self.loc)
scale = tf.convert_to_tensor(self.scale)
- concentration = tf.convert_to_tensor(self.concentration)
+ conc = tf.convert_to_tensor(self.concentration)
with tf.control_dependencies(
self._maybe_assert_valid_x(
- x, loc=loc, scale=scale, concentration=concentration)):
+ x, loc=loc, scale=scale, concentration=conc)):
z = (x - loc) / scale
+
+ equal_zero = tf.equal(conc, 0.)
+ # deal with case that gradient is N/A when conc = 0
+ safe_conc = tf.where(equal_zero, tf.ones_like(conc), conc)
t = tf.where(
- tf.equal(concentration, 0.), tf.math.exp(-z),
- tf.math.exp(-tf.math.log1p(z * concentration) / concentration))
+ equal_zero, tf.math.exp(-z),
+ tf.math.exp(-tf.math.log1p(z * safe_conc) / safe_conc))
return tf.exp(-t)
def _inverse(self, y):
@@ -135,24 +139,34 @@ def _inverse(self, y):
t = -tf.math.log(y)
conc = tf.convert_to_tensor(self.concentration)
+
+ equal_zero = tf.equal(conc, 0.)
+ # deal with case that gradient is N/A when conc = 0
+ safe_conc = tf.where(equal_zero, tf.ones_like(conc), conc)
+
z = tf.where(
- tf.equal(conc, 0.), -tf.math.log(t),
- tf.math.expm1(-tf.math.log(t) * conc) / conc)
+ equal_zero, -tf.math.log(t),
+ tf.math.expm1(-tf.math.log(t) * safe_conc) / safe_conc)
return self.loc + self.scale * z
def _forward_log_det_jacobian(self, x):
loc = tf.convert_to_tensor(self.loc)
scale = tf.convert_to_tensor(self.scale)
- concentration = tf.convert_to_tensor(self.concentration)
+ conc = tf.convert_to_tensor(self.concentration)
with tf.control_dependencies(
self._maybe_assert_valid_x(
- x, loc=loc, scale=scale, concentration=concentration)):
+ x, loc=loc, scale=scale, concentration=conc)):
z = (x - loc) / scale
+
+ equal_zero = tf.equal(conc, 0.)
+ # deal with case that gradient is N/A when conc = 0
+ safe_conc = tf.where(equal_zero, tf.ones_like(conc), conc)
+
log_t = tf.where(
- tf.equal(concentration, 0.), -z,
- -tf.math.log1p(z * concentration) / concentration)
- return (tf.math.multiply_no_nan(concentration + 1., log_t) -
+ equal_zero, -z,
+ -tf.math.log1p(z * safe_conc) / safe_conc)
+ return (tf.math.multiply_no_nan(conc + 1., log_t) -
tf.math.exp(log_t) - tf.math.log(scale))
def _inverse_log_det_jacobian(self, y):
diff --git a/tensorflow_probability/python/bijectors/identity.py b/tensorflow_probability/python/bijectors/identity.py
index 0fcd2fff9d..9122374863 100644
--- a/tensorflow_probability/python/bijectors/identity.py
+++ b/tensorflow_probability/python/bijectors/identity.py
@@ -66,6 +66,10 @@ def __init__(self, validate_args=False, name="identity"):
def _is_increasing(cls):
return True
+ @property
+ def _is_permutation(self):
+ return True
+
def _forward(self, x):
return x
diff --git a/tensorflow_probability/python/bijectors/invert.py b/tensorflow_probability/python/bijectors/invert.py
index b369a01f28..6c87546533 100644
--- a/tensorflow_probability/python/bijectors/invert.py
+++ b/tensorflow_probability/python/bijectors/invert.py
@@ -97,6 +97,10 @@ def inverse_event_shape_tensor(self, output_shape):
def bijector(self):
return self._bijector
+ @property
+ def _is_permutation(self):
+ return self.bijector._is_permutation # pylint: disable=protected-access
+
def _internal_is_increasing(self, **kwargs):
return self.bijector._internal_is_increasing(**kwargs) # pylint: disable=protected-access
@@ -112,8 +116,14 @@ def inverse_log_det_jacobian(self, y, event_ndims, **kwargs):
def forward_log_det_jacobian(self, x, event_ndims, **kwargs):
return self.bijector.inverse_log_det_jacobian(x, event_ndims, **kwargs)
- def forward_dtype(self, dtype, **kwargs):
+ def forward_dtype(self, dtype=bijector_lib.UNSPECIFIED, **kwargs):
return self.bijector.inverse_dtype(dtype, **kwargs)
- def inverse_dtype(self, dtype, **kwargs):
+ def inverse_dtype(self, dtype=bijector_lib.UNSPECIFIED, **kwargs):
return self.bijector.forward_dtype(dtype, **kwargs)
+
+ def inverse_event_ndims(self, event_ndims, **kwargs):
+ return self.bijector.forward_event_ndims(event_ndims, **kwargs)
+
+ def forward_event_ndims(self, event_ndims, **kwargs):
+ return self.bijector.inverse_event_ndims(event_ndims, **kwargs)
diff --git a/tensorflow_probability/python/bijectors/joint_map.py b/tensorflow_probability/python/bijectors/joint_map.py
new file mode 100644
index 0000000000..6a332477b0
--- /dev/null
+++ b/tensorflow_probability/python/bijectors/joint_map.py
@@ -0,0 +1,127 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""JointMap bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.compat.v2 as tf
+from tensorflow_probability.python.bijectors import composition
+from tensorflow_probability.python.internal import nest_util
+from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
+
+
+__all__ = [
+ 'JointMap',
+]
+
+
+class JointMap(composition.Composition):
+ """Bijector which applies a structure of bijectors in parallel.
+
+ This is the "structured" counterpart to `Chain`. Whereas `Chain` applies an
+ ordered sequence, JointMap applies a structure of transformations to a
+ matching structure of inputs.
+
+ Example Use:
+
+ ```python
+ exp = Exp()
+ scale = Scale(2.)
+ parallel = JointMap({'a': exp, 'b': scale})
+ x = {'a': 1., 'b': 2.}
+
+ parallel.forward(x)
+ # = {'a': exp.forward(x['a']), 'b': scale.forward(x['b'])}
+ # = {'a': tf.exp(1.), 'b': 2. * 2.}
+
+ parallel.inverse(x)
+ # = {'a': exp.inverse(x['a']), 'b': scale.inverse(x['b'])}
+ # = {'a': tf.log(1.), 'b': 2. / 2.}
+ ```
+
+ Bijectors need not be a dictionary; it could be a list, tuple, list of
+ dictionaries, or anything else supported by `tf.nest.map_structure`.
+ """
+
+ def __init__(self,
+ bijectors=None,
+ validate_args=False,
+ parameters=None,
+ name=None):
+ """Instantiates `JointMap` bijector.
+
+ Args:
+ bijectors: Structure of bijector instances to apply in parallel.
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ parameters: Locals dict captured by subclass constructor, to be used for
+ copy/slice re-instantiation operators.
+ name: Python `str`, name given to ops managed by this object.
+ Default value: automatically constructed, e.g.,
+ `jointmap_of_exp_and_softplus`.
+
+ Raises:
+ ValueError: if bijectors have different dtypes.
+ """
+ parameters = dict(locals()) if parameters is None else parameters
+
+ if not bijectors:
+ raise ValueError('`bijectors` must not be empty.')
+
+ if name is None:
+ name = ('jointmap_of_' +
+ '_and_'.join([b.name for b in nest.flatten(bijectors)]))
+ name = name.replace('/', '')
+ with tf.name_scope(name) as name:
+ # Structured dtypes are based on the non-wrapped input.
+ # Keep track of the non-wrapped structure of bijectors to correctly
+ # wrap inputs/outputs in _walk methods.
+ self._nested_structure = self._no_dependency(
+ nest.map_structure(lambda b: None, bijectors))
+
+ super(JointMap, self).__init__(
+ bijectors=bijectors,
+ validate_args=validate_args,
+ parameters=parameters,
+ name=name,
+ # JointMap and other bijectors that operate independently on
+ # parts of structured inputs do not have statically-known
+ # `min_event_ndims`. Infer the input/output structures, and fill them
+ # with `None`.
+ forward_min_event_ndims=nest.map_structure(
+ lambda b: nest_util.broadcast_structure( # pylint: disable=g-long-lambda
+ b.forward_min_event_ndims, None), bijectors),
+ inverse_min_event_ndims=nest.map_structure(
+ lambda b: nest_util.broadcast_structure( # pylint: disable=g-long-lambda
+ b.forward_min_event_ndims, None), bijectors),
+ )
+
+ def _walk_forward(self, step_fn, xs, **kwargs):
+ """Applies `transform_fn` to `x` in parallel over nested bijectors."""
+ # Set check_types to False to support bij-structures wrapped by Trackable.
+ return nest.map_structure_up_to(
+ self._nested_structure,
+ lambda bij, x: step_fn(bij, x, **kwargs.get(bij.name, {})), # pylint: disable=unnecessary-lambda
+ self._bijectors, xs, check_types=False)
+
+ def _walk_inverse(self, step_fn, ys, **kwargs):
+ """Applies `transform_fn` to `y` in parallel over nested bijectors."""
+ # Set check_types to False to support bij-structures wrapped by Trackable.
+ return nest.map_structure_up_to(
+ self._nested_structure,
+ lambda bij, y: step_fn(bij, y, **kwargs.get(bij.name, {})), # pylint: disable=unnecessary-lambda
+ self._bijectors, ys, check_types=False)
diff --git a/tensorflow_probability/python/bijectors/joint_map_test.py b/tensorflow_probability/python/bijectors/joint_map_test.py
new file mode 100644
index 0000000000..d6818601b6
--- /dev/null
+++ b/tensorflow_probability/python/bijectors/joint_map_test.py
@@ -0,0 +1,134 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""JointMap Tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+
+import numpy as np
+import tensorflow.compat.v2 as tf
+from tensorflow_probability.python import bijectors as tfb
+from tensorflow_probability.python.internal import test_util
+
+
+@test_util.test_all_tf_execution_regimes
+class JointMapBijectorTest(test_util.TestCase):
+ """Tests the correctness of the Y = JointMap({nested}) transformation."""
+
+ def assertShapeIs(self, expect_shape, observed):
+ self.assertEqual(expect_shape, np.asarray(observed))
+
+ def testBijector(self):
+ bij = tfb.JointMap({
+ 'a': tfb.Exp(),
+ 'b': tfb.Scale(2.),
+ 'c': tfb.Shift(3.)
+ })
+
+ a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32) # shape=[1, 2, 2]
+ b = np.asarray([[0, 4]], dtype=np.float32) # shape=[1, 2]
+ c = np.asarray([[5, 6]], dtype=np.float32) # shape=[1, 2]
+
+ inputs = {'a': a, 'b': b, 'c': c} # Could be inputs to forward or inverse.
+ event_ndims = {'a': 1, 'b': 0, 'c': 0}
+
+ self.assertStartsWith(bij.name, 'jointmap_of_exp_and_scale')
+ self.assertAllCloseNested({'a': np.exp(a), 'b': b * 2., 'c': c + 3},
+ self.evaluate(bij.forward(inputs)))
+ self.assertAllCloseNested({'a': np.log(a), 'b': b / 2., 'c': c - 3},
+ self.evaluate(bij.inverse(inputs)))
+
+ fldj = self.evaluate(bij.forward_log_det_jacobian(inputs, event_ndims))
+ self.assertEqual((1, 2), fldj.shape)
+ self.assertAllClose(np.sum(a, axis=-1) + np.log(2), fldj)
+
+ ildj = self.evaluate(bij.inverse_log_det_jacobian(inputs, event_ndims))
+ self.assertEqual((1, 2), ildj.shape)
+ self.assertAllClose(-np.log(a).sum(axis=-1) - np.log(2), ildj)
+
+ def testBijectorWithDeepStructure(self):
+ bij = tfb.JointMap({
+ 'a': tfb.Exp(),
+ 'bc': tfb.JointMap([
+ tfb.Scale(2.),
+ tfb.Shift(3.)
+ ])})
+
+ a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32) # shape=[1, 2, 2]
+ b = np.asarray([[0, 4]], dtype=np.float32) # shape=[1, 2]
+ c = np.asarray([[5, 6]], dtype=np.float32) # shape=[1, 2]
+
+ inputs = {'a': a, 'bc': [b, c]} # Could be inputs to forward or inverse.
+ event_ndims = {'a': 1, 'bc': [0, 0]}
+
+ self.assertStartsWith(bij.name, 'jointmap_of_exp_and_jointmap_of_')
+ self.assertAllCloseNested({'a': np.exp(a), 'bc': [b * 2., c + 3]},
+ self.evaluate(bij.forward(inputs)))
+ self.assertAllCloseNested({'a': np.log(a), 'bc': [b / 2., c - 3]},
+ self.evaluate(bij.inverse(inputs)))
+
+ fldj = self.evaluate(bij.forward_log_det_jacobian(inputs, event_ndims))
+ self.assertEqual((1, 2), fldj.shape)
+ self.assertAllClose(np.sum(a, axis=-1) + np.log(2), fldj)
+
+ ildj = self.evaluate(bij.inverse_log_det_jacobian(inputs, event_ndims))
+ self.assertEqual((1, 2), ildj.shape)
+ self.assertAllClose(-np.log(a).sum(axis=-1) - np.log(2), ildj)
+
+ def testBatchShapeBroadcasts(self):
+ bij = tfb.JointMap({'a': tfb.Exp(), 'b': tfb.Scale(10.)},
+ validate_args=True)
+ self.assertStartsWith(bij.name, 'jointmap_of_exp_and_scale')
+
+ a = np.asarray([[[1, 2]], [[2, 3]]], dtype=np.float32) # shape=[2, 1, 2]
+ b = np.asarray([[0, 1, 2]], dtype=np.float32) # shape=[1, 3]
+
+ inputs = {'a': a, 'b': b} # Could be inputs to forward or inverse.
+
+ self.assertAllClose(
+ a.sum(axis=-1) + np.log(10.),
+ self.evaluate(bij.forward_log_det_jacobian(inputs, {'a': 1, 'b': 0})))
+
+ self.assertAllClose(
+ a.sum(axis=-1) + 3 * np.log(10.),
+ self.evaluate(bij.forward_log_det_jacobian(inputs, {'a': 1, 'b': 1})))
+
+ @test_util.disable_test_for_backend(
+ disable_numpy=True,
+ reason='NumPy backend overrides dtypes in __init__.')
+ def testMixedDtypeLogDetJacobian(self):
+ bij = tfb.JointMap({
+ 'a': tfb.Scale(tf.constant(1, dtype=tf.float16)),
+ 'b': tfb.Scale(tf.constant(2, dtype=tf.float32)),
+ 'c': tfb.Scale(tf.constant(3, dtype=tf.float64))
+ })
+
+ fldj = bij.forward_log_det_jacobian(
+ x={'a': 4, 'b': 5, 'c': 6},
+ event_ndims=dict.fromkeys('abc', 0))
+ self.assertDTypeEqual(fldj, np.float64)
+ self.assertAllClose(np.log(1) + np.log(2) + np.log(3), self.evaluate(fldj))
+
+ def test_inverse_has_event_ndims(self):
+ bij_reshape = tfb.Invert(tfb.JointMap([tfb.Reshape([])]))
+ bij_reshape.inverse_event_ndims([10]) # expect [9]
+ self.assertEqual(bij_reshape.inverse_event_ndims([10]), [9])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow_probability/python/bijectors/lambertw_transform_test.py b/tensorflow_probability/python/bijectors/lambertw_transform_test.py
index 8e6ca9ad2d..d0f611259b 100644
--- a/tensorflow_probability/python/bijectors/lambertw_transform_test.py
+++ b/tensorflow_probability/python/bijectors/lambertw_transform_test.py
@@ -119,9 +119,8 @@ def testTailBijectorLogDetJacobian(self, value, delta, expected):
else:
value = np.float64(value)
expected = np.float64(expected)
- self.assertAllClose(ht._inverse_log_det_jacobian(
- tf.convert_to_tensor(value)),
- expected)
+ self.assertAllClose(expected,
+ ht.inverse_log_det_jacobian(value, event_ndims=0))
class LambertWGaussianizationTest(test_util.TestCase, parameterized.TestCase):
diff --git a/tensorflow_probability/python/bijectors/masked_autoregressive.py b/tensorflow_probability/python/bijectors/masked_autoregressive.py
index 8d8bfd0f09..b6ee69e146 100644
--- a/tensorflow_probability/python/bijectors/masked_autoregressive.py
+++ b/tensorflow_probability/python/bijectors/masked_autoregressive.py
@@ -323,7 +323,7 @@ def _bijector_fn(x, **condition_kwargs):
bijectors.append(shift_lib.Shift(shift))
if log_scale is not None:
bijectors.append(scale_lib.Scale(log_scale=log_scale))
- return chain.Chain(bijectors)
+ return chain.Chain(bijectors, validate_event_size=False)
bijector_fn = _bijector_fn
diff --git a/tensorflow_probability/python/bijectors/permute.py b/tensorflow_probability/python/bijectors/permute.py
index 0f14b92336..6e66356866 100644
--- a/tensorflow_probability/python/bijectors/permute.py
+++ b/tensorflow_probability/python/bijectors/permute.py
@@ -132,6 +132,11 @@ def permutation(self):
def axis(self):
return self._axis
+ @property
+ def _is_permutation(self):
+ # Definitely a permutation.
+ return True
+
def _forward(self, x):
y = tf.gather(x, self.permutation, axis=self.axis)
tensorshape_util.set_shape(y, x.shape)
diff --git a/tensorflow_probability/python/bijectors/reshape.py b/tensorflow_probability/python/bijectors/reshape.py
index dc23b33904..c92000596a 100644
--- a/tensorflow_probability/python/bijectors/reshape.py
+++ b/tensorflow_probability/python/bijectors/reshape.py
@@ -190,6 +190,10 @@ def _parameter_control_dependencies(self, is_init):
self._event_shape_out, self.validate_args))
return assertions
+ @property
+ def _is_permutation(self):
+ return True
+
def _forward(self, x):
output_shape, output_tensorshape = _replace_event_shape_in_shape_tensor(
ps.shape(x), self._event_shape_in, self._event_shape_out,
diff --git a/tensorflow_probability/python/bijectors/restructure.py b/tensorflow_probability/python/bijectors/restructure.py
new file mode 100644
index 0000000000..a103e18240
--- /dev/null
+++ b/tensorflow_probability/python/bijectors/restructure.py
@@ -0,0 +1,224 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Restructure Bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+import tensorflow.compat.v2 as tf
+
+from tensorflow_probability.python.bijectors import bijector
+from tensorflow_probability.python.internal import nest_util
+from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
+
+
+__all__ = [
+ 'Restructure',
+]
+
+
+def unique_token_set(source_structure):
+ """Checks that structured tokens are unique, and returns the set of values."""
+ flat_tokens = nest.flatten(source_structure)
+ flat_token_set = set(flat_tokens)
+ if len(flat_tokens) != len(flat_token_set):
+ raise ValueError('Restructure tokens must be unique. Saw: {}'
+ .format(source_structure))
+ return flat_token_set
+
+
+class Restructure(bijector.Bijector):
+ """Converts between nested structures of Tensors.
+
+ This is useful when constructing non-trivial chains of multipart bijectors.
+ It partitions inputs into different logical "blocks", which may be fed as
+ arguments to downstream multipart bijectors.
+
+ Example Usage:
+ ```python
+
+ # What restructure does:
+ restructure = Restructure({
+ 'foo': [0, 1],
+ 'bar': [3, 2],
+ 'baz': [4, 5, 6]
+ })
+
+ # Note that x is a *python-list* of tensors.
+ # To permute elements of an individual Tensor, see `tfb.Permute`.
+ x = [1, 2, 4, 8, 16, 32, 64]
+
+ assert restructure.forward(x) == {
+ 'foo': [1, 2],
+ 'bar': [8, 4],
+ 'baz': [16, 32, 64]
+ }
+
+ # Where restructure is useful:
+ complex_bijector = Chain([
+ # Apply different transformations to each block.
+ JointMap({
+ 'foo': ScaleMatVecLinearOperator(...), # Operates on the full block
+ 'bar': ScaleMatVecLinearOperator(...), # Operates on the full block
+ 'baz': [Exp(), Scale(10.), Shift(-1.)] # Different bijectors for each
+ }),
+ # Group the tensor into logical blocks.
+ Restructure({
+ 'foo': [0, 1],
+ 'bar': [3, 2],
+ 'baz': [4, 5, 6],
+ }),
+ # Split an input tensor into 7 chunks.
+ Split([2, 4, 6, 8, 10, 12, 14])
+ ])
+ ```
+ """
+
+ def __init__(self,
+ output_structure,
+ input_structure=None,
+ name='restructure'):
+ """Creates a `Restructure` bijector.
+
+ Args:
+ output_structure: A tf.nest-compatible structure of tokens describing the
+ output of `forward` (equivalently, the input of `inverse`).
+ input_structure: A tf.nest-compatible structure of tokens describing the
+ input to `forward`. If unspecified, a default structure is inferred from
+ `output_structure`. The default structure expects a `list` if tokens are
+ integers, or a `dict` if the tokens are strings.
+ name: Name of this bijector.
+ Raises:
+ ValueError: If tokens are duplicated, or a required default structure
+ cannot be inferred.
+ """
+ parameters = dict(locals())
+
+ # Get the flat set of tokens, making sure they're unique.
+ output_tokens = unique_token_set(output_structure)
+
+ # Create a default input_structure when it isn't provided.
+ if input_structure is None:
+ # If all tokens are strings, assume input is a dict.
+ if all(isinstance(tok, six.string_types) for tok in output_tokens):
+ input_structure = {token: token for token in output_tokens}
+
+ # If tokens are contiguous 0-based ints, return a list.
+ elif (all(isinstance(tok, six.integer_types) for tok in output_tokens)
+ and output_tokens == set(range(len(output_tokens)))):
+ input_structure = list(range(len(output_tokens)))
+
+ # Otherwise, we cannot infer a default structure.
+ else:
+ raise ValueError(('Tokens in output_structure must be all strings or '
+ 'contiguous 0-based indices when input_structure '
+ 'is not specified. Saw: {}'
+ ).format(output_tokens))
+
+ # If input_structure _is_ provided, make sure tokens are unique
+ # and that they match the output_structure tokens.
+ else:
+ input_tokens = unique_token_set(output_structure)
+ if input_tokens != output_tokens:
+ raise ValueError(('The `input_structure` tokens must match the '
+ '`output_structure` tokens exactly. Missing from '
+ '`input_structure`: {}. Missing from '
+ '`output_structure`: {}.').format(
+ output_tokens - input_tokens,
+ input_tokens - output_tokens))
+
+ self._input_structure = self._no_dependency(input_structure)
+ self._output_structure = self._no_dependency(output_structure)
+ super(Restructure, self).__init__(
+ forward_min_event_ndims=nest_util.broadcast_structure(
+ self._input_structure, None),
+ inverse_min_event_ndims=nest_util.broadcast_structure(
+ self._output_structure, None),
+ is_constant_jacobian=True,
+ validate_args=False,
+ parameters=parameters,
+ name=name)
+
+ @property
+ def _is_permutation(self):
+ return True
+
+ def _forward(self, x):
+ flat_dict = {}
+ nest.map_structure_up_to(
+ self._input_structure, flat_dict.setdefault,
+ self._input_structure, x)
+ result = nest.map_structure(flat_dict.pop, self._output_structure)
+ assert not flat_dict # Should never happen!
+ return result
+
+ def _inverse(self, y):
+ flat_dict = {}
+ nest.map_structure_up_to(
+ self._output_structure, flat_dict.setdefault,
+ self._output_structure, y)
+ result = nest.map_structure(flat_dict.pop, self._input_structure)
+ assert not flat_dict # Should never happen!
+ return result
+
+ ### Shape/ndims/etc transformations do the same thing as forward/inverse.
+
+ def _forward_event_shape(self, x_shape, **kwargs):
+ return self._forward(x_shape)
+
+ def _inverse_event_shape(self, y_shape, **kwargs):
+ return self._inverse(y_shape)
+
+ def _forward_event_shape_tensor(self, x_shape, **kwargs):
+ return self._forward(x_shape)
+
+ def _inverse_event_shape_tensor(self, y_shape, **kwargs):
+ return self._inverse(y_shape)
+
+ def _forward_dtype(self, x_dtype, **kwargs):
+ return self._forward(x_dtype)
+
+ def _inverse_dtype(self, y_dtype, **kwargs):
+ return self._inverse(y_dtype)
+
+ def forward_event_ndims(self, x_ndims, **kwargs):
+ return self._forward(x_ndims)
+
+ def inverse_event_ndims(self, y_ndims, **kwargs):
+ return self._inverse(y_ndims)
+
+ ### Skip convert-to-tensor/caching so we can rearrange nested sub-structures.
+
+ def _call_forward(self, x, name, **kwargs):
+ with self._name_and_control_scope(name):
+ return self._forward(x, **kwargs)
+
+ def _call_inverse(self, y, name, **kwargs):
+ with self._name_and_control_scope(name):
+ return self._inverse(y, **kwargs)
+
+ ### Restructure always has constant 0 LDJ.
+ # Override top-level methods, since min_event_ndims is undefined.
+
+ def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
+ with self._name_and_control_scope(name):
+ return tf.zeros([], tf.float32)
+
+ def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
+ with self._name_and_control_scope(name):
+ return tf.zeros([], tf.float32)
diff --git a/tensorflow_probability/python/bijectors/restructure_test.py b/tensorflow_probability/python/bijectors/restructure_test.py
new file mode 100644
index 0000000000..e441a5590a
--- /dev/null
+++ b/tensorflow_probability/python/bijectors/restructure_test.py
@@ -0,0 +1,123 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Restructure Tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+
+import tensorflow.compat.v2 as tf
+from tensorflow_probability.python import bijectors as tfb
+from tensorflow_probability.python.internal import test_util
+
+
+@test_util.test_all_tf_execution_regimes
+class RestructureBijectorTest(test_util.TestCase):
+ """Tests the correctness of the Y = Restructure({nested}) transformation."""
+
+ def testListToStructure(self):
+ bij = tfb.Restructure({
+ 'foo': [1, 2],
+ 'bar': 0,
+ 'baz': (3, 4)
+ })
+
+ x = [[1, 2, 3], [4, 5, 6], 7., 8., 9.]
+ x_ndims = [1, 1, 0, 0, 0]
+
+ y = {
+ 'foo': [[4, 5, 6], 7.],
+ 'bar': [1, 2, 3],
+ 'baz': (8., 9.),
+ }
+ y_ndims = {'foo': [1, 0], 'bar': 1, 'baz': (0, 0)}
+
+ # Invert assertion arguments to infer structure from bijector output.
+ self.assertAllEqualNested(bij.forward(x), y, check_types=True)
+ self.assertAllEqualNested(bij.inverse(y), x, check_types=True)
+
+ self.assertAllEqualNested(
+ 0., self.evaluate(bij.forward_log_det_jacobian(x, x_ndims)))
+ self.assertAllEqualNested(
+ 0., self.evaluate(bij.inverse_log_det_jacobian(y, y_ndims)))
+
+ def testDictToStructure(self):
+ bij = tfb.Restructure({
+ 'foo': ['b', 'c'],
+ 'bar': 'a',
+ 'baz': ('d', 'e')
+ })
+
+ x = {'a': [1, 2, 3],
+ 'b': [4, 5, 6],
+ 'c': 7., 'd': 8., 'e': 9.}
+ x_ndims = {'a': 1, 'b': 1, 'c': 0, 'd': 0, 'e': 0}
+
+ y = {'foo': [[4, 5, 6], 7.],
+ 'bar': [1, 2, 3],
+ 'baz': (8., 9.)}
+ y_ndims = {'foo': [1, 0], 'bar': 1, 'baz': (0, 0)}
+
+ # Invert assertion arguments to infer structure from bijector output.
+ self.assertAllEqualNested(bij.forward(x), y, check_types=True)
+ self.assertAllEqualNested(bij.inverse(y), x, check_types=True)
+
+ self.assertAllEqualNested(
+ 0., self.evaluate(bij.forward_log_det_jacobian(x, x_ndims)))
+ self.assertAllEqualNested(
+ 0., self.evaluate(bij.inverse_log_det_jacobian(y, y_ndims)))
+
+ def testStructureToStructure(self):
+ bij = tfb.Restructure(
+ input_structure={'foo': [0, 1], 'bar': 2, 'baz': (3, 4)},
+ output_structure={'zip': [1, 2, 3], 'zap': 0, 'zop': 4})
+
+ x = {'foo': [0., [1.]],
+ 'bar': [[2.]],
+ 'baz': ([[[3.]]], [[[[4.]]]])}
+ x_ndims = {'foo': [0, 1], 'bar': 2, 'baz': (3, 4)}
+
+ y = {'zip': [[1.], [[2.]], [[[3.]]]],
+ 'zap': 0.,
+ 'zop': [[[[4.]]]]}
+ y_ndims = {'zip': [1, 2, 3], 'zap': 0, 'zop': 4}
+
+ # Invert assertion arguments to infer structure from bijector output.
+ self.assertAllEqualNested(bij.forward(x), y, check_types=True)
+ self.assertAllEqualNested(bij.inverse(y), x, check_types=True)
+
+ self.assertAllEqualNested(
+ 0., self.evaluate(bij.forward_log_det_jacobian(x, x_ndims)))
+ self.assertAllEqualNested(
+ 0., self.evaluate(bij.inverse_log_det_jacobian(y, y_ndims)))
+
+ def testEventNdims(self):
+ bij = tfb.Restructure(
+ input_structure={'foo': [0, 1], 'bar': 2, 'baz': (3, 4)},
+ output_structure={'zip': [1, 2, 3], 'zap': 0, 'zop': 4})
+
+ x_ndims = {'foo': [10, 11], 'bar': 12, 'baz': (13, 14)}
+ y_ndims = {'zip': [11, 12, 13], 'zap': 10, 'zop': 14}
+
+ self.assertAllEqualNested(
+ y_ndims, bij.forward_event_ndims(x_ndims), check_types=True)
+ self.assertAllEqualNested(
+ x_ndims, bij.inverse_event_ndims(y_ndims), check_types=True)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow_probability/python/bijectors/sigmoid_test.py b/tensorflow_probability/python/bijectors/sigmoid_test.py
index 701e45a395..208383b2b0 100644
--- a/tensorflow_probability/python/bijectors/sigmoid_test.py
+++ b/tensorflow_probability/python/bijectors/sigmoid_test.py
@@ -71,7 +71,7 @@ class ShiftedScaledSigmoidBijectorTest(test_util.TestCase):
"""Tests correctness of Sigmoid with `low` and `high` parameters set."""
def testBijector(self):
- low = np.array([-3., 0., 5.]).astype(np.float32)
+ low = np.array([[-3.], [0.], [5.]]).astype(np.float32)
high = 12.
bijector = tfb.Sigmoid(low=low, high=high, validate_args=True)
@@ -79,10 +79,11 @@ def testBijector(self):
equivalent_bijector = tfb.Chain([
tfb.Shift(shift=low), tfb.Scale(scale=high-low), tfb.Sigmoid()])
- x = [[[1.], [2.], [-5.], [-0.3]]]
+ x = [[[1., 2., -5., -0.3]]]
y = self.evaluate(equivalent_bijector.forward(x))
self.assertAllClose(y, self.evaluate(bijector.forward(x)))
- self.assertAllClose(x, self.evaluate(bijector.inverse(y)[..., -1:]))
+ self.assertAllClose(
+ x, self.evaluate(bijector.inverse(y)[..., :1, :]), rtol=1e-5)
self.assertAllClose(
self.evaluate(equivalent_bijector.inverse_log_det_jacobian(
y, event_ndims=1)),
diff --git a/tensorflow_probability/python/bijectors/soft_clip.py b/tensorflow_probability/python/bijectors/soft_clip.py
index b7be8de803..6804115fdb 100644
--- a/tensorflow_probability/python/bijectors/soft_clip.py
+++ b/tensorflow_probability/python/bijectors/soft_clip.py
@@ -288,15 +288,16 @@ def _forward(self, x):
return self._chain.forward(x)
def _forward_log_det_jacobian(self, x):
- return self._chain._forward_log_det_jacobian(x) # pylint: disable=protected-access
+ return self._chain.forward_log_det_jacobian(x, self.forward_min_event_ndims)
def _inverse(self, y):
with tf.control_dependencies(self._assert_valid_inverse_input(y)):
- return self._chain._inverse(y) # pylint: disable=protected-access
+ return self._chain.inverse(y) # pylint: disable=protected-access
def _inverse_log_det_jacobian(self, y):
with tf.control_dependencies(self._assert_valid_inverse_input(y)):
- return self._chain._inverse_log_det_jacobian(y) # pylint: disable=protected-access
+ return self._chain.inverse_log_det_jacobian(
+ y, self.inverse_min_event_ndims)
def _assert_valid_inverse_input(self, y):
assertions = []
diff --git a/tensorflow_probability/python/bijectors/split.py b/tensorflow_probability/python/bijectors/split.py
index 3e7189378c..a6a207e41b 100644
--- a/tensorflow_probability/python/bijectors/split.py
+++ b/tensorflow_probability/python/bijectors/split.py
@@ -123,6 +123,10 @@ def split_sizes(self):
def axis(self):
return self._axis
+ @property
+ def _is_permutation(self):
+ return True
+
def _inverse(self, y):
"""Returns the inverse `Bijector` evaluation, i.e., X = g^{-1}(Y).
@@ -320,11 +324,15 @@ def _inverse_event_shape(self, output_shapes):
to the number of splits.
Returns:
- inverse_event_shape_tensor: `TensorShape` indicating event-portion shape
- after applying `inverse`. Possibly unknown.
+ inverse_event_shape: `TensorShape` indicating event-portion shape after
+ applying `inverse`. Possibly unknown.
"""
self._validate_output_shapes(output_shapes)
- shapes = [tf.TensorShape(s).as_list() for s in output_shapes]
+ shapes = []
+ for s in output_shapes:
+ if tensorshape_util.rank(s) is None:
+ return tf.TensorShape(None)
+ shapes.append(tf.TensorShape(s).as_list())
axis = tf.get_static_value(self.axis)
if self.split_sizes is None:
@@ -369,9 +377,10 @@ def _forward_dtype(self, dtype):
return [dtype] * self.num_splits
def _inverse_dtype(self, dtype):
- if any(d != dtype[0] for d in dtype):
+ dtype = set(dtype) - {None}
+ if len(dtype) > 1:
raise ValueError('All dtypes must be equivalent.')
- return dtype[0]
+ return dtype.pop() if dtype else None
def _parameter_control_dependencies(self, is_init):
assertions = []
diff --git a/tensorflow_probability/python/bijectors/transpose.py b/tensorflow_probability/python/bijectors/transpose.py
index 37e7e91bb2..0fc532357e 100644
--- a/tensorflow_probability/python/bijectors/transpose.py
+++ b/tensorflow_probability/python/bijectors/transpose.py
@@ -189,6 +189,10 @@ def perm(self):
def rightmost_transposed_ndims(self):
return self._rightmost_transposed_ndims
+ @property
+ def _is_permutation(self):
+ return True
+
def _is_increasing(self):
if self.forward_min_event_ndims == 0:
return True
diff --git a/tensorflow_probability/python/build_defs.bzl b/tensorflow_probability/python/build_defs.bzl
index 3fb79cb28a..7e66c9c826 100644
--- a/tensorflow_probability/python/build_defs.bzl
+++ b/tensorflow_probability/python/build_defs.bzl
@@ -266,6 +266,7 @@ def multi_substrate_py_test(
jax_size = None,
numpy_size = None,
srcs = [],
+ main = None,
deps = [],
tags = [],
numpy_tags = [],
@@ -286,6 +287,9 @@ def multi_substrate_py_test(
numpy_size: A size override for the numpy target.
srcs: As with `py_test`. These will have a `genrule` emitted to rewrite
NumPy and JAX variants, writing the test file into a subdirectory.
+ main: As with `py_test`. If this does not match "{name}.py", then we
+ suppress the genrule that rewrites "{name}.py", since the typical
+ use-case of the `main` argument is a secondary, i.e. GPU, test.
deps: As with `py_test`. The list is rewritten to depend on
substrate-specific libraries for substrate variants.
tags: Tags global to this test target. NumPy also gets a `'tfp_numpy'`
@@ -309,7 +313,7 @@ def multi_substrate_py_test(
name = "{}.tf".format(name),
size = size,
srcs = srcs,
- main = "{}.py".format(name),
+ main = main or "{}.py".format(name),
deps = deps,
tags = tags,
srcs_version = srcs_version,
@@ -321,18 +325,19 @@ def multi_substrate_py_test(
if "numpy" not in disabled_substrates:
numpy_srcs = _substrate_srcs(srcs, "numpy")
- native.genrule(
- name = "rewrite_{}_numpy".format(name),
- srcs = srcs,
- outs = numpy_srcs,
- cmd = "$(location {}) $(SRCS) > $@".format(REWRITER_TARGET),
- exec_tools = [REWRITER_TARGET],
- )
+ if main == None or main == "{}.py".format(name):
+ native.genrule(
+ name = "rewrite_{}_numpy".format(name),
+ srcs = srcs,
+ outs = numpy_srcs,
+ cmd = "$(location {}) $(SRCS) > $@".format(REWRITER_TARGET),
+ exec_tools = [REWRITER_TARGET],
+ )
native.py_test(
name = "{}.numpy".format(name),
size = numpy_size or size,
srcs = numpy_srcs,
- main = _substrate_src("{}.py".format(name), "numpy"),
+ main = _substrate_src(main or "{}.py".format(name), "numpy"),
deps = _substrate_deps(deps, "numpy"),
tags = tags + ["tfp_numpy"] + numpy_tags,
srcs_version = srcs_version,
@@ -344,20 +349,21 @@ def multi_substrate_py_test(
if "jax" not in disabled_substrates:
jax_srcs = _substrate_srcs(srcs, "jax")
- native.genrule(
- name = "rewrite_{}_jax".format(name),
- srcs = srcs,
- outs = jax_srcs,
- cmd = "$(location {}) $(SRCS) --numpy_to_jax > $@".format(REWRITER_TARGET),
- exec_tools = [REWRITER_TARGET],
- )
+ if main == None or main == "{}.py".format(name):
+ native.genrule(
+ name = "rewrite_{}_jax".format(name),
+ srcs = srcs,
+ outs = jax_srcs,
+ cmd = "$(location {}) $(SRCS) --numpy_to_jax > $@".format(REWRITER_TARGET),
+ exec_tools = [REWRITER_TARGET],
+ )
jax_deps = _substrate_deps(deps, "jax")
# [internal] Add JAX build dep
native.py_test(
name = "{}.jax".format(name),
size = jax_size or size,
srcs = jax_srcs,
- main = _substrate_src("{}.py".format(name), "jax"),
+ main = _substrate_src(main or "{}.py".format(name), "jax"),
deps = jax_deps,
tags = tags + ["tfp_jax"] + jax_tags,
srcs_version = srcs_version,
diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD
index 2781a094d0..7c0c1fb2e6 100644
--- a/tensorflow_probability/python/distributions/BUILD
+++ b/tensorflow_probability/python/distributions/BUILD
@@ -63,6 +63,7 @@ multi_substrate_py_library(
":empirical",
":exp_gamma",
":exponential",
+ ":exponentially_modified_gaussian",
":finite_discrete",
":gamma",
":gamma_gamma",
@@ -71,6 +72,7 @@ multi_substrate_py_library(
":generalized_normal",
":generalized_pareto",
":geometric",
+ ":gev",
":gumbel",
":half_cauchy",
":half_normal",
@@ -741,6 +743,24 @@ multi_substrate_py_library(
],
)
+multi_substrate_py_library(
+ name = "gev",
+ srcs = ["gev.py"],
+ deps = [
+ ":kullback_leibler",
+ ":transformed_distribution",
+ ":uniform",
+ # numpy dep,
+ # tensorflow dep,
+ "//tensorflow_probability/python/bijectors:gev_cdf",
+ "//tensorflow_probability/python/bijectors:invert",
+ "//tensorflow_probability/python/bijectors:softplus",
+ "//tensorflow_probability/python/internal:distribution_util",
+ "//tensorflow_probability/python/internal:dtype_util",
+ "//tensorflow_probability/python/internal:tensor_util",
+ ],
+)
+
multi_substrate_py_library(
name = "half_cauchy",
srcs = ["half_cauchy.py"],
@@ -1110,6 +1130,7 @@ multi_substrate_py_library(
":normal",
# numpy dep,
# tensorflow dep,
+ "//tensorflow_probability/python/bijectors:bijector",
"//tensorflow_probability/python/bijectors:chain",
"//tensorflow_probability/python/bijectors:cholesky_outer_product",
"//tensorflow_probability/python/bijectors:correlation_cholesky",
@@ -1121,6 +1142,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/internal:reparameterization",
"//tensorflow_probability/python/internal:tensor_util",
"//tensorflow_probability/python/internal:tensorshape_util",
+ "//tensorflow_probability/python/math:numeric",
"//tensorflow_probability/python/math:special",
"//tensorflow_probability/python/util:seed_stream",
],
@@ -1886,6 +1908,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:reparameterization",
"//tensorflow_probability/python/internal:tensor_util",
+ "//tensorflow_probability/python/random:random_ops",
],
)
@@ -2528,6 +2551,19 @@ multi_substrate_py_test(
],
)
+multi_substrate_py_test(
+ name = "gev_test",
+ srcs = ["gev_test.py"],
+ jax_size = "medium",
+ deps = [
+ # numpy dep,
+ # scipy dep,
+ # tensorflow dep,
+ "//tensorflow_probability",
+ "//tensorflow_probability/python/internal:test_util",
+ ],
+)
+
multi_substrate_py_test(
name = "half_cauchy_test",
srcs = ["half_cauchy_test.py"],
@@ -2600,6 +2636,27 @@ multi_substrate_py_test(
],
)
+multi_substrate_py_test(
+ name = "independent_test_gpu",
+ srcs = ["independent_test.py"],
+ disabled_substrates = ["numpy"],
+ jax_size = "medium",
+ main = "independent_test.py",
+ python_version = "PY3",
+ srcs_version = "PY3",
+ tags = ["requires-gpu-nvidia"],
+ deps = [
+ # hypothesis dep,
+ # numpy dep,
+ # scipy dep,
+ # tensorflow dep,
+ "//tensorflow_probability",
+ "//tensorflow_probability/python/internal:hypothesis_testlib",
+ "//tensorflow_probability/python/internal:tensorshape_util",
+ "//tensorflow_probability/python/internal:test_util",
+ ],
+)
+
multi_substrate_py_test(
name = "inverse_gamma_test",
size = "medium",
@@ -2660,12 +2717,13 @@ multi_substrate_py_test(
size = "large",
srcs = ["joint_distribution_auto_batched_test.py"],
numpy_tags = ["notap"],
- shard_count = 2,
+ shard_count = 7,
deps = [
# numpy dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:test_util",
+ # tensorflow/compiler/jit dep,
],
)
@@ -3279,6 +3337,24 @@ multi_substrate_py_test(
],
)
+multi_substrate_py_test(
+ name = "sample_test_gpu",
+ srcs = ["sample_test.py"],
+ disabled_substrates = ["numpy"],
+ jax_size = "medium",
+ main = "sample_test.py",
+ python_version = "PY3",
+ srcs_version = "PY3",
+ tags = ["requires-gpu-nvidia"],
+ deps = [
+ # absl/testing:parameterized dep,
+ # numpy dep,
+ # tensorflow dep,
+ "//tensorflow_probability",
+ "//tensorflow_probability/python/internal:test_util",
+ ],
+)
+
multi_substrate_py_test(
name = "sinh_arcsinh_test",
srcs = ["sinh_arcsinh_test.py"],
@@ -3562,6 +3638,7 @@ py_test(
"//tensorflow_probability/python/internal:tensor_util",
"//tensorflow_probability/python/internal:tensorshape_util",
"//tensorflow_probability/python/internal:test_util",
+ "//tensorflow_probability/python/math/psd_kernels:hypothesis_testlib",
# tensorflow/compiler/jit dep,
],
)
@@ -3673,3 +3750,33 @@ py_binary(
"//tensorflow_probability/python/distributions:hypothesis_testlib",
],
)
+
+multi_substrate_py_library(
+ name = "exponentially_modified_gaussian",
+ srcs = ["exponentially_modified_gaussian.py"],
+ deps = [
+ ":distribution",
+ ":exponential",
+ ":normal",
+ # tensorflow dep,
+ "//tensorflow_probability/python/internal:dtype_util",
+ "//tensorflow_probability/python/internal:prefer_static",
+ "//tensorflow_probability/python/internal:reparameterization",
+ "//tensorflow_probability/python/internal:tensor_util",
+ ],
+)
+
+multi_substrate_py_test(
+ name = "exponentially_modified_gaussian_test",
+ srcs = ["exponentially_modified_gaussian_test.py"],
+ jax_size = "medium",
+ # Disable numpy test for now because a bug in the types returned by special_math.ndtr
+ numpy_tags = ["notap"],
+ deps = [
+ # numpy dep,
+ # scipy dep,
+ # tensorflow dep,
+ "//tensorflow_probability",
+ "//tensorflow_probability/python/internal:test_util",
+ ],
+)
diff --git a/tensorflow_probability/python/distributions/__init__.py b/tensorflow_probability/python/distributions/__init__.py
index 5d75a05724..8a9f81e225 100644
--- a/tensorflow_probability/python/distributions/__init__.py
+++ b/tensorflow_probability/python/distributions/__init__.py
@@ -41,6 +41,7 @@
from tensorflow_probability.python.distributions.distribution import Distribution
from tensorflow_probability.python.distributions.doublesided_maxwell import DoublesidedMaxwell
from tensorflow_probability.python.distributions.empirical import Empirical
+from tensorflow_probability.python.distributions.exponentially_modified_gaussian import ExponentiallyModifiedGaussian
from tensorflow_probability.python.distributions.exp_gamma import ExpGamma
from tensorflow_probability.python.distributions.exp_gamma import ExpInverseGamma
from tensorflow_probability.python.distributions.exponential import Exponential
@@ -53,6 +54,7 @@
from tensorflow_probability.python.distributions.generalized_pareto import GeneralizedPareto
from tensorflow_probability.python.distributions.geometric import Geometric
from tensorflow_probability.python.distributions.gumbel import Gumbel
+from tensorflow_probability.python.distributions.gev import GeneralizedExtremeValue
from tensorflow_probability.python.distributions.half_cauchy import HalfCauchy
from tensorflow_probability.python.distributions.half_normal import HalfNormal
from tensorflow_probability.python.distributions.half_student_t import HalfStudentT
@@ -171,6 +173,7 @@
'DoublesidedMaxwell',
'VectorDeterministic',
'Empirical',
+ 'ExponentiallyModifiedGaussian',
'ExpGamma',
'ExpInverseGamma',
'Exponential',
@@ -185,6 +188,7 @@
'GaussianProcessRegressionModel',
'VariationalGaussianProcess',
'Gumbel',
+ 'GeneralizedExtremeValue',
'HalfCauchy',
'HalfNormal',
'HalfStudentT',
diff --git a/tensorflow_probability/python/distributions/distribution.py b/tensorflow_probability/python/distributions/distribution.py
index b938bd062c..9ec0f3129e 100644
--- a/tensorflow_probability/python/distributions/distribution.py
+++ b/tensorflow_probability/python/distributions/distribution.py
@@ -527,7 +527,7 @@ def __init__(self,
for i, t in enumerate(graph_parents):
if t is None or not tf.is_tensor(t):
raise ValueError('Graph parent item %d is not a Tensor; %s.' % (i, t))
- self._dtype = dtype
+ self._dtype = self._no_dependency(dtype)
self._reparameterization_type = reparameterization_type
self._allow_nan_stats = allow_nan_stats
self._validate_args = validate_args
@@ -605,14 +605,6 @@ def parameter_properties(cls, dtype=tf.float32, num_classes=None):
instances.
"""
with tf.name_scope('parameter_properties'):
- # Instead of a dtype, subclass implementations take an `eps` argument
- # representing a small value in the requested dtype. This may be used to
- # avoid constraint boundaries, e.g., Softplus(low=eps) will avoid
- # infinitesimally small values for a scale param. The dtype
- # may be recovered as `eps.dtype`.
- # Numpy defines `eps` using the difference between 1.0 and the next
- # smallest representable float larger than 1.0. This is approximately
- # 1.19e-07 in float32, 2.22e-16 in float64, and 0.00098 in float16.
return cls._parameter_properties(dtype, num_classes=num_classes)
@classmethod
@@ -817,13 +809,14 @@ def copy(self, **override_parameters_kwargs):
return slicing.batch_slice(self, self._params_event_ndims(),
override_parameters_kwargs, Ellipsis)
except NotImplementedError:
- parameters = dict(self.parameters, **override_parameters_kwargs)
- d = type(self)(**parameters)
- # pylint: disable=protected-access
- d._parameters = parameters
- d._parameters_sanitized = True
- # pylint: enable=protected-access
- return d
+ pass
+ parameters = dict(self.parameters, **override_parameters_kwargs)
+ d = type(self)(**parameters)
+ # pylint: disable=protected-access
+ d._parameters = self._no_dependency(parameters)
+ d._parameters_sanitized = True
+ # pylint: enable=protected-access
+ return d
def _batch_shape_tensor(self):
raise NotImplementedError(
@@ -1309,7 +1302,7 @@ def variance(self, name='variance', **kwargs):
return self._variance(**kwargs)
except NotImplementedError:
try:
- return tf.square(self._stddev(**kwargs))
+ return tf.nest.map_structure(tf.square, self._stddev(**kwargs))
except NotImplementedError:
pass
raise
@@ -1344,7 +1337,7 @@ def stddev(self, name='stddev', **kwargs):
return self._stddev(**kwargs)
except NotImplementedError:
try:
- return tf.sqrt(self._variance(**kwargs))
+ return tf.nest.map_structure(tf.sqrt, self._variance(**kwargs))
except NotImplementedError:
pass
raise
diff --git a/tensorflow_probability/python/distributions/distribution_properties_test.py b/tensorflow_probability/python/distributions/distribution_properties_test.py
index 282cd9a86f..02f9064f06 100644
--- a/tensorflow_probability/python/distributions/distribution_properties_test.py
+++ b/tensorflow_probability/python/distributions/distribution_properties_test.py
@@ -44,6 +44,7 @@
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.internal import test_util
+from tensorflow_probability.python.math.psd_kernels import hypothesis_testlib as kernel_hps
WORKING_PRECISION_TEST_BLOCK_LIST = (
@@ -115,7 +116,7 @@ def eligibility_filter(name):
hp.note('Trying distribution {}'.format(
self.evaluate_dict(dist.parameters)))
seed = test_util.test_seed()
- with tfp_hps.no_tf_rank_errors():
+ with tfp_hps.no_tf_rank_errors(), kernel_hps.no_pd_errors():
samples = dist.sample(5, seed=seed)
self.assertIn(samples.dtype, [tf.float32, tf.int32])
self.assertEqual(dist.log_prob(samples).dtype, tf.float32)
@@ -125,7 +126,7 @@ def log_prob_function(dist, x):
dist64 = tf.nest.map_structure(
tensor_to_f64, tfe.as_composite(dist), expand_composites=True)
- with tfp_hps.no_tf_rank_errors():
+ with tfp_hps.no_tf_rank_errors(), kernel_hps.no_pd_errors():
result64 = log_prob_function(dist64, tensor_to_f64(samples))
self.assertEqual(result64.dtype, tf.float64)
diff --git a/tensorflow_probability/python/distributions/empirical.py b/tensorflow_probability/python/distributions/empirical.py
index 31510e1749..7401d8dd8c 100644
--- a/tensorflow_probability/python/distributions/empirical.py
+++ b/tensorflow_probability/python/distributions/empirical.py
@@ -123,7 +123,7 @@ def __init__(self,
"""Initialize `Empirical` distributions.
Args:
- samples: Numeric `Tensor` of shape [B1, ..., Bk, S, E1, ..., En]`,
+ samples: Numeric `Tensor` of shape `[B1, ..., Bk, S, E1, ..., En]`,
`k, n >= 0`. Samples or batches of samples on which the distribution
is based. The first `k` dimensions index into a batch of independent
distributions. Length of `S` dimension determines number of samples
diff --git a/tensorflow_probability/python/distributions/exponential.py b/tensorflow_probability/python/distributions/exponential.py
index a59ac6c32b..d8bfed874c 100644
--- a/tensorflow_probability/python/distributions/exponential.py
+++ b/tensorflow_probability/python/distributions/exponential.py
@@ -71,6 +71,7 @@ class Exponential(gamma.Gamma):
def __init__(self,
rate,
+ force_probs_to_zero_outside_support=False,
validate_args=False,
allow_nan_stats=True,
name="Exponential"):
@@ -79,6 +80,17 @@ def __init__(self,
Args:
rate: Floating point tensor, equivalent to `1 / mean`. Must contain only
positive values.
+ force_probs_to_zero_outside_support: Python `bool`. When `True`, negative
+ and non-integer values are evaluated "strictly": `cdf` returns
+ `0`, `sf` returns `1`, and `log_cdf` and `log_sf` correspond. When
+ `False`, the implementation is free to save computation (and TF graph
+ size) by evaluating something that matches the Exponential cdf at
+ non-negative values `x` but produces an unrestricted result on
+ other inputs. In the case of Exponential distribution, the `cdf`
+ formula in this case happens to be the continuous function
+ `1 - exp(rate * value)`.
+ Note that this function is not itself a cdf function.
+ Default value: `False`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
@@ -99,6 +111,8 @@ def __init__(self,
rate,
name="rate",
dtype=dtype_util.common_dtype([rate], dtype_hint=tf.float32))
+ self._force_probs_to_zero_outside_support = (
+ force_probs_to_zero_outside_support)
super(Exponential, self).__init__(
concentration=1.,
rate=self._rate,
@@ -120,12 +134,29 @@ def _parameter_properties(cls, dtype, num_classes=None):
def rate(self):
return self._rate
+ @property
+ def force_probs_to_zero_outside_support(self):
+ """Return 0 probabilities on non-integer inputs."""
+ return self._force_probs_to_zero_outside_support
+
def _cdf(self, value):
- return -tf.math.expm1(-self.rate * value)
+ cdf = -tf.math.expm1(-self.rate * value)
+
+ if self.force_probs_to_zero_outside_support:
+ # Set cdf = 0 when value is less than 0.
+ cdf = tf.where(value < 0., tf.zeros_like(cdf), cdf)
+
+ return cdf
def _log_survival_function(self, value):
rate = tf.convert_to_tensor(self._rate)
- return self._log_prob(value, rate=rate) - tf.math.log(rate)
+ log_sf = self._log_prob(value, rate=rate) - tf.math.log(rate)
+
+ if self.force_probs_to_zero_outside_support:
+ # Set log_survival_function = 0 when value is less than 0.
+ log_sf = tf.where(value < 0., tf.zeros_like(log_sf), log_sf)
+
+ return log_sf
def _sample_n(self, n, seed=None):
rate = tf.convert_to_tensor(self.rate)
diff --git a/tensorflow_probability/python/distributions/exponentially_modified_gaussian.py b/tensorflow_probability/python/distributions/exponentially_modified_gaussian.py
new file mode 100644
index 0000000000..0ab75add55
--- /dev/null
+++ b/tensorflow_probability/python/distributions/exponentially_modified_gaussian.py
@@ -0,0 +1,244 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""The exponentially modified Gaussian distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+
+import tensorflow.compat.v2 as tf
+from tensorflow_probability.python.bijectors import identity as identity_bijector
+from tensorflow_probability.python.bijectors import softplus as softplus_bijector
+from tensorflow_probability.python.distributions import distribution
+from tensorflow_probability.python.distributions import exponential as exponential_lib
+from tensorflow_probability.python.distributions import normal as normal_lib
+from tensorflow_probability.python.internal import assert_util
+from tensorflow_probability.python.internal import dtype_util
+from tensorflow_probability.python.internal import parameter_properties
+from tensorflow_probability.python.internal import prefer_static
+from tensorflow_probability.python.internal import reparameterization
+from tensorflow_probability.python.internal import samplers
+from tensorflow_probability.python.internal import special_math
+from tensorflow_probability.python.internal import tensor_util
+
+__all__ = [
+ 'ExponentiallyModifiedGaussian',
+]
+
+
+class ExponentiallyModifiedGaussian(distribution.Distribution):
+ """Exponentially modified Gaussian distribution.
+
+ #### Mathematical details
+
+ The exponentially modified Gaussian distribution is the sum of a normal
+ distribution and an exponential distribution.
+ ```none
+ X ~ Normal(loc, scale)
+ Y ~ Exponential(rate)
+ Z = X + Y
+ ```
+ is equivalent to
+ ```none
+ Z ~ ExponentiallyModifiedGaussian(loc, scale, rate)
+ ```
+
+ #### Examples
+ ```python
+ tfd = tfp.distributions
+
+ # Define a single scalar ExponentiallyModifiedGaussian distribution
+ dist = tfd.ExponentiallyModifiedGaussian(loc=0., scale=1., rate=3.)
+
+ # Evaluate the pdf at 1, returing a scalar.
+ dist.prob(1.)
+ ```
+
+
+ """
+
+ def __init__(self,
+ loc,
+ scale,
+ rate,
+ validate_args=False,
+ allow_nan_stats=True,
+ name='ExponentiallyModifiedGaussian'):
+ """Construct an exponentially-modified Gaussian distribution.
+
+ The Gaussian distribution has mean `loc` and stddev `scale`,
+ and Exponential distribution has rate parameter `rate`.
+
+ The parameters `loc`, `scale`, and `rate` must be shaped in a way that
+ supports broadcasting (e.g. `loc + scale + rate` is a valid operation).
+ Args:
+ loc: Floating-point `Tensor`; the means of the distribution(s).
+ scale: Floating-point `Tensor`; the stddevs of the distribution(s). Must
+ contain only positive values.
+ rate: Floating-point `Tensor`; the rate parameter for the exponential
+ distribution.
+ validate_args: Python `bool`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
+ (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
+ result is undefined. When `False`, an exception is raised if one or more
+ of the statistic's batch members are undefined.
+ name: Python `str` name prefixed to Ops created by this class.
+
+ Raises:
+ TypeError: if `loc`, `scale`, and `rate` are not all the same `dtype`.
+ """
+ parameters = dict(locals())
+ with tf.name_scope(name) as name:
+ dtype = dtype_util.common_dtype([loc, scale, rate], dtype_hint=tf.float32)
+ self._loc = tensor_util.convert_nonref_to_tensor(
+ loc, dtype=dtype, name='loc')
+ self._scale = tensor_util.convert_nonref_to_tensor(
+ scale, dtype=dtype, name='scale')
+ self._rate = tensor_util.convert_nonref_to_tensor(
+ rate, dtype=dtype, name='rate')
+ super(ExponentiallyModifiedGaussian, self).__init__(
+ dtype=dtype,
+ reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
+ validate_args=validate_args,
+ allow_nan_stats=allow_nan_stats,
+ parameters=parameters,
+ name=name)
+
+ @staticmethod
+ def _param_shapes(sample_shape):
+ return dict(
+ zip(('loc', 'scale', 'rate'),
+ ([tf.convert_to_tensor(sample_shape, dtype=tf.int32)] * 3)))
+
+ @classmethod
+ def _parameter_properties(cls, dtype, num_classes=None):
+ return dict(
+ loc=parameter_properties.ParameterProperties(),
+ scale=parameter_properties.ParameterProperties(
+ default_constraining_bijector_fn=(
+ lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))),
+ rate=parameter_properties.ParameterProperties(
+ default_constraining_bijector_fn=(
+ lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))))
+
+ @classmethod
+ def _params_event_ndims(cls):
+ return dict(loc=0, scale=0, rate=0)
+
+ @property
+ def loc(self):
+ """Distribution parameter for the mean of the normal distribution."""
+ return self._loc
+
+ @property
+ def scale(self):
+ """Distribution parameter for standard deviation of the normal distribution."""
+ return self._scale
+
+ @property
+ def rate(self):
+ """Distribution parameter for rate parameter of exponential distribution."""
+ return self._rate
+
+ def _batch_shape_tensor(self, loc=None, scale=None, rate=None):
+ return prefer_static.broadcast_shape(
+ prefer_static.shape(self.loc if loc is None else loc),
+ prefer_static.broadcast_shape(
+ prefer_static.shape(self.scale if scale is None else scale),
+ prefer_static.shape(self.rate if rate is None else rate)))
+
+ def _batch_shape(self):
+ return tf.broadcast_static_shape(
+ self.loc.shape,
+ tf.broadcast_static_shape(self.scale.shape, self.rate.shape))
+
+ def _event_shape_tensor(self):
+ return tf.constant([], dtype=tf.int32)
+
+ def _event_shape(self):
+ return tf.TensorShape([])
+
+ def _sample_n(self, n, seed=None):
+ normal_seed, exp_seed = samplers.split_seed(seed, salt='emg_sample')
+ # need to make sure component distributions are broadcast appropriately
+ # for correct generation of samples
+ loc = tf.convert_to_tensor(self.loc)
+ rate = tf.convert_to_tensor(self.rate)
+ scale = tf.convert_to_tensor(self.scale)
+ batch_shape = self._batch_shape_tensor(loc, scale, rate)
+ loc_broadcast = tf.broadcast_to(loc, batch_shape)
+ rate_broadcast = tf.broadcast_to(rate, batch_shape)
+ normal_dist = normal_lib.Normal(loc=loc_broadcast, scale=scale)
+ exp_dist = exponential_lib.Exponential(rate_broadcast)
+ x = normal_dist.sample(n, normal_seed)
+ y = exp_dist.sample(n, exp_seed)
+ return x + y
+
+ def _log_prob(self, x):
+ loc = tf.convert_to_tensor(self.loc)
+ rate = tf.convert_to_tensor(self.rate)
+ scale = tf.convert_to_tensor(self.scale)
+ two = dtype_util.as_numpy_dtype(x.dtype)(2.)
+ z = (x - loc) / scale
+ w = rate * scale
+ return (tf.math.log(rate) + w / two * (w - 2 * z) +
+ special_math.log_ndtr(z - w))
+
+ def _cdf(self, x):
+ rate = tf.convert_to_tensor(self.rate)
+ x_centralized = x - self.loc
+ u = rate * x_centralized
+ v = rate * self.scale
+ vsquared = tf.square(v)
+ return special_math.ndtr(x_centralized / self.scale) - tf.exp(
+ -u + vsquared / 2. + special_math.log_ndtr((u - vsquared) / v))
+
+ def _mean(self):
+ return self.loc + 1 / self.rate
+
+ def _variance(self):
+ return tf.square(self.scale) + 1 / tf.square(self.rate)
+
+ def _parameter_control_dependencies(self, is_init):
+ assertions = []
+
+ if is_init:
+ try:
+ self._batch_shape()
+ except ValueError:
+ raise ValueError(
+ 'Arguments `loc`, `scale`, and `rate` must have compatible shapes; '
+ 'loc.shape={}, scale.shape={}, rate.shape={}.'.format(
+ self.loc.shape, self.scale.shape, self.rate.shape))
+ # We don't bother checking the shapes in the dynamic case because
+ # all member functions access both arguments anyway.
+
+ if is_init != tensor_util.is_ref(self.scale):
+ assertions.append(assert_util.assert_positive(
+ self.scale, message='Argument `scale` must be positive.'))
+
+ if is_init != tensor_util.is_ref(self.rate):
+ assertions.append(assert_util.assert_positive(
+ self.rate, message='Argument `rate` must be positive.'))
+
+ return assertions
+
+ def _default_event_space_bijector(self):
+ return identity_bijector.Identity(validate_args=self.validate_args)
diff --git a/tensorflow_probability/python/distributions/exponentially_modified_gaussian_test.py b/tensorflow_probability/python/distributions/exponentially_modified_gaussian_test.py
new file mode 100644
index 0000000000..5c307a0b77
--- /dev/null
+++ b/tensorflow_probability/python/distributions/exponentially_modified_gaussian_test.py
@@ -0,0 +1,269 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for ExponentiallyModifiedGaussian Distribution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+# Dependency imports
+
+import numpy as np
+from scipy import stats as sp_stats
+import tensorflow.compat.v2 as tf
+import tensorflow_probability as tfp
+from tensorflow_probability.python import distributions as tfd
+from tensorflow_probability.python.internal import dtype_util
+from tensorflow_probability.python.internal import test_util
+
+
+class _ExponentiallyModifiedGaussianTest(object):
+
+ def _test_param_shapes(self, sample_shape, expected):
+ param_shapes = tfd.ExponentiallyModifiedGaussian.param_shapes(sample_shape)
+ mu_shape, sigma_shape, lambda_shape = param_shapes['loc'], param_shapes[
+ 'scale'], param_shapes['rate']
+ self.assertAllEqual(expected, self.evaluate(mu_shape))
+ self.assertAllEqual(expected, self.evaluate(sigma_shape))
+ self.assertAllEqual(expected, self.evaluate(lambda_shape))
+ mu = tf.zeros(mu_shape, dtype=self.dtype)
+ sigma = tf.ones(sigma_shape, dtype=self.dtype)
+ rate = tf.ones(lambda_shape, dtype=self.dtype)
+ self.assertAllEqual(
+ expected,
+ self.evaluate(
+ tf.shape(
+ tfd.ExponentiallyModifiedGaussian(
+ mu, sigma, rate,
+ validate_args=True).sample(seed=test_util.test_seed()))))
+
+ def _test_param_static_shapes(self, sample_shape, expected):
+ param_shapes = tfd.ExponentiallyModifiedGaussian.param_static_shapes(
+ sample_shape)
+ mu_shape, sigma_shape, lambda_shape = param_shapes['loc'], param_shapes[
+ 'scale'], param_shapes['rate']
+ self.assertEqual(expected, mu_shape)
+ self.assertEqual(expected, sigma_shape)
+ self.assertEqual(expected, lambda_shape)
+
+ # Currently fails for numpy due to a bug in the types returned by
+ # special_math.ndtr
+ # As of now, numpy testing is disabled in the BUILD file
+ def testSampleLikeArgsGetDistDType(self):
+ zero = dtype_util.as_numpy_dtype(self.dtype)(0.)
+ one = dtype_util.as_numpy_dtype(self.dtype)(1.)
+ dist = tfd.ExponentiallyModifiedGaussian(zero, one, one)
+ self.assertEqual(self.dtype, dist.dtype)
+ for method in ('log_prob', 'prob', 'log_cdf', 'cdf',
+ 'log_survival_function', 'survival_function'):
+ self.assertEqual(self.dtype, getattr(dist, method)(one).dtype, msg=method)
+
+ def testParamShapes(self):
+ sample_shape = [10, 3, 4]
+ self._test_param_shapes(sample_shape, sample_shape)
+ self._test_param_shapes(tf.constant(sample_shape), sample_shape)
+
+ def testParamStaticShapes(self):
+ sample_shape = [10, 3, 4]
+ self._test_param_static_shapes(sample_shape, sample_shape)
+ self._test_param_static_shapes(tf.TensorShape(sample_shape), sample_shape)
+
+ def testExponentiallyModifiedGaussianLogPDF(self):
+ batch_size = 6
+ mu = tf.constant([3.0] * batch_size, dtype=self.dtype)
+ sigma = tf.constant([math.sqrt(10.0)] * batch_size, dtype=self.dtype)
+ rate = tf.constant([2.] * batch_size, dtype=self.dtype)
+ x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=self.dtype)
+ exgaussian = tfd.ExponentiallyModifiedGaussian(
+ loc=mu, scale=sigma, rate=rate, validate_args=True)
+
+ log_pdf = exgaussian.log_prob(x)
+ self.assertAllEqual(
+ self.evaluate(exgaussian.batch_shape_tensor()), log_pdf.shape)
+ self.assertAllEqual(
+ self.evaluate(exgaussian.batch_shape_tensor()),
+ self.evaluate(log_pdf).shape)
+ self.assertAllEqual(exgaussian.batch_shape, log_pdf.shape)
+ self.assertAllEqual(exgaussian.batch_shape, self.evaluate(log_pdf).shape)
+
+ pdf = exgaussian.prob(x)
+ self.assertAllEqual(
+ self.evaluate(exgaussian.batch_shape_tensor()), pdf.shape)
+ self.assertAllEqual(
+ self.evaluate(exgaussian.batch_shape_tensor()),
+ self.evaluate(pdf).shape)
+ self.assertAllEqual(exgaussian.batch_shape, pdf.shape)
+ self.assertAllEqual(exgaussian.batch_shape, self.evaluate(pdf).shape)
+
+ expected_log_pdf = sp_stats.exponnorm(
+ 1. / (self.evaluate(rate) * self.evaluate(sigma)),
+ loc=self.evaluate(mu),
+ scale=self.evaluate(sigma)).logpdf(x)
+ self.assertAllClose(
+ expected_log_pdf, self.evaluate(log_pdf), atol=1e-5, rtol=1e-5)
+ self.assertAllClose(
+ np.exp(expected_log_pdf), self.evaluate(pdf), atol=1e-5, rtol=1e-5)
+
+ def testExponentiallyModifiedGaussianLogPDFMultidimensional(self):
+ batch_size = 6
+ mu = tf.constant([[3.0, -3.0]] * batch_size, dtype=self.dtype)
+ sigma = tf.constant(
+ [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size, dtype=self.dtype)
+ rate = tf.constant([[2., 3.]] * batch_size, dtype=self.dtype)
+ x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=self.dtype).T
+ exgaussian = tfd.ExponentiallyModifiedGaussian(
+ loc=mu, scale=sigma, rate=rate, validate_args=True)
+
+ log_pdf = exgaussian.log_prob(x)
+ log_pdf_values = self.evaluate(log_pdf)
+ self.assertEqual(log_pdf.shape, (6, 2))
+ self.assertAllEqual(
+ self.evaluate(exgaussian.batch_shape_tensor()), log_pdf.shape)
+ self.assertAllEqual(
+ self.evaluate(exgaussian.batch_shape_tensor()),
+ self.evaluate(log_pdf).shape)
+ self.assertAllEqual(exgaussian.batch_shape, log_pdf.shape)
+ self.assertAllEqual(exgaussian.batch_shape, self.evaluate(log_pdf).shape)
+
+ pdf = exgaussian.prob(x)
+ pdf_values = self.evaluate(pdf)
+ self.assertEqual(pdf.shape, (6, 2))
+ self.assertAllEqual(
+ self.evaluate(exgaussian.batch_shape_tensor()), pdf.shape)
+ self.assertAllEqual(
+ self.evaluate(exgaussian.batch_shape_tensor()), pdf_values.shape)
+ self.assertAllEqual(exgaussian.batch_shape, pdf.shape)
+ self.assertAllEqual(exgaussian.batch_shape, pdf_values.shape)
+
+ expected_log_pdf = sp_stats.exponnorm(
+ 1. / (self.evaluate(rate) * self.evaluate(sigma)),
+ loc=self.evaluate(mu),
+ scale=self.evaluate(sigma)).logpdf(x)
+ self.assertAllClose(expected_log_pdf, log_pdf_values, atol=1e-5, rtol=1e-5)
+ self.assertAllClose(
+ np.exp(expected_log_pdf), pdf_values, atol=1e-5, rtol=1e-5)
+
+ def testExponentiallyModifiedGaussianCDF(self):
+ batch_size = 50
+ mu = self._rng.randn(batch_size)
+ sigma = self._rng.rand(batch_size) + 1.0
+ rate = self._rng.rand(batch_size) + 1.0
+ x = np.linspace(-8.0, 8.0, batch_size).astype(self.dtype)
+
+ exgaussian = tfd.ExponentiallyModifiedGaussian(
+ loc=mu, scale=sigma, rate=rate, validate_args=True)
+ cdf = exgaussian.cdf(x)
+ self.assertAllEqual(
+ self.evaluate(exgaussian.batch_shape_tensor()), cdf.shape)
+ self.assertAllEqual(
+ self.evaluate(exgaussian.batch_shape_tensor()),
+ self.evaluate(cdf).shape)
+ self.assertAllEqual(exgaussian.batch_shape, cdf.shape)
+ self.assertAllEqual(exgaussian.batch_shape, self.evaluate(cdf).shape)
+ expected_cdf = sp_stats.exponnorm(
+ 1. / (rate * sigma), loc=mu, scale=sigma).cdf(x)
+ self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
+
+ @test_util.numpy_disable_gradient_test
+ def testFiniteGradientAtDifficultPoints(self):
+
+ def make_fn(attr):
+ x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(self.dtype)
+ return lambda m, s, l: getattr( # pylint: disable=g-long-lambda
+ tfd.ExponentiallyModifiedGaussian(
+ loc=m, scale=s, rate=l, validate_args=True), attr)(
+ x)
+
+ for attr in ['cdf', 'log_prob']:
+ value, grads = self.evaluate(
+ tfp.math.value_and_gradient(
+ make_fn(attr), [
+ tf.constant(0, self.dtype),
+ tf.constant(1, self.dtype),
+ tf.constant(1, self.dtype)
+ ]))
+ self.assertAllFinite(value)
+ self.assertAllFinite(grads[0])
+ self.assertAllFinite(grads[1])
+
+ def testNegativeSigmaFails(self):
+ with self.assertRaisesOpError('Argument `scale` must be positive.'):
+ exgaussian = tfd.ExponentiallyModifiedGaussian(
+ loc=[tf.constant(1., dtype=self.dtype)],
+ scale=[tf.constant(-5., dtype=self.dtype)],
+ rate=[tf.constant(1., dtype=self.dtype)],
+ validate_args=True,
+ name='G')
+ self.evaluate(exgaussian.mean())
+
+ def testExponentiallyModifiedGaussianShape(self):
+ mu = tf.constant([-3.0] * 5, dtype=self.dtype)
+ sigma = tf.constant(11.0, dtype=self.dtype)
+ rate = tf.constant(6.0, dtype=self.dtype)
+ exgaussian = tfd.ExponentiallyModifiedGaussian(
+ loc=mu, scale=sigma, rate=rate, validate_args=True)
+
+ self.assertEqual(self.evaluate(exgaussian.batch_shape_tensor()), [5])
+ self.assertEqual(exgaussian.batch_shape, tf.TensorShape([5]))
+ self.assertAllEqual(self.evaluate(exgaussian.event_shape_tensor()), [])
+ self.assertEqual(exgaussian.event_shape, tf.TensorShape([]))
+
+ def testVariableScale(self):
+ x = tf.Variable(1., dtype=self.dtype)
+ d = tfd.ExponentiallyModifiedGaussian(
+ loc=tf.constant(0., dtype=self.dtype),
+ scale=x,
+ rate=tf.constant(1., dtype=self.dtype),
+ validate_args=True)
+ self.evaluate([v.initializer for v in d.variables])
+ self.assertIs(x, d.scale)
+ with self.assertRaisesOpError('Argument `scale` must be positive.'):
+ with tf.control_dependencies([x.assign(-1.)]):
+ self.evaluate(d.mean())
+
+ def testIncompatibleArgShapes(self):
+ with self.assertRaisesRegexp(Exception, r'compatible shapes'):
+ d = tfd.ExponentiallyModifiedGaussian(
+ loc=tf.zeros([2, 3], dtype=self.dtype),
+ scale=tf.ones([4, 1], dtype=self.dtype),
+ rate=tf.ones([2, 3], dtype=self.dtype),
+ validate_args=True)
+ self.evaluate(d.mean())
+
+
+@test_util.test_all_tf_execution_regimes
+class ExponentiallyModifiedGaussianTestFloat32(
+ test_util.TestCase, _ExponentiallyModifiedGaussianTest):
+ dtype = np.float32
+
+ def setUp(self):
+ self._rng = np.random.RandomState(123)
+ super(ExponentiallyModifiedGaussianTestFloat32, self).setUp()
+
+
+@test_util.test_all_tf_execution_regimes
+class ExponentiallyModifiedGaussianTestFloat64(
+ test_util.TestCase, _ExponentiallyModifiedGaussianTest):
+ dtype = np.float64
+
+ def setUp(self):
+ self._rng = np.random.RandomState(123)
+ super(ExponentiallyModifiedGaussianTestFloat64, self).setUp()
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow_probability/python/distributions/geometric.py b/tensorflow_probability/python/distributions/geometric.py
index b95a9b4776..2a54f3d249 100644
--- a/tensorflow_probability/python/distributions/geometric.py
+++ b/tensorflow_probability/python/distributions/geometric.py
@@ -59,6 +59,7 @@ class Geometric(distribution.Distribution):
def __init__(self,
logits=None,
probs=None,
+ force_probs_to_zero_outside_support=False,
validate_args=False,
allow_nan_stats=True,
name='Geometric'):
@@ -75,6 +76,16 @@ def __init__(self,
represents the probability of success for independent Geometric
distributions and must be in the range `(0, 1]`. Only one of `logits`
or `probs` should be specified.
+ force_probs_to_zero_outside_support: Python `bool`. When `True`, negative
+ and non-integer values are evaluated "strictly": `log_prob` returns
+ `-inf`, `prob` returns `0`, and `cdf` and `sf` correspond. When
+ `False`, the implementation is free to save computation (and TF graph
+ size) by evaluating something that matches the Geometric pmf at integer
+ values `k` but produces an unrestricted result on other inputs. In the
+ case of Geometric distribution, the `log_prob` formula in this case
+ happens to be the continuous function `k * log(1 - probs) + log(probs)`.
+ Note that this function is not a normalized probability log-density.
+ Default value: `False`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
@@ -95,6 +106,8 @@ def __init__(self,
probs, dtype=dtype, name='probs')
self._logits = tensor_util.convert_nonref_to_tensor(
logits, dtype=dtype, name='logits')
+ self._force_probs_to_zero_outside_support = (
+ force_probs_to_zero_outside_support)
super(Geometric, self).__init__(
dtype=dtype,
reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
@@ -122,6 +135,11 @@ def probs(self):
"""Input argument `probs`."""
return self._probs
+ @property
+ def force_probs_to_zero_outside_support(self):
+ """Return 0 probabilities on non-integer inputs."""
+ return self._force_probs_to_zero_outside_support
+
def _batch_shape_tensor(self):
x = self._probs if self._logits is None else self._logits
return ps.shape(x)
@@ -182,7 +200,16 @@ def _log_prob(self, x):
if not self.validate_args:
# For consistency with cdf, we take the floor.
x = tf.floor(x)
- return tf.math.xlog1py(x, -probs) + tf.math.log(probs)
+
+ log_probs = tf.math.xlog1py(x, -probs) + tf.math.log(probs)
+
+ if self.force_probs_to_zero_outside_support:
+ # Set log_prob = -inf when value is less than 0, ie prob = 0.
+ log_probs = tf.where(
+ x < 0.,
+ dtype_util.as_numpy_dtype(x.dtype)(-np.inf),
+ log_probs)
+ return log_probs
def _entropy(self):
logits, probs = self._logits_and_probs_no_checks()
diff --git a/tensorflow_probability/python/distributions/gev.py b/tensorflow_probability/python/distributions/gev.py
new file mode 100644
index 0000000000..e752df6ecb
--- /dev/null
+++ b/tensorflow_probability/python/distributions/gev.py
@@ -0,0 +1,277 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""The GeneralizedExtremeValue distribution class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+import numpy as np
+import tensorflow.compat.v2 as tf
+
+from tensorflow_probability.python.bijectors import gev_cdf as gev_cdf_bijector
+from tensorflow_probability.python.bijectors import invert as invert_bijector
+from tensorflow_probability.python.bijectors import softplus as softplus_bijector
+from tensorflow_probability.python.distributions import transformed_distribution
+from tensorflow_probability.python.distributions import uniform
+from tensorflow_probability.python.internal import distribution_util
+from tensorflow_probability.python.internal import dtype_util
+from tensorflow_probability.python.internal import parameter_properties
+from tensorflow_probability.python.internal import prefer_static as ps
+from tensorflow_probability.python.internal import tensor_util
+
+
+class GeneralizedExtremeValue(transformed_distribution.TransformedDistribution):
+ """The scalar GeneralizedExtremeValue distribution.
+
+ This distribution is a common choice for modeling the maximum value of a
+ suitably normalized sequence of random variables. This distribution is closely
+ related to Gumbel and Weibull distributions, with Gumbel in particular being a
+ special case of this distribution with `concentration = 0`.
+
+ #### Mathematical details
+
+ The probability density function (pdf) of this distribution is,
+
+ ```none
+ pdf(x; loc, scale, conc) = t(x; loc, scale, conc) ** (1 + conc) * exp(
+ -t(x; loc, scale, conc) ) / scale
+ where t(x) =
+ * (1 + conc * (x - loc) / scale) ) ** (-1 / conc) when conc != 0;
+ * exp(-(x - loc) / scale) when conc = 0.
+ ```
+
+ where `concentration = conc`.
+
+ The cumulative density function of this distribution is,
+
+ ```cdf(x; mu, sigma) = exp(-t(x))```
+
+ The generalized extreme value distribution is a member of the
+ [location-scale family](https://en.wikipedia.org/wiki/Location-scale_family),
+ i.e., it can be constructed as,
+
+ ```none
+ X ~ GeneralizedExtremeValue(loc=0, scale=1, concentration=conc)
+ Y = loc + scale * X
+ ```
+
+ #### Examples
+
+ Examples of initialization of one or a batch of distributions.
+
+ ```python
+ tfd = tfp.distributions
+
+ # Define a single scalar generalized extreme values distribution.
+ dist = tfd.GeneralizedExtremeValue(loc=0., scale=3., concentration=0.9)
+
+ # Evaluate the cdf at 1, returning a scalar.
+ dist.cdf(1.)
+
+ # Define a batch of two scalar valued generalized extreme values.
+ # The first has loc 1 and scale 11, the second 2 and 22.
+ dist = tfd.GeneralizedExtremeValue(loc=[1, 2.], scale=[11, 22.])
+
+ # Evaluate the pdf of the first distribution on 0, and the second on 1.5,
+ # returning a length two tensor.
+ dist.prob([0, 1.5])
+
+ # Get 3 samples, returning a 3 x 2 tensor.
+ dist.sample([3])
+ ```
+
+ Arguments are broadcast when possible.
+
+ ```python
+ # Define a batch of two scalar valued GEV distributions.
+ # Both have location 1, but different concentrations.
+ dist = tfd.GeneralizedExtremeValue(loc=1., scale=1, concentration=[0, 0.9])
+
+ # Evaluate the pdf of both distributions on the same point, 3.0,
+ # returning a length 2 tensor.
+ dist.prob(3.0)
+ ```
+
+ """
+
+ def __init__(self,
+ loc,
+ scale,
+ concentration,
+ validate_args=False,
+ allow_nan_stats=True,
+ name='GeneralizedExtremeValue'):
+ """Construct generalized extreme value distribution.
+
+ The parameters `loc`, `scale`, and `concentration` must be shaped in a way
+ that supports broadcasting (e.g. `loc + scale` + `concentration` is valid).
+
+ Args:
+ loc: Floating point tensor, the location parameter of the distribution(s).
+ scale: Floating point tensor, the scales of the distribution(s).
+ scale must contain only positive values.
+ concentration: Floating point tensor, the concentration of
+ the distribution(s).
+ validate_args: Python `bool`, default `False`. When `True` distribution
+ parameters are checked for validity despite possibly degrading runtime
+ performance. When `False` invalid inputs may silently render incorrect
+ outputs.
+ Default value: `False`.
+ allow_nan_stats: Python `bool`, default `True`. When `True`,
+ statistics (e.g., mean, mode, variance) use the value `NaN` to
+ indicate the result is undefined. When `False`, an exception is raised
+ if one or more of the statistic's batch members are undefined.
+ Default value: `True`.
+ name: Python `str` name prefixed to Ops created by this class.
+ Default value: `'GeneralizedExtremeValue'`.
+
+ Raises:
+ TypeError: if loc and scale are different dtypes.
+ """
+ parameters = dict(locals())
+ with tf.name_scope(name) as name:
+ dtype = dtype_util.common_dtype([loc, scale, concentration],
+ dtype_hint=tf.float32)
+ loc = tensor_util.convert_nonref_to_tensor(
+ loc, name='loc', dtype=dtype)
+ scale = tensor_util.convert_nonref_to_tensor(
+ scale, name='scale', dtype=dtype)
+ concentration = tensor_util.convert_nonref_to_tensor(
+ concentration, name='concentration', dtype=dtype)
+ dtype_util.assert_same_float_dtype([loc, scale, concentration])
+ # Positive scale is asserted by the incorporated GEV bijector.
+ self._gev_bijector = gev_cdf_bijector.GeneralizedExtremeValueCDF(
+ loc=loc, scale=scale, concentration=concentration,
+ validate_args=validate_args)
+
+ batch_shape = distribution_util.get_broadcast_shape(loc, scale,
+ concentration)
+ # Because the uniform sampler generates samples in `[0, 1)` this would
+ # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix
+ # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny`
+ # because it is the smallest, positive, 'normal' number.
+ super(GeneralizedExtremeValue, self).__init__(
+ # TODO(b/137665504): Use batch-adding meta-distribution to set the
+ # batch shape instead of tf.ones.
+ distribution=uniform.Uniform(
+ low=np.finfo(dtype_util.as_numpy_dtype(dtype)).tiny,
+ high=tf.ones(batch_shape, dtype=dtype),
+ allow_nan_stats=allow_nan_stats),
+ # The GEV bijector encodes the CDF function as the forward,
+ # and hence needs to be inverted.
+ bijector=invert_bijector.Invert(
+ self._gev_bijector, validate_args=validate_args),
+ parameters=parameters,
+ name=name)
+
+ @classmethod
+ def _parameter_properties(cls, dtype, num_classes=None):
+ # pylint: disable=g-long-lambda
+ return dict(
+ loc=parameter_properties.ParameterProperties(),
+ scale=parameter_properties.ParameterProperties(
+ default_constraining_bijector_fn=(
+ lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))),
+ concentration=parameter_properties.ParameterProperties())
+ # pylint: enable=g-long-lambda
+
+ @property
+ def loc(self):
+ """Distribution parameter for the location."""
+ return self._gev_bijector.loc
+
+ @property
+ def scale(self):
+ """Distribution parameter for scale."""
+ return self._gev_bijector.scale
+
+ @property
+ def concentration(self):
+ """Distribution parameter for shape."""
+ return self._gev_bijector.concentration
+
+ def _entropy(self):
+ scale = tf.broadcast_to(self.scale,
+ ps.broadcast_shape(ps.shape(self.scale),
+ ps.shape(self.loc)))
+ euler_gamma = tf.constant(np.euler_gamma, self.dtype)
+ return 1. + tf.math.log(scale) + euler_gamma * (1. + self.concentration)
+
+ def _log_prob(self, x):
+ with tf.control_dependencies(self._gev_bijector._maybe_assert_valid_x(x)): # pylint: disable=protected-access
+ scale = tf.convert_to_tensor(self.scale)
+ z = (x - self.loc) / scale
+
+ conc = tf.convert_to_tensor(self.concentration)
+ equal_zero = tf.equal(conc, 0.)
+ safe_conc = tf.where(equal_zero, tf.ones_like(conc), conc)
+ log_t = tf.where(equal_zero, -z,
+ -tf.math.log1p(z * safe_conc) / safe_conc)
+
+ return (conc + 1) * log_t - tf.exp(log_t) - tf.math.log(scale)
+
+ def _mean(self):
+ conc = tf.convert_to_tensor(self.concentration)
+ equal_zero = tf.equal(conc, 0.)
+ less_than_one = tf.less(conc, 1.)
+ safe_conc = tf.where(equal_zero, tf.ones_like(conc), conc)
+
+ mean_zero = tf.fill(tf.shape(conc), tf.constant(np.euler_gamma, self.dtype))
+ mean_fin = tf.math.expm1(tf.math.lgamma(1. - safe_conc)) / safe_conc
+ mean_inf = tf.fill(tf.shape(conc), tf.constant(np.inf, self.dtype))
+
+ mean_z = tf.where(equal_zero,
+ mean_zero,
+ tf.where(less_than_one,
+ mean_fin,
+ mean_inf))
+
+ return self.loc + self.scale * mean_z
+
+ def _stddev(self):
+ conc = tf.convert_to_tensor(self.concentration)
+ equal_zero = tf.equal(conc, 0.)
+ less_than_half = tf.less(conc, 0.5)
+
+ g1_square = tf.exp(tf.math.lgamma(1. - conc)) ** 2
+ g2 = tf.exp(tf.math.lgamma(1. - 2. * conc))
+ safe_conc = tf.where(equal_zero, tf.ones([], self.dtype), conc)
+
+ std_z = tf.where(equal_zero,
+ tf.fill(tf.shape(conc),
+ tf.constant(np.pi / np.sqrt(6), self.dtype)),
+ tf.where(less_than_half,
+ tf.math.sqrt(g2 - g1_square) / tf.abs(safe_conc),
+ tf.fill(tf.shape(conc),
+ tf.constant(np.inf, self.dtype)))
+ )
+
+ return self.scale * tf.ones_like(self.loc) * std_z
+
+ def _mode(self):
+ conc = tf.convert_to_tensor(self.concentration)
+ equal_zero = tf.equal(conc, 0.)
+ safe_conc = tf.where(equal_zero, tf.ones_like(conc), conc)
+
+ mode_z = tf.where(equal_zero,
+ tf.zeros_like(conc),
+ tf.math.expm1(-conc * tf.math.log1p(conc)) / safe_conc)
+
+ return self.loc + self.scale * mode_z
+
+ def _parameter_control_dependencies(self, is_init):
+ return self._gev_bijector._parameter_control_dependencies(is_init) # pylint: disable=protected-access
diff --git a/tensorflow_probability/python/distributions/gev_test.py b/tensorflow_probability/python/distributions/gev_test.py
new file mode 100644
index 0000000000..d382b249f4
--- /dev/null
+++ b/tensorflow_probability/python/distributions/gev_test.py
@@ -0,0 +1,468 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for GEV."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+import numpy as np
+from scipy import stats
+
+import tensorflow.compat.v1 as tf1
+import tensorflow.compat.v2 as tf
+import tensorflow_probability as tfp
+
+from tensorflow_probability.python.internal import test_util
+
+tfd = tfp.distributions
+
+
+class _GEVTest(object):
+
+ def make_tensor(self, x):
+ x = tf.cast(x, self._dtype)
+ return tf1.placeholder_with_default(
+ x, shape=x.shape if self._use_static_shape else None)
+
+ def testGEVShape(self):
+ loc = np.array([3.0] * 5, dtype=self._dtype)
+ scale = np.array([3.0] * 5, dtype=self._dtype)
+ conc = np.array([3.0] * 5, dtype=self._dtype)
+ gev = tfd.GeneralizedExtremeValue(loc=loc, scale=scale,
+ concentration=conc,
+ validate_args=True)
+
+ self.assertEqual((5,), self.evaluate(gev.batch_shape_tensor()))
+ self.assertEqual(tf.TensorShape([5]), gev.batch_shape)
+ self.assertAllEqual([], self.evaluate(gev.event_shape_tensor()))
+ self.assertEqual(tf.TensorShape([]), gev.event_shape)
+
+ def testInvalidScale(self):
+ scale = [-.01, 0., 2.]
+ with self.assertRaisesOpError('Argument `scale` must be positive.'):
+ gev = tfd.GeneralizedExtremeValue(loc=0., scale=scale, concentration=1.,
+ validate_args=True)
+ self.evaluate(gev.mean())
+
+ scale = tf.Variable([.01])
+ self.evaluate(scale.initializer)
+ gev = tfd.GeneralizedExtremeValue(loc=0., scale=scale, concentration=1.,
+ validate_args=True)
+ self.assertIs(scale, gev.scale)
+ self.evaluate(gev.mean())
+ with tf.control_dependencies([scale.assign([-.01])]):
+ with self.assertRaisesOpError('Argument `scale` must be positive.'):
+ self.evaluate(gev.mean())
+
+ def testGEVLogPdf(self):
+ batch_size = 6
+ loc = np.array([0.] * batch_size, dtype=self._dtype)
+ scale = np.array([3.] * batch_size, dtype=self._dtype)
+ conc = np.array([2.] * batch_size, dtype=self._dtype)
+ gev_dist = stats.genextreme(-conc, loc=loc, scale=scale)
+ x = np.array([2., 3., 4., 5., 6., 7.], dtype=self._dtype)
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+ log_pdf = gev.log_prob(self.make_tensor(x))
+ self.assertAllClose(
+ gev_dist.logpdf(x),
+ self.evaluate(log_pdf))
+
+ pdf = gev.prob(x)
+ self.assertAllClose(
+ gev_dist.pdf(x), self.evaluate(pdf))
+
+ def testGEVLogPdfMultidimensional(self):
+ batch_size = 6
+ loc = np.array([[-2.0, -4.0, -5.0]] * batch_size, dtype=self._dtype)
+ scale = np.array([1.0], dtype=self._dtype)
+ conc = np.array([[0.0, 1.0, 2.0]] * batch_size, dtype=self._dtype)
+ gev_dist = stats.genextreme(-conc, loc=loc, scale=scale)
+ x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=self._dtype).T
+
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+ log_pdf = gev.log_prob(self.make_tensor(x))
+ self.assertAllClose(
+ self.evaluate(log_pdf), gev_dist.logpdf(x))
+
+ pdf = gev.prob(self.make_tensor(x))
+ self.assertAllClose(
+ self.evaluate(pdf), gev_dist.pdf(x))
+
+ def testGEVCDF(self):
+ batch_size = 6
+ loc = np.array([0.] * batch_size, dtype=self._dtype)
+ scale = np.array([3.] * batch_size, dtype=self._dtype)
+ conc = np.array([2.] * batch_size, dtype=self._dtype)
+ gev_dist = stats.genextreme(-conc, loc=loc, scale=scale)
+ x = np.array([2., 3., 4., 5., 6., 7.], dtype=self._dtype)
+
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+
+ log_cdf = gev.log_cdf(self.make_tensor(x))
+ self.assertAllClose(
+ self.evaluate(log_cdf), gev_dist.logcdf(x))
+
+ cdf = gev.cdf(self.make_tensor(x))
+ self.assertAllClose(
+ self.evaluate(cdf), gev_dist.cdf(x))
+
+ def testGEVCdfMultidimensional(self):
+ batch_size = 6
+ loc = np.array([[-2.0, -4.0, -5.0]] * batch_size, dtype=self._dtype)
+ scale = np.array([1.0], dtype=self._dtype)
+ conc = np.array([[0.0, 1.0, 2.0]] * batch_size, dtype=self._dtype)
+ gev_dist = stats.genextreme(-conc, loc=loc, scale=scale)
+ x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=self._dtype).T
+
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+
+ log_cdf = gev.log_cdf(self.make_tensor(x))
+ self.assertAllClose(
+ self.evaluate(log_cdf),
+ gev_dist.logcdf(x))
+
+ cdf = gev.cdf(self.make_tensor(x))
+ self.assertAllClose(
+ self.evaluate(cdf),
+ gev_dist.cdf(x))
+
+ def testGEVMean(self):
+ loc = np.array([2.0], dtype=self._dtype)
+ scale = np.array([1.5], dtype=self._dtype)
+ conc = np.array([-0.9, 0.0], dtype=self._dtype)
+ gev_dist = stats.genextreme(-conc, loc=loc, scale=scale)
+
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+ self.assertAllClose(self.evaluate(gev.mean()),
+ gev_dist.mean())
+
+ conc_with_inf_mean = np.array([2.], dtype=self._dtype)
+ gev_with_inf_mean = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc_with_inf_mean),
+ validate_args=True)
+ self.assertAllClose(self.evaluate(gev_with_inf_mean.mean()),
+ [np.inf])
+
+ def testGEVVariance(self):
+ loc = np.array([2.0], dtype=self._dtype)
+ scale = np.array([1.5], dtype=self._dtype)
+ conc = np.array([-0.9, 0.0], dtype=self._dtype)
+ gev_dist = stats.genextreme(-conc, loc=loc, scale=scale)
+
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+
+ self.assertAllClose(self.evaluate(gev.variance()),
+ gev_dist.var())
+
+ conc_with_inf_var = np.array([1.5], dtype=self._dtype)
+ gev_with_inf_var = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc_with_inf_var),
+ validate_args=True)
+ self.assertAllClose(self.evaluate(gev_with_inf_var.variance()),
+ [np.inf])
+
+ def testGEVStd(self):
+ loc = np.array([2.0], dtype=self._dtype)
+ scale = np.array([1.5], dtype=self._dtype)
+ conc = np.array([-0.9, 0.0], dtype=self._dtype)
+ gev_dist = stats.genextreme(-conc, loc=loc, scale=scale)
+
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+
+ self.assertAllClose(self.evaluate(gev.stddev()),
+ gev_dist.std())
+
+ conc_with_inf_std = np.array([1.5], dtype=self._dtype)
+ gev_with_inf_std = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc_with_inf_std),
+ validate_args=True)
+ self.assertAllClose(self.evaluate(gev_with_inf_std.stddev()),
+ [np.inf])
+
+ def testGEVMode(self):
+ loc = np.array([2.0], dtype=self._dtype)
+ scale = np.array([1.5], dtype=self._dtype)
+ conc = np.array([-0.9, 0.0, 1.5], dtype=self._dtype)
+
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+
+ np_mode_z = np.where(conc == 0., 0., ((conc+1)**(-conc) - 1.) / conc)
+ np_mode = loc + np_mode_z * scale
+ self.assertAllClose(self.evaluate(gev.mode()), np_mode)
+
+ def testGEVSample(self):
+ loc = self._dtype(4.0)
+ scale = self._dtype(1.0)
+ conc = self._dtype(0.2)
+ n = int(1e6)
+ gev_dist = stats.genextreme(-conc, loc=loc, scale=scale)
+
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+
+ samples = gev.sample(n, seed=test_util.test_seed())
+ sample_values = self.evaluate(samples)
+ self.assertEqual((n,), sample_values.shape)
+ self.assertAllClose(
+ gev_dist.mean(),
+ sample_values.mean(), rtol=.01)
+ self.assertAllClose(
+ gev_dist.var(),
+ sample_values.var(), rtol=.01)
+
+ def testGEVSampleMultidimensionalMean(self):
+ loc = np.array([2.0, 4.0, 5.0], dtype=self._dtype)
+ scale = np.array([1.0, 0.8, 0.5], dtype=self._dtype)
+ conc = np.array([0.2], dtype=self._dtype)
+ gev_dist = stats.genextreme(-conc, loc=loc, scale=scale)
+ n = int(1e6)
+
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+
+ samples = gev.sample(n, seed=test_util.test_seed())
+ sample_values = self.evaluate(samples)
+ self.assertAllClose(
+ gev_dist.mean(),
+ sample_values.mean(axis=0),
+ rtol=.03,
+ atol=0)
+
+ def testGEVSampleMultidimensionalVar(self):
+ loc = np.array([2.0, 4.0, 5.0], dtype=self._dtype)
+ scale = np.array([1.0, 0.8, 0.5], dtype=self._dtype)
+ conc = np.array([0.2], dtype=self._dtype)
+ gev_dist = stats.genextreme(-conc, loc=loc, scale=scale)
+ n = int(1e6)
+
+ gev = tfd.GeneralizedExtremeValue(
+ loc=self.make_tensor(loc),
+ scale=self.make_tensor(scale),
+ concentration=self.make_tensor(conc),
+ validate_args=True)
+
+ samples = gev.sample(n, seed=test_util.test_seed())
+ sample_values = self.evaluate(samples)
+ self.assertAllClose(
+ gev_dist.var(),
+ sample_values.var(axis=0),
+ rtol=.03,
+ atol=0)
+
+ @test_util.numpy_disable_gradient_test
+ def testFiniteGradientAtDifficultPoints(self):
+ def make_fn(dtype, attr):
+ x = np.array([1.]).astype(dtype)
+ return lambda m, s, p: getattr( # pylint: disable=g-long-lambda
+ tfd.GeneralizedExtremeValue(loc=m, scale=s,
+ concentration=p, validate_args=True),
+ attr)(x)
+
+ loc = np.array([1.0], dtype=self._dtype)
+ scale = np.array([1.5], dtype=self._dtype)
+ conc = np.array([-0.7, 0.0, 0.5, 1.], dtype=self._dtype)
+
+ for attr in ['log_prob', 'prob', 'cdf', 'log_cdf']:
+ value, grads = self.evaluate(tfp.math.value_and_gradient(
+ make_fn(self._dtype, attr),
+ [self.make_tensor(loc), # loc
+ self.make_tensor(scale), # scale
+ self.make_tensor(conc)])) # conc
+ self.assertAllFinite(value)
+ self.assertAllFinite(grads[0]) # d/d loc
+ self.assertAllFinite(grads[1]) # d/d scale
+ self.assertAllFinite(grads[2]) # d/d conc
+
+ def testBroadcastingParams(self):
+
+ def _check(gev_dist):
+ self.assertEqual(gev_dist.mean().shape, (3,))
+ self.assertEqual(gev_dist.variance().shape, (3,))
+ self.assertEqual(gev_dist.entropy().shape, (3,))
+ self.assertEqual(gev_dist.log_prob(6.).shape, (3,))
+ self.assertEqual(gev_dist.prob(6.).shape, (3,))
+ self.assertEqual(gev_dist.sample(
+ 37, seed=test_util.test_seed()).shape, (37, 3,))
+
+ _check(
+ tfd.GeneralizedExtremeValue(loc=[
+ 2.,
+ 3.,
+ 4.,
+ ], scale=2., concentration=1., validate_args=True))
+ _check(
+ tfd.GeneralizedExtremeValue(loc=3., scale=[
+ 2.,
+ 3.,
+ 4.,
+ ], concentration=1., validate_args=True))
+ _check(
+ tfd.GeneralizedExtremeValue(loc=3., scale=3., concentration=[
+ 2.,
+ 3.,
+ 4.,
+ ], validate_args=True))
+
+ def testBroadcastingPdfArgs(self):
+
+ def _assert_shape(gev_dist, arg, shape):
+ self.assertEqual(gev_dist.log_prob(arg).shape, shape)
+ self.assertEqual(gev_dist.prob(arg).shape, shape)
+
+ def _check(gev_dist):
+ _assert_shape(gev_dist, 5., (3,))
+ xs = np.array([5., 6., 7.], dtype=np.float32)
+ _assert_shape(gev_dist, xs, (3,))
+ xs = np.array([xs])
+ _assert_shape(gev_dist, xs, (1, 3))
+ xs = xs.T
+ _assert_shape(gev_dist, xs, (3, 3))
+
+ _check(
+ tfd.GeneralizedExtremeValue(loc=[
+ -2.,
+ -3.,
+ -4.,
+ ], scale=2., concentration=1., validate_args=True))
+ _check(
+ tfd.GeneralizedExtremeValue(loc=-6., scale=[
+ 2.,
+ 3.,
+ 4.,
+ ], concentration=1., validate_args=True))
+ _check(
+ tfd.GeneralizedExtremeValue(loc=-7., scale=3., concentration=[
+ 2.,
+ 3.,
+ 4.,
+ ], validate_args=True))
+
+ def _check2d(gev_dist):
+ _assert_shape(gev_dist, 5., (1, 3))
+ xs = np.array([5., 6., 7.], dtype=np.float32)
+ _assert_shape(gev_dist, xs, (1, 3))
+ xs = np.array([xs])
+ _assert_shape(gev_dist, xs, (1, 3))
+ xs = xs.T
+ _assert_shape(gev_dist, xs, (3, 3))
+
+ _check2d(
+ tfd.GeneralizedExtremeValue(loc=[[
+ -2.,
+ -3.,
+ -4.,
+ ]], scale=2., concentration=1., validate_args=True))
+ _check2d(
+ tfd.GeneralizedExtremeValue(loc=-7., scale=[[
+ 2.,
+ 3.,
+ 4.,
+ ]], concentration=1., validate_args=True))
+ _check2d(
+ tfd.GeneralizedExtremeValue(loc=-7., scale=3., concentration=[[
+ 2.,
+ 3.,
+ 4.,
+ ]], validate_args=True))
+
+ def _check2d_rows(gev_dist):
+ _assert_shape(gev_dist, 5., (3, 1))
+ xs = np.array([5., 6., 7.], dtype=np.float32) # (3,)
+ _assert_shape(gev_dist, xs, (3, 3))
+ xs = np.array([xs]) # (1,3)
+ _assert_shape(gev_dist, xs, (3, 3))
+ xs = xs.T # (3,1)
+ _assert_shape(gev_dist, xs, (3, 1))
+
+ _check2d_rows(
+ tfd.GeneralizedExtremeValue(
+ loc=[[-2.], [-3.], [-4.]], scale=2., concentration=1.,
+ validate_args=True))
+ _check2d_rows(
+ tfd.GeneralizedExtremeValue(
+ loc=-7., scale=[[2.], [3.], [4.]], concentration=1.,
+ validate_args=True))
+ _check2d_rows(
+ tfd.GeneralizedExtremeValue(
+ loc=-7., scale=3., concentration=[[2.], [3.], [4.]],
+ validate_args=True))
+
+
+@test_util.test_all_tf_execution_regimes
+class GEVTestStaticShape(test_util.TestCase, _GEVTest):
+ _dtype = np.float32
+ _use_static_shape = True
+
+
+@test_util.test_all_tf_execution_regimes
+class GEVTestFloat64StaticShape(test_util.TestCase, _GEVTest):
+ _dtype = np.float64
+ _use_static_shape = True
+
+
+@test_util.test_all_tf_execution_regimes
+class GEVTestDynamicShape(test_util.TestCase, _GEVTest):
+ _dtype = np.float32
+ _use_static_shape = False
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow_probability/python/distributions/half_normal.py b/tensorflow_probability/python/distributions/half_normal.py
index fa63739e0e..725d1944e0 100644
--- a/tensorflow_probability/python/distributions/half_normal.py
+++ b/tensorflow_probability/python/distributions/half_normal.py
@@ -154,11 +154,14 @@ def _sample_n(self, n, seed=None):
shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed)
return tf.abs(sampled * scale)
- def _prob(self, x):
+ def _log_prob(self, x):
scale = tf.convert_to_tensor(self.scale)
- coeff = math.sqrt(2) / scale / math.sqrt(np.pi)
- pdf = coeff * tf.exp(-0.5 * (x / scale)**2)
- return pdf * tf.cast(x >= 0, self.dtype)
+ log_unnormalized = -0.5 * (x / scale)**2
+ log_normalization = tf.math.log(scale) + tf.constant(0.5 * np.log(np.pi/2.),
+ dtype=self.dtype)
+ return tf.where(x >= 0,
+ log_unnormalized - log_normalization,
+ tf.constant(-np.inf, dtype=self.dtype))
def _cdf(self, x):
truncated_x = tf.nn.relu(x)
diff --git a/tensorflow_probability/python/distributions/half_normal_test.py b/tensorflow_probability/python/distributions/half_normal_test.py
index c7c49ca912..c4b5032861 100644
--- a/tensorflow_probability/python/distributions/half_normal_test.py
+++ b/tensorflow_probability/python/distributions/half_normal_test.py
@@ -83,9 +83,9 @@ def testParamStaticShapes(self):
self._testParamStaticShapes(tf.TensorShape(sample_shape), sample_shape)
def testHalfNormalLogPDF(self):
- batch_size = 6
- scale = tf.constant([3.0] * batch_size)
- x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
+
+ x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0, 60.], dtype=np.float32)
+ scale = tf.constant([3.0] * len(x))
halfnorm = tfd.HalfNormal(scale=scale, validate_args=False)
log_pdf = halfnorm.log_prob(x)
diff --git a/tensorflow_probability/python/distributions/hypothesis_testlib.py b/tensorflow_probability/python/distributions/hypothesis_testlib.py
index d97f24d082..3247d8aef6 100644
--- a/tensorflow_probability/python/distributions/hypothesis_testlib.py
+++ b/tensorflow_probability/python/distributions/hypothesis_testlib.py
@@ -72,6 +72,7 @@
'GeneralizedPareto',
'Geometric',
'Gumbel',
+ 'GeneralizedExtremeValue',
'HalfCauchy',
'HalfNormal',
'HalfStudentT',
diff --git a/tensorflow_probability/python/distributions/independent.py b/tensorflow_probability/python/distributions/independent.py
index 0b575f3963..18f7cc4731 100644
--- a/tensorflow_probability/python/distributions/independent.py
+++ b/tensorflow_probability/python/distributions/independent.py
@@ -22,6 +22,7 @@
import tensorflow.compat.v2 as tf
+from tensorflow_probability.python import math as tfp_math
from tensorflow_probability.python.distributions import distribution as distribution_lib
from tensorflow_probability.python.distributions import kullback_leibler
from tensorflow_probability.python.internal import assert_util
@@ -97,6 +98,7 @@ def __init__(self,
distribution,
reinterpreted_batch_ndims=None,
validate_args=False,
+ experimental_use_kahan_sum=False,
name=None):
"""Construct an `Independent` distribution.
@@ -110,6 +112,11 @@ def __init__(self,
validate_args: Python `bool`. Whether to validate input with asserts.
If `validate_args` is `False`, and the inputs are invalid,
correct behavior is not guaranteed.
+ experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan
+ summation to aggregate independent underlying log_prob values, which
+ improves against the precision of a naive float32 sum. This can be
+ noticeable in particular for large dimensions in float32. See CPU caveat
+ on `tfp.math.reduce_kahan_sum`.
name: The name for ops managed by the distribution.
Default value: `Independent + distribution.name`.
@@ -118,6 +125,7 @@ def __init__(self,
`distribution.batch_ndims`
"""
parameters = dict(locals())
+ self._experimental_use_kahan_sum = experimental_use_kahan_sum
with tf.name_scope(name or ('Independent' + distribution.name)) as name:
self._distribution = distribution
@@ -239,18 +247,23 @@ def _event_shape(self):
def _sample_n(self, n, seed, **kwargs):
return self.distribution.sample(sample_shape=n, seed=seed, **kwargs)
+ def _sum_fn(self):
+ if self._experimental_use_kahan_sum:
+ return lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total
+ return tf.math.reduce_sum
+
def _log_prob(self, x, **kwargs):
return self._reduce(
- tf.reduce_sum, self.distribution.log_prob(x, **kwargs))
+ self._sum_fn(), self.distribution.log_prob(x, **kwargs))
def _log_cdf(self, x, **kwargs):
- return self._reduce(tf.reduce_sum, self.distribution.log_cdf(x, **kwargs))
+ return self._reduce(self._sum_fn(), self.distribution.log_cdf(x, **kwargs))
def _entropy(self, **kwargs):
# NOTE: If self._reinterpreted_batch_ndims is None, we could avoid a read
# of self.distribution.batch_shape_tensor() in `self._reduce` here by
# passing in `tf.shape(self.distribution.entropy())` to use instead.
- return self._reduce(tf.reduce_sum, self.distribution.entropy(**kwargs))
+ return self._reduce(self._sum_fn(), self.distribution.entropy(**kwargs))
def _mean(self, **kwargs):
return self.distribution.mean(**kwargs)
diff --git a/tensorflow_probability/python/distributions/independent_test.py b/tensorflow_probability/python/distributions/independent_test.py
index 0966e58635..9959641ca6 100644
--- a/tensorflow_probability/python/distributions/independent_test.py
+++ b/tensorflow_probability/python/distributions/independent_test.py
@@ -18,8 +18,11 @@
from __future__ import division
from __future__ import print_function
+import os
+
# Dependency imports
+from absl.testing import parameterized
import numpy as np
from scipy import stats as sp_stats
import tensorflow.compat.v1 as tf1
@@ -30,13 +33,12 @@
from tensorflow_probability.python.internal import test_util
+JAX_MODE = False
+
+
@test_util.test_all_tf_execution_regimes
class IndependentDistributionTest(test_util.TestCase):
- def setUp(self):
- super(IndependentDistributionTest, self).setUp()
- self._rng = np.random.RandomState(42)
-
def assertRaises(self, error_class, msg):
if tf.executing_eagerly():
return self.assertRaisesRegexp(error_class, msg)
@@ -269,7 +271,7 @@ def _testMnistLike(self, static_shape):
sample_shape = [4, 5]
batch_shape = [10]
image_shape = [28, 28, 1]
- logits = 3 * self._rng.random_sample(
+ logits = 3 * np.random.random_sample(
batch_shape + image_shape).astype(np.float32) - 1
def expected_log_prob(x, logits):
@@ -308,7 +310,7 @@ def expected_log_prob(x, logits):
self.assertAllEqual(sample_shape + batch_shape + image_shape, x_shape)
self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape)
self.assertAllClose(
- expected_log_prob(x_, logits), actual_log_prob_x, rtol=1e-6, atol=0.)
+ expected_log_prob(x_, logits), actual_log_prob_x, rtol=1.5e-6, atol=0.)
def testMnistLikeStaticShape(self):
self._testMnistLike(static_shape=True)
@@ -497,6 +499,31 @@ def testChangingVariableShapes(self):
self.assertAllEqual(
(2, 3), tf.shape(dist.log_prob(np.zeros((2, 3, 7, 1, 1, 1)))))
+ @parameterized.named_parameters(dict(testcase_name=''),
+ dict(testcase_name='_jit', jit=True))
+ def test_kahan_precision(self, jit=False):
+ maybe_jit = lambda f: f
+ if jit:
+ self.skip_if_no_xla()
+ maybe_jit = tf.function(experimental_compile=True)
+ stream = test_util.test_seed_stream()
+ n = 20_000
+ samps = tfd.Poisson(rate=1.).sample(n, seed=stream())
+ log_rate = tf.fill([n], tfd.Normal(0, .2).sample(seed=stream()))
+ pois = tfd.Poisson(log_rate=log_rate)
+ lp_fn = maybe_jit(tfd.Independent(pois, reinterpreted_batch_ndims=1,
+ experimental_use_kahan_sum=True).log_prob)
+ lp = lp_fn(samps)
+ pois64 = tfd.Poisson(log_rate=tf.cast(log_rate, tf.float64))
+ lp64 = tfd.Independent(pois64, reinterpreted_batch_ndims=1).log_prob(
+ tf.cast(samps, tf.float64))
+ # Evaluate together to ensure we use the same samples.
+ lp, lp64 = self.evaluate((tf.cast(lp, tf.float64), lp64))
+ # Fails ~75% CPU, 1-75% GPU --vary_seed runs w/o experimental_use_kahan_sum.
+ self.assertAllClose(lp64, lp, rtol=0., atol=.01)
+
if __name__ == '__main__':
+ # TODO(b/173158845): XLA:CPU reassociates away the Kahan correction term.
+ os.environ['XLA_FLAGS'] = '--xla_cpu_enable_fast_math=false'
tf.test.main()
diff --git a/tensorflow_probability/python/distributions/joint_distribution.py b/tensorflow_probability/python/distributions/joint_distribution.py
index 6772c8275c..d7d8082a8e 100644
--- a/tensorflow_probability/python/distributions/joint_distribution.py
+++ b/tensorflow_probability/python/distributions/joint_distribution.py
@@ -25,7 +25,7 @@
import tensorflow.compat.v2 as tf
-from tensorflow_probability.python.bijectors import bijector as bijector_lib
+from tensorflow_probability.python.bijectors import composition
from tensorflow_probability.python.bijectors import identity as identity_bijector
from tensorflow_probability.python.distributions import distribution as distribution_lib
from tensorflow_probability.python.internal import assert_util
@@ -334,6 +334,11 @@ def sample_distributions(self, sample_shape=(), seed=None, value=None,
with self._name_and_control_scope(name):
ds, xs = self._call_flat_sample_distributions(sample_shape, seed, value,
**kwargs)
+ if not sample_shape and value is None:
+ # This is a single sample with no pinned values; this call will cache
+ # the distributions if they are not already cached.
+ self._get_single_sample_distributions(candidate_dists=ds)
+
return self._model_unflatten(ds), self._model_unflatten(xs)
def log_prob_parts(self, value, name='log_prob_parts'):
@@ -424,8 +429,13 @@ def _sample_n(self, sample_shape, seed, value=None, **kwargs):
raise ValueError('Supplied both `value` and keyword arguments to '
'parameterize sampling. Supplied keywords were: '
'{}'.format(keywords))
- _, xs = self._call_flat_sample_distributions(sample_shape, seed, value,
- **kwargs)
+ ds, xs = self._call_flat_sample_distributions(sample_shape, seed, value,
+ **kwargs)
+ if not sample_shape and value is None:
+ # This is a single sample with no pinned values; this call will cache
+ # the distributions if they are not already cached.
+ self._get_single_sample_distributions(candidate_dists=ds)
+
return self._model_unflatten(xs)
def _map_measure_over_dists(self, attr, value):
@@ -465,10 +475,6 @@ def _call_flat_sample_distributions(
value = self._model_flatten(value)
ds, xs = self._flat_sample_distributions(sample_shape, seed, value)
- if not sample_shape and value is None:
- # Maybe cache these distributions.
- self._get_single_sample_distributions(candidate_dists=ds)
-
return ds, xs
# Override the base method to capture *args and **kwargs, so we can
@@ -732,139 +738,75 @@ def maybe_check_wont_broadcast(flat_xs, validate_args):
return tuple(tf.identity(x) for x in flat_xs)
-# TODO(b/162764645): Implement as a Composite bijector.
-# The Composite CL generalizes Chain to arbitrary bijector DAGs. It will:
-# 1) Define an abstract `CompositeBijector` class (for any bijector that
-# wraps other bijectors, and does nothing else)
-# 2) Express `Chain` and friends (including this) in terms of Composite.
-# 3) Introduce `JointMap` (this class sans coroutine)
-# 4) Introduce `Restructure`, as Chain+JM are pretty useless without it.
-class _DefaultJointBijector(bijector_lib.Bijector):
+# pylint: disable=protected-access
+class _DefaultJointBijector(composition.Composition):
"""Minimally-viable event space bijector for `JointDistribution`."""
- # TODO(b/148485798): Support joint bijectors in TransformedDistribution.
- def __init__(self, jd):
+ def __init__(self, jd, parameters=None):
+ parameters = dict(locals()) if parameters is None else parameters
+
with tf.name_scope('default_joint_bijector') as name:
- structure = tf.nest.map_structure(lambda _: None, jd.dtype)
+ bijectors = tuple(
+ d.experimental_default_event_space_bijector()
+ for d in jd._get_single_sample_distributions())
+ i_min_event_ndims = tf.nest.map_structure(
+ prefer_static.size, jd.event_shape)
+ f_min_event_ndims = jd._model_unflatten([
+ b.inverse_event_ndims(nd) for b, nd in
+ zip(bijectors, jd._model_flatten(i_min_event_ndims))])
super(_DefaultJointBijector, self).__init__(
- forward_min_event_ndims=structure,
- inverse_min_event_ndims=structure,
+ bijectors=bijectors,
+ forward_min_event_ndims=f_min_event_ndims,
+ inverse_min_event_ndims=i_min_event_ndims,
validate_args=jd.validate_args,
+ parameters=parameters,
name=name)
self._jd = jd
- def _check_inputs_not_none(self, value):
- if any(x is None for x in tf.nest.flatten(value)):
- raise ValueError('No `value` part can be `None`; saw: {}.'.format(value))
+ def _conditioned_bijectors(self, samples, constrained=False):
+ if samples is None:
+ return self.bijectors
- # pylint: disable=protected-access
- def _evaluate_bijector(self, bijector_fn, values):
+ bijectors = []
gen = self._jd._model_coroutine()
- outputs = []
- d = next(gen)
- index = 0
- try:
- while True:
- dist = d.distribution if type(d).__name__ == 'Root' else d
- bijector = dist.experimental_default_event_space_bijector()
-
- # For discrete distributions, the default event space bijector is None.
- # For a joint distribution's discrete components, we want the behavior
- # of the Identity bijector.
- bijector = (identity_bijector.Identity()
- if bijector is None else bijector)
-
- out, y = bijector_fn(bijector, values[index])
- outputs.append(out)
- d = gen.send(y)
- index += 1
- except StopIteration:
- pass
- return outputs
-
- def _event_shapes(self, input_shapes, event_shape_attr):
- """For forward/inverse static event shapes."""
- input_shapes = self._jd._model_flatten(input_shapes)
- support_bijectors = [
- d.experimental_default_event_space_bijector()
- for d in self._jd._get_single_sample_distributions()]
- output_shapes = [
- getattr(bijector, event_shape_attr)(input_shape)
- for (bijector, input_shape) in zip(support_bijectors, input_shapes)]
- return self._jd._model_unflatten(output_shapes)
-
- # We override the public methods so that the `default_event_space_bijector`s
- # of the component distributions, instead of that of the `JointDistribution`,
- # hit the global bijector cache.
- def forward(self, values, name=None):
- with tf.name_scope(name or 'forward'):
- values = self._jd._model_flatten(values)
- self._check_inputs_not_none(values)
-
- def bijector_fn(bijector, value):
- y = bijector.forward(value)
- return y, y
-
- out = self._evaluate_bijector(bijector_fn, values)
- return self._jd._model_unflatten(out)
-
- def inverse(self, values, name=None):
- with tf.name_scope(name or 'inverse'):
- self._check_inputs_not_none(values)
- values = self._jd._model_flatten(values)
-
- def bijector_fn(bijector, value):
- x = bijector.inverse(value)
- return x, value
-
- out = self._evaluate_bijector(bijector_fn, values)
- return self._jd._model_unflatten(out)
-
- def forward_log_det_jacobian(self, values, event_ndims, name=None):
- with tf.name_scope(name or 'forward_log_det_jacobian'):
- self._check_inputs_not_none(values)
- values = self._jd._model_flatten(values)
- event_ndims = self._jd._model_flatten(event_ndims)
-
- def bijector_fn(bijector, value):
- x, event_ndims = value
- y = bijector.forward(x)
- fldj = bijector.forward_log_det_jacobian(x, event_ndims)
- return fldj, y
-
- fldjs = self._evaluate_bijector(bijector_fn,
- list(zip(values, event_ndims)))
- return sum(fldjs)
-
- def inverse_log_det_jacobian(self, values, event_ndims, name=None):
- with tf.name_scope(name or 'inverse_log_det_jacobian'):
- self._check_inputs_not_none(values)
- values = self._jd._model_flatten(values)
- event_ndims = self._jd._model_flatten(event_ndims)
-
- def bijector_fn(bijector, value):
- y, event_ndims = value
- ildj = bijector.inverse_log_det_jacobian(y, event_ndims)
- return ildj, y
-
- ildjs = self._evaluate_bijector(bijector_fn,
- list(zip(values, event_ndims)))
- return sum(ildjs)
- # pylint: enable=protected-access
-
- # TODO(b/148485931): Fix bijector caching.
- def forward_event_shape(self, input_shapes):
- return self._event_shapes(input_shapes, 'forward_event_shape')
-
- def forward_event_shape_tensor(self, input_shapes, name=None):
- with tf.name_scope(name or 'forward_event_shape_tensor'):
- self._check_inputs_not_none(input_shapes)
- return self._event_shapes(input_shapes, 'forward_event_shape_tensor')
-
- def inverse_event_shape(self, output_shapes):
- return self._event_shapes(output_shapes, 'inverse_event_shape')
-
- def inverse_event_shape_tensor(self, output_shapes, name=None):
- with tf.name_scope('inverse_event_shape_tensor'):
- self._check_inputs_not_none(output_shapes)
- return self._event_shapes(output_shapes, 'inverse_event_shape_tensor')
+ cond = None
+ for rv in self._jd._model_flatten(samples):
+ d = gen.send(cond)
+ dist = d.distribution if type(d).__name__ == 'Root' else d
+ bij = dist.experimental_default_event_space_bijector()
+
+ if bij is None:
+ bij = identity_bijector.Identity()
+ bijectors.append(bij)
+
+ # If the RV is not yet constrained, transform it.
+ cond = rv if constrained else bij.forward(rv)
+ return bijectors
+
+ def _walk_forward(self, step_fn, values, _jd_conditioning=None): # pylint: disable=invalid-name
+ bijectors = self._conditioned_bijectors(_jd_conditioning, constrained=False)
+ return self._jd._model_unflatten(tuple(
+ step_fn(bij, value) for bij, value in
+ zip(bijectors, self._jd._model_flatten(values))))
+
+ def _walk_inverse(self, step_fn, values, _jd_conditioning=None): # pylint: disable=invalid-name
+ bijectors = self._conditioned_bijectors(_jd_conditioning, constrained=True)
+ return self._jd._model_unflatten(tuple(
+ step_fn(bij, value) for bij, value
+ in zip(bijectors, self._jd._model_flatten(values))))
+
+ def _forward(self, x, **kwargs):
+ return super(_DefaultJointBijector, self)._forward(
+ x, _jd_conditioning=x, **kwargs)
+
+ def _forward_log_det_jacobian(self, x, event_ndims, **kwargs):
+ return super(_DefaultJointBijector, self)._forward_log_det_jacobian(
+ x, event_ndims, _jd_conditioning=x, **kwargs)
+
+ def _inverse(self, y, **kwargs):
+ return super(_DefaultJointBijector, self)._inverse(
+ y, _jd_conditioning=y, **kwargs)
+
+ def _inverse_log_det_jacobian(self, y, event_ndims, **kwargs):
+ return super(_DefaultJointBijector, self)._inverse_log_det_jacobian(
+ y, event_ndims, _jd_conditioning=y, **kwargs)
diff --git a/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py b/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py
index 3816746504..1ff1c299b9 100644
--- a/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py
+++ b/tensorflow_probability/python/distributions/joint_distribution_auto_batched.py
@@ -245,6 +245,7 @@ def __init__(
batch_ndims=0,
use_vectorized_map=True,
validate_args=False,
+ experimental_use_kahan_sum=False,
name=None,
):
"""Construct the `JointDistributionCoroutineAutoBatched` distribution.
@@ -270,13 +271,21 @@ def __init__(
If `validate_args` is `False`, and the inputs are invalid,
correct behavior is not guaranteed.
Default value: `False`.
+ experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan
+ summation to aggregate independent underlying log_prob values, which
+ improves against the precision of a naive float32 sum. This can be
+ noticeable in particular for large dimensions in float32. See CPU caveat
+ on `tfp.math.reduce_kahan_sum`.
name: The name for ops managed by the distribution.
Default value: `None` (i.e., `JointDistributionCoroutine`).
"""
+ parameters = dict(locals())
super(JointDistributionCoroutineAutoBatched, self).__init__(
model, sample_dtype=sample_dtype, batch_ndims=batch_ndims,
use_vectorized_map=use_vectorized_map, validate_args=validate_args,
+ experimental_use_kahan_sum=experimental_use_kahan_sum,
name=name or 'JointDistributionCoroutineAutoBatched')
+ self._parameters = self._no_dependency(parameters)
@property
def _require_root(self):
@@ -394,7 +403,8 @@ class JointDistributionNamedAutoBatched(
"""
def __init__(self, model, batch_ndims=0, use_vectorized_map=True,
- validate_args=False, name=None):
+ validate_args=False, experimental_use_kahan_sum=False,
+ name=None):
"""Construct the `JointDistributionNamedAutoBatched` distribution.
Args:
@@ -413,13 +423,21 @@ def __init__(self, model, batch_ndims=0, use_vectorized_map=True,
If `validate_args` is `False`, and the inputs are invalid,
correct behavior is not guaranteed.
Default value: `False`.
+ experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan
+ summation to aggregate independent underlying log_prob values, which
+ improves against the precision of a naive float32 sum. This can be
+ noticeable in particular for large dimensions in float32. See CPU caveat
+ on `tfp.math.reduce_kahan_sum`.
name: The name for ops managed by the distribution.
Default value: `None` (i.e., `JointDistributionNamed`).
"""
+ parameters = dict(locals())
super(JointDistributionNamedAutoBatched, self).__init__(
model, batch_ndims=batch_ndims, use_vectorized_map=use_vectorized_map,
validate_args=validate_args,
+ experimental_use_kahan_sum=experimental_use_kahan_sum,
name=name or 'JointDistributionNamedAutoBatched')
+ self._parameters = self._no_dependency(parameters)
# TODO(b/159723894): Reduce complexity by eliminating use of mixins.
@@ -533,7 +551,8 @@ class JointDistributionSequentialAutoBatched(
"""
def __init__(self, model, batch_ndims=0, use_vectorized_map=True,
- validate_args=False, name=None):
+ validate_args=False, experimental_use_kahan_sum=False,
+ name=None):
"""Construct the `JointDistributionSequentialAutoBatched` distribution.
Args:
@@ -552,10 +571,18 @@ def __init__(self, model, batch_ndims=0, use_vectorized_map=True,
If `validate_args` is `False`, and the inputs are invalid,
correct behavior is not guaranteed.
Default value: `False`.
+ experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan
+ summation to aggregate independent underlying log_prob values, which
+ improves against the precision of a naive float32 sum. This can be
+ noticeable in particular for large dimensions in float32. See CPU caveat
+ on `tfp.math.reduce_kahan_sum`.
name: The name for ops managed by the distribution.
Default value: `None` (i.e., `JointDistributionSequential`).
"""
+ parameters = dict(locals())
super(JointDistributionSequentialAutoBatched, self).__init__(
model, batch_ndims=batch_ndims, use_vectorized_map=use_vectorized_map,
validate_args=validate_args,
+ experimental_use_kahan_sum=experimental_use_kahan_sum,
name=name or 'JointDistributionSequentialAutoBatched')
+ self._parameters = self._no_dependency(parameters)
diff --git a/tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py b/tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py
index 1ba178315a..f593541bbb 100644
--- a/tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py
+++ b/tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py
@@ -18,8 +18,8 @@
from __future__ import division
from __future__ import print_function
-
import collections
+import os
# Dependency imports
from absl.testing import parameterized
@@ -597,21 +597,21 @@ def coroutine_model():
g = yield tfd.LogNormal(0., [1., 2.])
df = yield tfd.Exponential([1., 2.])
loc = yield tfd.Sample(tfd.Normal(0, g), 20)
- yield tfd.StudentT(tf.expand_dims(df, -1), loc, 1)
+ yield tfd.StudentT(df[:, tf.newaxis], loc, 1)
models[tfd.JointDistributionCoroutineAutoBatched] = coroutine_model
models[tfd.JointDistributionSequentialAutoBatched] = [
tfd.LogNormal(0., [1., 2.]),
tfd.Exponential([1., 2.]),
lambda _, g: tfd.Sample(tfd.Normal(0, g), 20),
- lambda loc, df: tfd.StudentT(tf.expand_dims(df, -1), loc, 1)
+ lambda loc, df: tfd.StudentT(df[:, tf.newaxis], loc, 1)
]
models[tfd.JointDistributionNamedAutoBatched] = collections.OrderedDict((
('g', tfd.LogNormal(0., [1., 2.])),
('df', tfd.Exponential([1., 2.])),
('loc', lambda g: tfd.Sample(tfd.Normal(0, g), 20)),
- ('x', lambda loc, df: tfd.StudentT(tf.expand_dims(df, -1), loc, 1))))
+ ('x', lambda loc, df: tfd.StudentT(df[:, tf.newaxis], loc, 1))))
joint = jd_class(models[jd_class], batch_ndims=1, validate_args=True)
joint_bijector = joint.experimental_default_event_space_bijector()
@@ -650,6 +650,59 @@ def inner_fn():
lp2 = joint.log_prob(z2)
self.assertAllEqual(lp2.shape, [5])
+ @parameterized.named_parameters(*[
+ dict(testcase_name='_{}{}'.format(jd_class.__name__, # pylint: disable=g-complex-comprehension
+ '_jit' if jit else ''),
+ jd_class=jd_class, jit=jit)
+ for jd_class in (tfd.JointDistributionCoroutineAutoBatched,
+ tfd.JointDistributionSequentialAutoBatched,
+ tfd.JointDistributionNamedAutoBatched)
+ for jit in (False, True)
+ ])
+ def test_kahan_precision(self, jd_class, jit):
+ maybe_jit = lambda f: f
+ if jit:
+ self.skip_if_no_xla()
+ maybe_jit = tf.function(experimental_compile=True)
+
+ def make_models(dtype):
+ models = {}
+ def mk_20k_poisson(log_rate):
+ return tfd.Poisson(log_rate=tf.broadcast_to(log_rate[..., tf.newaxis],
+ log_rate.shape + (20_000,)))
+ def coroutine_model():
+ log_rate = yield tfd.Normal(0., dtype(.2), name='log_rate')
+ yield mk_20k_poisson(log_rate).copy(name='x')
+ models[tfd.JointDistributionCoroutineAutoBatched] = coroutine_model
+
+ models[tfd.JointDistributionSequentialAutoBatched] = [
+ tfd.Normal(0., dtype(.2)), mk_20k_poisson
+ ]
+
+ models[tfd.JointDistributionNamedAutoBatched] = collections.OrderedDict((
+ ('log_rate', tfd.Normal(0., dtype(.2))), ('x', mk_20k_poisson)))
+ return models
+
+ joint = jd_class(make_models(np.float32)[jd_class], validate_args=True,
+ experimental_use_kahan_sum=True)
+ joint64 = jd_class(make_models(np.float64)[jd_class], validate_args=True)
+ stream = test_util.test_seed_stream()
+ nsamp = 7
+ xs = self.evaluate(
+ joint.sample(log_rate=tf.zeros([nsamp]), seed=stream()))
+ if isinstance(xs, dict):
+ xs['log_rate'] = tfd.Normal(0, .2).sample(nsamp, seed=stream())
+ else:
+ xs = (tfd.Normal(0, .2).sample(nsamp, seed=stream()), xs[1])
+ xs64 = tf.nest.map_structure(lambda x: tf.cast(x, tf.float64), xs)
+ lp = maybe_jit(joint.copy(validate_args=not jit).log_prob)(xs)
+ lp64 = joint64.log_prob(xs64)
+ lp, lp64 = self.evaluate((tf.cast(lp, tf.float64), lp64))
+ # Without Kahan, example max-abs-diff: ~0.06
+ self.assertAllClose(lp64, lp, rtol=0., atol=.01)
+
if __name__ == '__main__':
+ # TODO(b/173158845): XLA:CPU reassociates away the Kahan correction term.
+ os.environ['XLA_FLAGS'] = '--xla_cpu_enable_fast_math=false'
tf.test.main()
diff --git a/tensorflow_probability/python/distributions/joint_distribution_coroutine.py b/tensorflow_probability/python/distributions/joint_distribution_coroutine.py
index d5b442fd2c..b59a039bc0 100644
--- a/tensorflow_probability/python/distributions/joint_distribution_coroutine.py
+++ b/tensorflow_probability/python/distributions/joint_distribution_coroutine.py
@@ -394,5 +394,6 @@ def _model_unflatten(self, xs):
def _model_flatten(self, xs):
if self._sample_dtype is None:
- return tuple(xs)
+ return tuple((xs[k] for k in self._flat_resolve_names())
+ if isinstance(xs, collections.Mapping) else xs)
return nest.flatten_up_to(self._sample_dtype, xs)
diff --git a/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py b/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py
index 5079ae8ebc..67b1fde8ad 100644
--- a/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py
+++ b/tensorflow_probability/python/distributions/joint_distribution_coroutine_test.py
@@ -939,6 +939,30 @@ def _get_support_bijectors(dists, xs=None, ys=None):
self.evaluate(bijectors[i].inverse_event_shape_tensor(
event_shapes[i])))
+ def test_default_event_space_bijector_nested(self):
+ @tfd.JointDistributionCoroutine
+ def inner():
+ a = yield Root(tfd.Exponential(1., name='a'))
+ yield tfd.Sample(tfd.LogNormal(a, a), [5], name='b')
+
+ @tfd.JointDistributionCoroutine
+ def outer():
+ yield Root(inner)
+ yield Root(inner)
+ yield Root(inner)
+
+ xs = outer.sample(seed=test_util.test_seed())
+
+ outer_bij = outer.experimental_default_event_space_bijector()
+ joint_ldj = outer_bij.forward_log_det_jacobian(xs, [(0, 1)] * len(xs))
+
+ inner_bij = inner.experimental_default_event_space_bijector()
+ inner_ldjs = [inner_bij.forward_log_det_jacobian(x, (0, 1)) for x in xs]
+
+ # Evaluate both at once, to make sure we're using the same samples.
+ joint_ldj_, inner_ldjs_ = self.evaluate((joint_ldj, inner_ldjs))
+ self.assertAllClose(joint_ldj_, sum(inner_ldjs_))
+
def test_sample_kwargs(self):
@tfd.JointDistributionCoroutine
diff --git a/tensorflow_probability/python/distributions/joint_distribution_sample_path_mixin.py b/tensorflow_probability/python/distributions/joint_distribution_sample_path_mixin.py
index dec87667fb..54d6d38261 100644
--- a/tensorflow_probability/python/distributions/joint_distribution_sample_path_mixin.py
+++ b/tensorflow_probability/python/distributions/joint_distribution_sample_path_mixin.py
@@ -22,6 +22,7 @@
import tensorflow.compat.v2 as tf
+from tensorflow_probability.python import math as tfp_math
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import prefer_static
@@ -67,6 +68,8 @@ class JointDistributionSamplePathMixin(object):
def __init__(self, *args, **kwargs):
self._batch_ndims = kwargs.pop('batch_ndims', 0)
+ self._experimental_use_kahan_sum = kwargs.pop(
+ 'experimental_use_kahan_sum', False)
super(JointDistributionSamplePathMixin, self).__init__(*args, **kwargs)
@property
@@ -143,6 +146,10 @@ def _maybe_check_batch_shape(self):
return assertions
def _log_prob(self, value):
+ if self._experimental_use_kahan_sum:
+ xs = self._map_and_reduce_measure_over_dists(
+ 'log_prob', tfp_math.reduce_kahan_sum, value)
+ return sum(xs).total
xs = self._map_and_reduce_measure_over_dists(
'log_prob', tf.reduce_sum, value)
return sum(xs)
@@ -162,8 +169,11 @@ def log_prob_parts(self, value, name='log_prob_parts'):
each `distribution_fn` evaluated at each corresponding `value`.
"""
with self._name_and_control_scope(name):
+ sum_fn = tf.reduce_sum
+ if self._experimental_use_kahan_sum:
+ sum_fn = lambda x, axis: tfp_math.reduce_kahan_sum(x, axis=axis).total
xs = self._map_and_reduce_measure_over_dists(
- 'log_prob', tf.reduce_sum, value)
+ 'log_prob', sum_fn, value)
return self._model_unflatten(xs)
def prob_parts(self, value, name='prob_parts'):
diff --git a/tensorflow_probability/python/distributions/joint_distribution_vmap_mixin.py b/tensorflow_probability/python/distributions/joint_distribution_vmap_mixin.py
index 6fa472c464..28af559cef 100644
--- a/tensorflow_probability/python/distributions/joint_distribution_vmap_mixin.py
+++ b/tensorflow_probability/python/distributions/joint_distribution_vmap_mixin.py
@@ -138,7 +138,7 @@ def sample_distributions(self, sample_shape=(), seed=None, value=None,
if self.use_vectorized_map and (
_might_have_nonzero_size(sample_shape) or
value_might_have_sample_dims):
- raise NotImplementedError('sample_distributions` with nontrivial '
+ raise NotImplementedError('`sample_distributions` with nontrivial '
'sample shape is not yet supported '
'for autovectorized JointDistributions.')
else:
@@ -146,7 +146,7 @@ def sample_distributions(self, sample_shape=(), seed=None, value=None,
sample_shape=sample_shape, seed=seed, value=value)
return self._model_unflatten(ds), self._model_unflatten(xs)
- def _sample_n(self, sample_shape, seed, value=None):
+ def _sample_n(self, sample_shape, seed, value=None, **kwargs):
value_might_have_sample_dims = False
if value is not None:
@@ -162,7 +162,7 @@ def _sample_n(self, sample_shape, seed, value=None):
value_might_have_sample_dims):
# No need to auto-vectorize.
xs = self._call_flat_sample_distributions(
- sample_shape=sample_shape, seed=seed, value=value)[1]
+ sample_shape=sample_shape, seed=seed, value=value, **kwargs)[1]
return self._model_unflatten(xs)
# Set up for autovectorized sampling. To support the `value` arg, we need to
@@ -209,12 +209,3 @@ def map_measure_fn(value):
validate_args=self.validate_args)
return map_measure_fn(value)
-
- # Redefine not to attempt to cache the sampled distributions, since we might
- # be inside of a vectorized_map.
- def _call_flat_sample_distributions(
- self, sample_shape=(), seed=None, value=None):
- if value is not None:
- value = self._model_flatten(value)
- ds, xs = self._flat_sample_distributions(sample_shape, seed, value)
- return ds, xs
diff --git a/tensorflow_probability/python/distributions/lkj.py b/tensorflow_probability/python/distributions/lkj.py
index d5d9b5aa2b..aacee4a7ea 100644
--- a/tensorflow_probability/python/distributions/lkj.py
+++ b/tensorflow_probability/python/distributions/lkj.py
@@ -31,6 +31,10 @@
import tensorflow.compat.v2 as tf
from tensorflow_probability.python import math as tfp_math
+from tensorflow_probability.python.bijectors import bijector as bijector_lib
+from tensorflow_probability.python.bijectors import chain as chain_bijector
+from tensorflow_probability.python.bijectors import cholesky_outer_product as cholesky_outer_product_bijector
+from tensorflow_probability.python.bijectors import correlation_cholesky as correlation_cholesky_bijector
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
from tensorflow_probability.python.distributions import beta
from tensorflow_probability.python.distributions import distribution
@@ -42,6 +46,7 @@
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util
+from tensorflow_probability.python.math.numeric import clip_by_value_preserve_gradient
from tensorflow.python.ops import control_flow_util # pylint: disable=g-direct-tensorflow-import
@@ -50,6 +55,59 @@
]
+class _ClipByValue(bijector_lib.Bijector):
+ """A bijector that clips by value.
+
+ This class is intended for minute numerical issues where `|clip(x) - x| <=
+ eps`, as it defines the derivative of its application to be exactly 1.
+ """
+
+ def __init__(self,
+ clip_value_min,
+ clip_value_max,
+ validate_args=False,
+ name='clip_by_value'):
+ """Instantiates the `ClipByValue` bijector.
+
+ Args:
+ clip_value_min: Floating-point `Tensor`.
+ clip_value_max: Floating-point `Tensor`.
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ parameters = dict(locals())
+ with tf.name_scope(name) as name:
+ dtype = dtype_util.common_dtype([clip_value_min, clip_value_max],
+ dtype_hint=tf.float32)
+ self._clip_value_min = tensor_util.convert_nonref_to_tensor(
+ clip_value_min, dtype=dtype, name='clip_value_min')
+ self._clip_value_max = tensor_util.convert_nonref_to_tensor(
+ clip_value_max, dtype=dtype, name='clip_value_max')
+ super(_ClipByValue, self).__init__(
+ forward_min_event_ndims=0,
+ is_constant_jacobian=True,
+ dtype=dtype,
+ validate_args=validate_args,
+ parameters=parameters,
+ name=name)
+
+ @classmethod
+ def _is_increasing(cls):
+ return False
+
+ def _forward(self, x):
+ return clip_by_value_preserve_gradient(x, self._clip_value_min,
+ self._clip_value_max)
+
+ def _inverse(self, y):
+ return y
+
+ def _forward_log_det_jacobian(self, x):
+ # We deliberately ignore the clipping operation.
+ return tf.zeros([], dtype=dtype_util.base_dtype(x.dtype))
+
+
def _uniform_unit_norm(dimension, shape, dtype, seed):
"""Returns a batch of points chosen uniformly from the unit hypersphere."""
# This works because the Gaussian distribution is spherically symmetric.
@@ -463,21 +521,25 @@ def _mean(self):
dtype=concentration.dtype)
return answer
- # TODO(b/146522000): The output of tfb.CorrelationCholesky() can have
- # values > 1. Enable this bijector when that's fixed.
- # def _default_event_space_bijector(self):
- # # TODO(b/145620027) Finalize choice of bijector.
- # cholesky_bijector = correlation_cholesky_bijector.CorrelationCholesky(
- # validate_args=self.validate_args)
- # if self.input_output_cholesky:
- # return cholesky_bijector
- # return chain_bijector.Chain([
- # cholesky_outer_product_bijector.CholeskyOuterProduct(
- # validate_args=self.validate_args),
- # cholesky_bijector
- # ], validate_args=self.validate_args)
def _default_event_space_bijector(self):
- return
+ # TODO(b/145620027) Finalize choice of bijector.
+ cholesky_bijector = correlation_cholesky_bijector.CorrelationCholesky(
+ validate_args=self.validate_args)
+
+ if self.input_output_cholesky:
+ return cholesky_bijector
+ return chain_bijector.Chain([
+ # We need to explictly clip the output of this bijector because the
+ # other two bijectors sometimes return values that exceed the bounds by
+ # an epsilon due to minute numerical errors. Even numerically stable
+ # algorithms (which the other two bijectors employ) allow for symmetric
+ # errors about the true value, which is inappropriate for a one-sided
+ # validity constraint associated with correlation matrices.
+ _ClipByValue(-1., tf.ones([], self.dtype)),
+ cholesky_outer_product_bijector.CholeskyOuterProduct(
+ validate_args=self.validate_args),
+ cholesky_bijector
+ ], validate_args=self.validate_args)
def _parameter_control_dependencies(self, is_init):
assertions = []
diff --git a/tensorflow_probability/python/distributions/lkj_test.py b/tensorflow_probability/python/distributions/lkj_test.py
index f07e0640b3..036e437ede 100644
--- a/tensorflow_probability/python/distributions/lkj_test.py
+++ b/tensorflow_probability/python/distributions/lkj_test.py
@@ -70,7 +70,8 @@ def _det_ok_mask(x, det_bounds, input_output_cholesky=False):
@test_util.test_all_tf_execution_regimes
-@parameterized.parameters(np.float32, np.float64)
+@parameterized.named_parameters(('_float32', np.float32),
+ ('_float64', np.float64))
class LKJTest(test_util.TestCase):
def testNormConst2D(self, dtype):
@@ -419,6 +420,12 @@ def testValidateConcentrationAfterMutation(self, dtype):
with tf.control_dependencies([concentration.assign(0.5)]):
self.evaluate(d.mean())
+ def testDefaultEventSpaceBijectorValidCorrelation(self, dtype):
+ d = tfd.LKJ(3, tf.constant(1., dtype), validate_args=True)
+ b = d.experimental_default_event_space_bijector()
+ sample = b(tf.zeros((3, 3), dtype))
+ self.evaluate(d.log_prob(sample))
+
class LKJTestGraphOnly(test_util.TestCase):
diff --git a/tensorflow_probability/python/distributions/mixture_same_family.py b/tensorflow_probability/python/distributions/mixture_same_family.py
index 44518fb5fc..4c5c9d5f47 100644
--- a/tensorflow_probability/python/distributions/mixture_same_family.py
+++ b/tensorflow_probability/python/distributions/mixture_same_family.py
@@ -332,8 +332,7 @@ def _sample_n(self, n, seed):
], axis=0)
mask = tf.reshape(mask, shape=target_shape)
- if x.dtype in [tf.bfloat16, tf.float16, tf.float32, tf.float64,
- tf.complex64, tf.complex128]:
+ if dtype_util.is_floating(x.dtype) or dtype_util.is_complex(x.dtype):
masked = tf.math.multiply_no_nan(x, mask)
else:
masked = x * mask
diff --git a/tensorflow_probability/python/distributions/multinomial.py b/tensorflow_probability/python/distributions/multinomial.py
index 9635055c9b..6f75d2f767 100644
--- a/tensorflow_probability/python/distributions/multinomial.py
+++ b/tensorflow_probability/python/distributions/multinomial.py
@@ -413,7 +413,7 @@ def fn(i, num_trials, consumed_prob, accum):
num_trials = tf.cast(num_trials, probs.dtype)
# Pre-broadcast with probs
- num_trials += tf.zeros_like(probs[..., 0])
+ num_trials = num_trials + tf.zeros_like(probs[..., 0])
# Pre-enlarge for different output samples
num_trials = _replicate_along_left(num_trials, num_samples)
i = tf.constant(0)
diff --git a/tensorflow_probability/python/distributions/platform_compatibility_test.py b/tensorflow_probability/python/distributions/platform_compatibility_test.py
index 5ecc7dbfee..cf82292a0b 100644
--- a/tensorflow_probability/python/distributions/platform_compatibility_test.py
+++ b/tensorflow_probability/python/distributions/platform_compatibility_test.py
@@ -54,6 +54,8 @@
'FiniteDiscrete',
# TODO(b/159996966)
'Gamma',
+ # TODO(b/173546024)
+ 'GeneralizedExtremeValue',
'OneHotCategorical',
'LogNormal',
# TODO(b/162935914): Needs to use XLA friendly Poisson sampler.
@@ -104,6 +106,7 @@
LOGPROB_AUTOVECTORIZATION_IS_BROKEN = [
'Bates', # tf.repeat and tf.range do not vectorize. (b/157665707)
+ 'ExponentiallyModifiedGaussian', # b/174778704
'HalfStudentT', # Numerical problem: b/149785284
'Skellam',
'StudentT', # Numerical problem: b/149785284
@@ -122,6 +125,7 @@
'Beta': 1e-5,
'BetaBinomial': 1e-5,
'CholeskyLKJ': 1e-4,
+ 'GammaGamma': 2e-5,
'LKJ': 1e-3,
'PowerSpherical': 2e-5,
})
@@ -129,6 +133,7 @@
VECTORIZED_LOGPROB_RTOL = collections.defaultdict(lambda: 1e-6)
VECTORIZED_LOGPROB_RTOL.update({
'Beta': 1e-5,
+ 'GammaGamma': 1e-4,
'NegativeBinomial': 1e-5,
'PERT': 1e-5,
'PowerSpherical': 5e-5,
@@ -188,7 +193,9 @@
SKIP_KL_CHECK_DIST_VAR_GRADS = [
'Kumaraswamy', # TD's KL gradients do not rely on bijector variables.
- 'JohnsonSU' # TD's KL gradients do not rely on bijector variables.
+ 'JohnsonSU', # TD's KL gradients do not rely on bijector variables.
+ 'GeneralizedExtremeValue', # TD's KL gradients do not rely on bijector
+ # variables.
]
diff --git a/tensorflow_probability/python/distributions/sample.py b/tensorflow_probability/python/distributions/sample.py
index c1d8325034..27bfcfa434 100644
--- a/tensorflow_probability/python/distributions/sample.py
+++ b/tensorflow_probability/python/distributions/sample.py
@@ -25,6 +25,7 @@
import tensorflow.compat.v2 as tf
+from tensorflow_probability.python import math as tfp_math
from tensorflow_probability.python.bijectors import bijector as bijector_lib
from tensorflow_probability.python.distributions import distribution as distribution_lib
from tensorflow_probability.python.distributions import kullback_leibler
@@ -121,6 +122,7 @@ def __init__(
distribution,
sample_shape=(),
validate_args=False,
+ experimental_use_kahan_sum=False,
name=None):
"""Construct the `Sample` distribution.
@@ -141,10 +143,16 @@ def __init__(
validate_args: Python `bool`. Whether to validate input with asserts.
If `validate_args` is `False`, and the inputs are invalid,
correct behavior is not guaranteed.
+ experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan
+ summation to aggregate independent underlying log_prob values, which
+ improves against the precision of a naive float32 sum. This can be
+ noticeable in particular for large dimensions in float32. See CPU caveat
+ on `tfp.math.reduce_kahan_sum`.
name: The name for ops managed by the distribution.
Default value: `None` (i.e., `'Sample' + distribution.name`).
"""
parameters = dict(locals())
+ self._experimental_use_kahan_sum = experimental_use_kahan_sum
with tf.name_scope(name or 'Sample' + distribution.name) as name:
self._distribution = distribution
self._sample_shape = tensor_util.convert_nonref_to_tensor(
@@ -223,6 +231,11 @@ def _sample_n(self, n, seed, **kwargs):
**kwargs)
return tf.transpose(a=x, perm=perm)
+ def _sum_fn(self):
+ if self._experimental_use_kahan_sum:
+ return lambda x, axis: tfp_math.reduce_kahan_sum(x, axis).total
+ return tf.math.reduce_sum
+
def _log_prob(self, x, **kwargs):
batch_ndims = ps.rank_from_shape(
self.distribution.batch_shape_tensor,
@@ -266,7 +279,7 @@ def _log_prob(self, x, **kwargs):
lp = tf.broadcast_to(lp, bcast_lp_shape)
# (5) Make the final reduction in x.
axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims)
- return tf.reduce_sum(lp, axis=axis)
+ return self._sum_fn()(lp, axis=axis)
def _entropy(self, **kwargs):
h = self.distribution.entropy(**kwargs)
@@ -282,7 +295,8 @@ def _default_event_space_bijector(self):
# TODO(b/170405182): In scenarios where we can statically prove that it has
# no batch part, avoid the transposes by directly using
# `self.distribution.experimental_default_event_space_bijector()`.
- return _DefaultSampleBijector(self.distribution, self.sample_shape)
+ return _DefaultSampleBijector(self.distribution, self.sample_shape,
+ self._sum_fn())
def _parameter_control_dependencies(self, is_init):
assertions = []
@@ -335,11 +349,12 @@ def _parameter_control_dependencies(self, is_init):
class _DefaultSampleBijector(bijector_lib.Bijector):
"""Since tfd.Sample uses transposes, it requires a custom event bijector."""
- def __init__(self, distribution, sample_shape):
+ def __init__(self, distribution, sample_shape, sum_fn):
parameters = dict(locals())
self.distribution = distribution
self.bijector = distribution.experimental_default_event_space_bijector()
self.sample_shape = sample_shape
+ self._sum_fn = sum_fn
sample_ndims = ps.rank_from_shape(self.sample_shape)
super(_DefaultSampleBijector, self).__init__(
forward_min_event_ndims=(
@@ -466,7 +481,7 @@ def _bcast_and_reduce_logdet(self, underlying_ldj):
ps.ones([batch_ndims], tf.int32),
ps.reshape(self.sample_shape, shape=[-1])], axis=0))
ldj = tf.broadcast_to(underlying_ldj, bcast_ldj_shape)
- return tf.reduce_sum(ldj, axis=-1 - ps.range(extra_sample_ndims))
+ return self._sum_fn(ldj, axis=-1 - ps.range(extra_sample_ndims))
def _forward_log_det_jacobian(self, x, **kwargs):
dist = self.distribution
diff --git a/tensorflow_probability/python/distributions/sample_test.py b/tensorflow_probability/python/distributions/sample_test.py
index 321085d7d5..1d8c023d71 100644
--- a/tensorflow_probability/python/distributions/sample_test.py
+++ b/tensorflow_probability/python/distributions/sample_test.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+import os
+
# Dependency imports
from absl.testing import parameterized
@@ -28,6 +30,9 @@
from tensorflow_probability.python.internal import test_util
+JAX_MODE = False
+
+
@test_util.test_all_tf_execution_regimes
class SampleDistributionTest(test_util.TestCase):
@@ -422,6 +427,29 @@ def test_bijector_constant_underlying_ildj(self):
ildj = bij.inverse_log_det_jacobian(tf.zeros([2, 3]), event_ndims=2)
self.assertAllClose(-np.log([2., 3.]).sum() * 3, ildj)
+ @parameterized.named_parameters(dict(testcase_name=''),
+ dict(testcase_name='_jit', jit=True))
+ def test_kahan_precision(self, jit=False):
+ maybe_jit = lambda f: f
+ if jit:
+ self.skip_if_no_xla()
+ maybe_jit = tf.function(experimental_compile=True)
+ stream = test_util.test_seed_stream()
+ n = 20_000
+ samps = tfd.Poisson(rate=1.).sample(n, seed=stream())
+ log_rate = tfd.Normal(0, .2).sample(seed=stream())
+ pois = tfd.Poisson(log_rate=log_rate)
+ lp = maybe_jit(
+ tfd.Sample(pois, n, experimental_use_kahan_sum=True).log_prob)(samps)
+ pois64 = tfd.Poisson(log_rate=tf.cast(log_rate, tf.float64))
+ lp64 = tfd.Sample(pois64, n).log_prob(tf.cast(samps, tf.float64))
+ # Evaluate together to ensure we use the same samples.
+ lp, lp64 = self.evaluate((tf.cast(lp, tf.float64), lp64))
+ # Fails 75% CPU, 0-80% GPU --vary_seed runs w/o experimental_use_kahan_sum.
+ self.assertAllClose(lp64, lp, rtol=0., atol=.01)
+
if __name__ == '__main__':
+ # TODO(b/173158845): XLA:CPU reassociates away the Kahan correction term.
+ os.environ['XLA_FLAGS'] = '--xla_cpu_enable_fast_math=false'
tf.test.main()
diff --git a/tensorflow_probability/python/distributions/spherical_uniform.py b/tensorflow_probability/python/distributions/spherical_uniform.py
index ff8794b00f..b583110627 100644
--- a/tensorflow_probability/python/distributions/spherical_uniform.py
+++ b/tensorflow_probability/python/distributions/spherical_uniform.py
@@ -31,9 +31,9 @@
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import reparameterization
-from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util
+from tensorflow_probability.python.random import random_ops
__all__ = ['SphericalUniform']
@@ -164,11 +164,11 @@ def _log_prob(self, x):
return tf.fill(batch_shape, -log_nsphere_surface_area)
def _sample_n(self, n, seed=None):
- raw = samplers.normal(
- shape=ps.concat([[n], self.batch_shape, [self.dimension]], axis=0),
- seed=seed, dtype=self.dtype)
- unit_norm = raw / tf.norm(raw, ord=2, axis=-1)[..., tf.newaxis]
- return unit_norm
+ return random_ops.spherical_uniform(
+ shape=ps.concat([[n], self.batch_shape], axis=0),
+ dimension=self.dimension,
+ dtype=self.dtype,
+ seed=seed)
def _entropy(self):
log_nsphere_surface_area = (
diff --git a/tensorflow_probability/python/distributions/stochastic_process_properties_test.py b/tensorflow_probability/python/distributions/stochastic_process_properties_test.py
index cfb55cde82..2fbe7becad 100644
--- a/tensorflow_probability/python/distributions/stochastic_process_properties_test.py
+++ b/tensorflow_probability/python/distributions/stochastic_process_properties_test.py
@@ -418,7 +418,8 @@ def testExcessiveConcretizationInZeroArgPublicMethods(
try:
with tfp_hps.assert_no_excessive_var_usage(
'method `{}` of `{}`'.format(stat, process),
- max_permissible=MAX_CONVERSIONS_BY_CLASS.get(process_name, 1)):
+ max_permissible=MAX_CONVERSIONS_BY_CLASS.get(process_name, 1)
+ ), kernel_hps.no_pd_errors():
getattr(process, stat)()
except NotImplementedError:
diff --git a/tensorflow_probability/python/distributions/transformed_distribution.py b/tensorflow_probability/python/distributions/transformed_distribution.py
index 2d669b4820..68daa369a7 100644
--- a/tensorflow_probability/python/distributions/transformed_distribution.py
+++ b/tensorflow_probability/python/distributions/transformed_distribution.py
@@ -478,6 +478,32 @@ def _mean_mode_impl(self, attr, kwargs):
y = self._set_sample_static_shape(y, sample_shape)
return y
+ def _stddev(self, **kwargs):
+ if not self.bijector.is_constant_jacobian:
+ raise NotImplementedError('`stddev` is not implemented for non-affine '
+ '`bijectors`.')
+ if not self.bijector._is_injective: # pylint: disable=protected-access
+ raise NotImplementedError('`stddev` is not implemented when '
+ '`bijector` is not injective.')
+ if not (self.bijector._is_scalar # pylint: disable=protected-access
+ or self.bijector._is_permutation): # pylint: disable=protected-access
+ raise NotImplementedError('`stddev` is not implemented when `bijector` '
+ 'is a multivariate transformation.')
+
+ # A scalar affine bijector is of the form `forward(x) = scale * x + shift`,
+ # where the standard deviation is invariant to the shift, so we extract the
+ # shift and subtract it.
+ distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
+ x_stddev = self.distribution.stddev(**distribution_kwargs)
+ y_stddev_plus_shift = self.bijector.forward(x_stddev, **bijector_kwargs)
+ shift = self.bijector.forward(
+ tf.nest.map_structure(
+ tf.zeros_like, x_stddev),
+ **bijector_kwargs)
+ return tf.nest.map_structure(
+ tf.abs,
+ tf.nest.map_structure(tf.subtract, y_stddev_plus_shift, shift))
+
def _entropy(self, **kwargs):
if not self.bijector.is_constant_jacobian:
raise NotImplementedError('`entropy` is not implemented.')
diff --git a/tensorflow_probability/python/distributions/transformed_distribution_test.py b/tensorflow_probability/python/distributions/transformed_distribution_test.py
index 2572beb7f0..04e3ea43e6 100644
--- a/tensorflow_probability/python/distributions/transformed_distribution_test.py
+++ b/tensorflow_probability/python/distributions/transformed_distribution_test.py
@@ -363,6 +363,39 @@ def testMean(self):
validate_args=True)
self.assertAllClose(shift, self.evaluate(fake_mvn.mean()))
+ def testStddev(self):
+ base_stddev = 2.
+ shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
+ scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32)
+ expected_stddev = tf.abs(base_stddev * scale)
+ normal = self._cls()(
+ distribution=tfd.Normal(loc=tf.zeros_like(shift),
+ scale=base_stddev * tf.ones_like(scale),
+ validate_args=True),
+ bijector=tfb.Chain([tfb.Shift(shift=shift),
+ tfb.Scale(scale=scale)],
+ validate_args=True),
+ validate_args=True)
+ self.assertAllClose(expected_stddev, normal.stddev())
+ self.assertAllClose(expected_stddev**2, normal.variance())
+
+ split_normal = self._cls()(
+ distribution=tfd.Independent(normal, reinterpreted_batch_ndims=1),
+ bijector=tfb.Split(3),
+ validate_args=True)
+ self.assertAllCloseNested(tf.split(expected_stddev,
+ num_or_size_splits=3,
+ axis=-1),
+ split_normal.stddev())
+
+ scaled_normal = self._cls()(
+ distribution=tfd.Independent(normal, reinterpreted_batch_ndims=1),
+ bijector=tfb.ScaleMatvecTriL([[1., 0.], [-1., 2.]]),
+ validate_args=True)
+ with self.assertRaisesRegex(
+ NotImplementedError, 'is a multivariate transformation'):
+ scaled_normal.stddev()
+
def testEntropy(self):
shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
diag = np.array([[1, 2, 3], [2, 3, 2]], dtype=np.float32)
@@ -810,74 +843,6 @@ def setUp(self):
self.shape = tf.TensorShape(None)
-class ToyZipMap(tfb.Bijector):
-
- def __init__(self, bijectors):
- parameters = dict(locals())
- self._bijectors = bijectors
-
- super(ToyZipMap, self).__init__(
- forward_min_event_ndims=tf.nest.map_structure(
- lambda b: b.forward_min_event_ndims, bijectors),
- inverse_min_event_ndims=tf.nest.map_structure(
- lambda b: b.inverse_min_event_ndims, bijectors),
- is_constant_jacobian=all([
- b.is_constant_jacobian for b in tf.nest.flatten(bijectors)]),
- parameters=parameters)
-
- @property
- def bijectors(self):
- return self._bijectors
-
- def forward(self, x):
- return tf.nest.map_structure(lambda b_i, x_i: b_i.forward(x_i),
- self.bijectors, x)
-
- def inverse(self, y):
- return tf.nest.map_structure(lambda b_i, y_i: b_i.inverse(y_i),
- self.bijectors, y)
-
- def forward_dtype(self, dtype):
- return tf.nest.map_structure(lambda b_i, d_i: b_i.forward_dtype(d_i),
- self.bijectors, dtype)
-
- def inverse_dtype(self, dtype):
- return tf.nest.map_structure(lambda b_i, d_i: b_i.inverse_dtype(d_i),
- self.bijectors, dtype)
-
- def forward_event_shape(self, x_shape):
- return tf.nest.map_structure(
- lambda b_i, x_i: b_i.forward_event_shape(x_i),
- self.bijectors, x_shape)
-
- def inverse_event_shape(self, y_shape):
- return tf.nest.map_structure(
- lambda b_i, y_i: b_i.inverse_event_shape(y_i),
- self.bijectors, y_shape)
-
- def forward_event_shape_tensor(self, x_shape_tensor):
- return tf.nest.map_structure(
- lambda b_i, x_i: b_i.forward_event_shape_tensor(x_i),
- self.bijectors, x_shape_tensor)
-
- def inverse_event_shape_tensor(self, y_shape_tensor):
- return tf.nest.map_structure(
- lambda b_i, y_i: b_i.inverse_event_shape_tensor(y_i),
- self.bijectors, y_shape_tensor)
-
- def forward_log_det_jacobian(self, x, event_ndims):
- fldj_parts = tf.nest.map_structure(
- lambda b, y, n: b.forward_log_det_jacobian(x, event_ndims=n),
- self.bijectors, x, event_ndims)
- return sum(tf.nest.flatten(fldj_parts))
-
- def inverse_log_det_jacobian(self, y, event_ndims):
- ildj_parts = tf.nest.map_structure(
- lambda b, y, n: b.inverse_log_det_jacobian(y, event_ndims=n),
- self.bijectors, y, event_ndims)
- return sum(tf.nest.flatten(ildj_parts))
-
-
@test_util.test_all_tf_execution_regimes
class MultipartBijectorsTest(test_util.TestCase):
@@ -958,6 +923,10 @@ def test_transform_vector_to_parts(self, known_split_sizes):
bijector = tfb.Split(known_split_sizes, axis=-1)
split_dist = tfd.TransformedDistribution(base_dist, bijector)
+ self.assertRegex(
+ str(split_dist),
+ '{}.*batch_shape.*event_shape.*dtype'.format(split_dist.name))
+
expected_event_shape = [np.array([s]) for s in true_split_sizes]
output_event_shape = [np.array(s) for s in split_dist.event_shape]
self.assertAllEqual(output_event_shape, expected_event_shape)
@@ -1027,11 +996,16 @@ def test_transform_joint_to_joint(self, split_sizes):
minval=1., maxval=2.,
shape=bijector_batch_shape, seed=seed())),
tfb.Reshape([2, 1])]
- bijector = ToyZipMap(tf.nest.pack_sequence_as(split_sizes, bijectors))
+ bijector = tfb.JointMap(tf.nest.pack_sequence_as(split_sizes, bijectors),
+ validate_args=True)
# Transform a joint distribution that has different batch shape components
transformed_dist = tfd.TransformedDistribution(base_dist, bijector)
+ self.assertRegex(
+ str(transformed_dist),
+ '{}.*batch_shape.*event_shape.*dtype'.format(transformed_dist.name))
+
self.assertAllEqualNested(
transformed_dist.event_shape,
bijector.forward_event_shape(base_dist.event_shape))
diff --git a/tensorflow_probability/python/experimental/__init__.py b/tensorflow_probability/python/experimental/__init__.py
index 85db29fa9a..dd26b7443c 100644
--- a/tensorflow_probability/python/experimental/__init__.py
+++ b/tensorflow_probability/python/experimental/__init__.py
@@ -50,12 +50,14 @@
from tensorflow_probability.python.experimental.composite_tensor import register_composite
from tensorflow_probability.python.internal import all_util
from tensorflow_probability.python.internal.auto_composite_tensor import auto_composite_tensor
+from tensorflow_probability.python.internal.auto_composite_tensor import AutoCompositeTensor
_allowed_symbols = [
'auto_batching',
'as_composite',
'auto_composite_tensor',
+ 'AutoCompositeTensor',
'bijectors',
'distribute',
'distributions',
diff --git a/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse.py b/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse.py
index de1b11bc39..616d743721 100644
--- a/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse.py
+++ b/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse.py
@@ -30,7 +30,7 @@ class ScalarFunctionWithInferredInverse(bijector.Bijector):
def __init__(self,
fn,
domain_constraint_fn=None,
- root_search_fn=tfp_math.secant_root,
+ root_search_fn=tfp_math.find_root_secant,
max_iterations=50,
require_convergence=True,
validate_args=False,
@@ -124,21 +124,103 @@ def _wrap_inverse_with_implicit_gradient(self):
"""Wraps the inverse to provide implicit reparameterization gradients."""
def _vjp_fwd(y):
- x = self._inverse_no_gradient(y)
- return x, x # Keep `x` as an auxiliary value for the backwards pass.
+ # Prevent autodiff from trying to backprop through the root search.
+ x = tf.stop_gradient(self._inverse_no_gradient(y))
+ return x, (x, y) # Auxiliary values for the backwards pass.
# By the inverse function theorem, the derivative of an
# inverse function is the reciprocal of the forward derivative. This has
# been popularized in machine learning by [1].
# [1] Michael Figurnov, Shakir Mohamed, Andriy Mnih (2018). Implicit
# Reparameterization Gradients. https://arxiv.org/abs/1805.08498.
- def _vjp_bwd(x, grad_x):
- _, grads = tfp_math.value_and_gradient(self.fn, x)
- return (grad_x / grads,)
+ def _vjp_bwd(aux, dresult_dx):
+ x, y = aux
+ return [dresult_dx /
+ _make_dy_dx_with_implicit_derivative_wrt_y(self.fn, x)(y)]
+
+ def _inverse_jvp(primals, tangents):
+ y, = primals
+ dy, = tangents
+ # Prevent autodiff from trying to backprop through the root search.
+ x = tf.stop_gradient(self._inverse_no_gradient(y))
+ return x, dy / _make_dy_dx_with_implicit_derivative_wrt_y(self.fn, x)(y)
@tfp_custom_gradient.custom_gradient(
vjp_fwd=_vjp_fwd,
- vjp_bwd=_vjp_bwd)
+ vjp_bwd=_vjp_bwd,
+ jvp_fn=_inverse_jvp)
def _inverse_with_gradient(y):
return self._inverse_no_gradient(y)
return _inverse_with_gradient
+
+
+def _make_dy_dx_with_implicit_derivative_wrt_y(fn, x):
+ """Given `y = fn(x)`, returns a function `dy_dx(y)` differentiable wrt y.
+
+ The returned function `dy_dx(y)` computes the reciprocal of the derivative of
+ `x = inverse_of_fn(y)`. By the inverse function theorem, this is just the
+ derivative `dy / dx` of `y = fn(x)`. Even though we'll actually care about
+ the derivative of the inverse, `dx / dy`, it's more efficient to return the
+ reciprocal of that quantity from the forward derivative.
+
+ Since `dy_dx(y)` is the first derivative of `fn(x)` evaluated at
+ `x = inverse_of_fn(y)`, we define *its* derivative in terms of the
+ second derivative of `fn`, via the chain rule:
+
+ ```
+ d / dy fn'(inverse_of_fn(y)) = fn''(inverse_of_fn(y)) * inverse_of_fn'(y)
+ = fn''(x) / fn'(x)
+ ```
+
+ When bijector log-det-jacobians are computed using autodiff, as in
+ `ScalarFunctionWithInferredInverse`, the gradients of the log-det-jacobians
+ make use of these second-derivative annotations.
+
+ Args:
+ fn: Python `callable` invertible scalar function of a scalar `x`. Must be
+ twice differentiable.
+ x: Float `Tensor` input at which `fn` and its derivatives are evaluated.
+ Returns:
+ dy_dx_fn: Python `callable` that takes an argument `y` and returns the
+ derivative of `fn(x)`. The argument `y` is ignored (it is assumed to be
+ `y = fn(x)`), but the derivative of `dy_dx_fn` wrt `y` is defined.
+ """
+
+ # To override first and second derivatives of the inverse
+ # (second derivatives are needed for gradients of
+ # `inverse_log_det_jacobian`s), we'll need the first and second
+ # derivatives from the forward direction.
+ def _dy_dx_fwd(unused_y):
+ first_order = lambda x: tfp_math.value_and_gradient(fn, x)[1]
+ dy_dx, d2y_dx2 = tfp_math.value_and_gradient(first_order, x)
+ return (dy_dx,
+ (dy_dx, d2y_dx2)) # Auxiliary values for the second-order pass.
+
+ # Chain rule for second derivative of an inverse function:
+ # f''(inv_f(y)) = f''(x) * inv_f'(y)
+ # = f''(x) / f'(x).
+ def _dy_dx_bwd(aux, dresult_d_dy_dx):
+ dy_dx, d2y_dx2 = aux
+ return [dresult_d_dy_dx * d2y_dx2 / dy_dx]
+
+ def _dy_dx_jvp(primals, tangents):
+ unused_y, = primals
+ dy, = tangents
+ first_order = lambda x: tfp_math.value_and_gradient(fn, x)[1]
+ dy_dx, ddy_dx2 = tfp_math.value_and_gradient(first_order, x)
+ return dy_dx, (dy / dy_dx) * ddy_dx2
+
+ # Naively, autodiff of this derivative would attempt to backprop through
+ # `x = root_search(fn, y)` when computing the second derivative with
+ # respect to `y`. Since that's no good, we need to provide our own
+ # custom gradient wrt `y`.
+ @tfp_custom_gradient.custom_gradient(
+ vjp_fwd=_dy_dx_fwd,
+ vjp_bwd=_dy_dx_bwd,
+ jvp_fn=_dy_dx_jvp)
+ def _dy_dx_fn(y):
+ del y # Unused.
+ _, dy_dx = tfp_math.value_and_gradient(fn, x)
+ return dy_dx
+
+ return _dy_dx_fn
diff --git a/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py b/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py
index c826ac3608..645121f971 100644
--- a/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py
+++ b/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py
@@ -61,22 +61,39 @@ def test_domain_constraint_fn(self):
self.assertAllClose(xs, bij.inverse(bij.forward(xs)))
@test_util.numpy_disable_gradient_test
- def test_transformed_distribution_log_prob(self):
- uniform = tfd.Uniform(low=0, high=1.)
+ def test_transformed_distribution_log_prob_and_grads(self):
normal = tfd.Normal(loc=0., scale=1.)
xs = self.evaluate(normal.sample(100, seed=test_util.test_seed()))
+ lp_true, lp_grad_true = tfp.math.value_and_gradient(normal.log_prob, xs)
# Define a normal distribution using inverse-CDF sampling. Computing
# log probs under this definition requires inverting the quantile function,
# i.e., numerically approximating `normal.cdf`.
+ uniform = tfd.Uniform(low=0, high=1.)
inverse_transform_normal = tfbe.ScalarFunctionWithInferredInverse(
fn=normal.quantile,
domain_constraint_fn=uniform.experimental_default_event_space_bijector()
)(uniform)
- self.assertAllClose(normal.log_prob(xs),
- inverse_transform_normal.log_prob(xs),
- atol=1e-4)
+ lp, lp_grad = tfp.math.value_and_gradient(inverse_transform_normal.log_prob,
+ xs)
+ self.assertAllClose(lp_true, lp, atol=1e-4)
+ self.assertAllClose(lp_grad_true, lp_grad, atol=1e-4)
+ @test_util.numpy_disable_gradient_test
+ def test_ildj_gradients(self):
+ bij = tfbe.ScalarFunctionWithInferredInverse(lambda x: x**2)
+ ys = tf.convert_to_tensor([0.25, 1., 4., 9.])
+ ildj, ildj_grad = tfp.math.value_and_gradient(
+ lambda y: bij.inverse_log_det_jacobian(y, event_ndims=0),
+ ys)
+
+ # Compare ildjs from inferred inverses to ildjs from the true inverse.
+ def ildj_fn(y):
+ _, inverse_grads = tfp.math.value_and_gradient(tf.sqrt, y)
+ return tf.math.log(tf.abs(inverse_grads))
+ ildj_true, ildj_grad_true = tfp.math.value_and_gradient(ildj_fn, ys)
+ self.assertAllClose(ildj, ildj_true, atol=1e-4)
+ self.assertAllClose(ildj_grad, ildj_grad_true, rtol=1e-4)
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow_probability/python/experimental/distribute/distribute_lib.py b/tensorflow_probability/python/experimental/distribute/distribute_lib.py
index 44e3e8372b..bab1b6b37e 100644
--- a/tensorflow_probability/python/experimental/distribute/distribute_lib.py
+++ b/tensorflow_probability/python/experimental/distribute/distribute_lib.py
@@ -62,9 +62,6 @@ def make_sharded_log_prob_parts(log_prob_parts_fn, is_sharded):
@tf.custom_gradient
def sharded_log_prob_parts(value):
- if not isinstance(value, (list, tuple)):
- raise NotImplementedError('Can only shard functions that output `list`s.'
- ' or `tuple`s')
tf.nest.assert_same_structure(value, is_sharded)
with tf.GradientTape(persistent=True) as tape:
tape.watch(value)
@@ -78,15 +75,15 @@ def sharded_log_prob_parts(value):
is_sharded)
def vjp(*gs):
- assert len(gs) == len(log_prob_parts)
+ gs = tf.nest.pack_sequence_as(log_prob_parts, gs)
def local_grad(v, g):
- return _DummyGrads([
- tape.gradient(log_prob_part, v, output_gradients=g)
- for log_prob_part in log_prob_parts
- ])
+ return _DummyGrads(
+ tf.nest.map_structure(
+ lambda lp: tape.gradient(lp, v, output_gradients=g),
+ log_prob_parts))
- local_grads = tf.nest.map_structure(local_grad, value, list(gs))
+ local_grads = tf.nest.map_structure(local_grad, value, gs)
def value_grad(v, value_sharded, term_grads):
"""Computes reductions of output gradients.
@@ -117,12 +114,15 @@ def value_grad(v, value_sharded, term_grads):
`log_prob_parts` function.
"""
term_grads = term_grads.grads
- total_grad = []
- for term_grad, term_sharded in zip(term_grads, is_sharded):
+
+ def psum_grads(term_grad, term_sharded):
if term_grad is not None:
if not value_sharded and term_sharded:
term_grad = psum(term_grad)
- total_grad.append(term_grad)
+ return term_grad
+
+ total_grad = tf.nest.map_structure(psum_grads, term_grads,
+ is_sharded)
if all([grad is None for grad in tf.nest.flatten(total_grad)]):
return None
return tf.add_n(
diff --git a/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py b/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py
index b75a588e27..3f65709a54 100644
--- a/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py
+++ b/tensorflow_probability/python/experimental/distribute/distribute_lib_test.py
@@ -305,8 +305,48 @@ def true_log_prob(*value):
self.assertAllEqualNested(self.evaluate(out_grads),
self.evaluate(true_grad))
+ def test_correct_gradient_for_global_and_local_variable_dict(self):
-if __name__ == "__main__":
+ @tf.function(autograph=False)
+ def run(w, x, data):
+
+ def log_prob_parts(value):
+ return {
+ 'w': tfd.Normal(0., 1.).log_prob(value['w']),
+ 'x': tfd.Normal(w, 1.).log_prob(value['x']),
+ 'data': tfd.Normal(x, 1.).log_prob(value['data']),
+ }
+
+ def log_prob(*value):
+ w, x = value
+ sharded_log_prob_parts = distribute_lib.make_sharded_log_prob_parts(
+ log_prob_parts, {'w': False, 'x': True, 'data': True})
+ parts = sharded_log_prob_parts({'w': w, 'x': x, 'data': data})
+ return tf.add_n(tf.nest.flatten(parts))
+
+ return tfp.math.value_and_gradient(log_prob, [w, x])[1]
+
+ w = tf.constant(1.)
+ x = tf.range(4.)
+ sharded_x = self.shard_values(x)
+ data = 2 * tf.ones(4)
+ sharded_data = self.shard_values(data)
+ out_grads = per_replica_to_tensor(self.strategy.run(run, (w, sharded_x,
+ sharded_data)))
+
+ def true_log_prob(*value):
+ w, x = value
+ return (tfd.Normal(0., 1.).log_prob(w) +
+ tf.reduce_sum(tfd.Normal(w, 1.).log_prob(x)) +
+ tf.reduce_sum(tfd.Normal(x, 1.).log_prob(data)))
+
+ true_grad = tfp.math.value_and_gradient(true_log_prob, [w, x])[1]
+ true_grad[0] = tf.ones(4) * true_grad[0]
+
+ self.assertAllEqualNested(self.evaluate(out_grads),
+ self.evaluate(true_grad))
+
+if __name__ == '__main__':
tf.enable_v2_behavior()
physical_devices = tf.config.experimental.list_physical_devices()
diff --git a/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py b/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py
index 5df1a0e48c..14867fda3b 100644
--- a/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py
+++ b/tensorflow_probability/python/experimental/distribute/joint_distribution_test.py
@@ -94,23 +94,39 @@ def test_jd(self, dist_fn):
@tf.function(autograph=False)
def run(key):
sample = dist.sample(seed=key)
- return sample, dist.log_prob(sample)
-
- sample, log_prob = per_replica_to_tensor(
- self.strategy.run(run, (tf.ones(2, tf.int32),)))
-
- def true_log_prob(sample):
- if isinstance(dist, jd.JointDistributionNamed):
- w, x, data = sample['w'], sample['x'], sample['data']
- else:
- w, x, data = sample
- return (tfd.Normal(0., 1.).log_prob(w[0]) +
+ # The identity is to prevent reparameterization gradients from kicking in.
+ log_prob, (log_prob_grads,) = tfp.math.value_and_gradient(
+ dist.log_prob, (tf.nest.map_structure(tf.identity, sample),))
+ return sample, log_prob, log_prob_grads
+
+ sample, log_prob, log_prob_grads = self.strategy.run(
+ run, (tf.ones(2, tf.int32),))
+ sample, log_prob, log_prob_grads = per_replica_to_tensor(
+ (sample, log_prob, log_prob_grads))
+
+ def true_log_prob_fn(w, x, data):
+ return (tfd.Normal(0., 1.).log_prob(w) +
tfd.Sample(tfd.Normal(w, 1.), (4, 1)).log_prob(x) +
tfd.Independent(tfd.Normal(x, 1.), 2).log_prob(data))
+ if isinstance(dist, jd.JointDistributionNamed):
+ # N.B. the global RV 'w' gets replicated, so we grab any single replica's
+ # result.
+ w, x, data = sample['w'][0], sample['x'], sample['data']
+ log_prob_grads = (log_prob_grads['w'][0], log_prob_grads['x'],
+ log_prob_grads['data'])
+ else:
+ w, x, data = sample[0][0], sample[1], sample[2]
+ log_prob_grads = (log_prob_grads[0][0], log_prob_grads[1],
+ log_prob_grads[2])
+
+ true_log_prob, true_log_prob_grads = tfp.math.value_and_gradient(
+ true_log_prob_fn, (w, x, data))
+
self.assertAllClose(
- self.evaluate(log_prob),
- self.evaluate(tf.ones(4) * true_log_prob(sample)))
+ self.evaluate(log_prob), self.evaluate(tf.ones(4) * true_log_prob))
+ self.assertAllCloseNested(
+ self.evaluate(log_prob_grads), self.evaluate(true_log_prob_grads))
if __name__ == '__main__':
diff --git a/tensorflow_probability/python/experimental/distribute/sharded.py b/tensorflow_probability/python/experimental/distribute/sharded.py
index 358f825426..31976f5304 100644
--- a/tensorflow_probability/python/experimental/distribute/sharded.py
+++ b/tensorflow_probability/python/experimental/distribute/sharded.py
@@ -36,6 +36,7 @@ def __init__(self,
sample_shape=(),
shard_axis=0,
validate_args=False,
+ experimental_use_kahan_sum=False,
name=None):
"""Construct the `ShardedSample` distribution.
@@ -49,9 +50,16 @@ def __init__(self,
validate_args: Python `bool`. Whether to validate input with asserts. If
`validate_args` is `False`, and the inputs are invalid, correct behavior
is not guaranteed.
+ experimental_use_kahan_sum: Python `bool`. When `True`, we use Kahan
+ summation to aggregate independent underlying log_prob values, which
+ improves against the precision of a naive float32 sum. This can be
+ noticeable in particular for large dimensions in float32. See CPU caveat
+ on `tfp.math.reduce_kahan_sum`.
name: The name for ops managed by the distribution.
Default value: `None` (i.e., `'Sample' + distribution.name`).
"""
+ parameters = dict(locals())
+
with tf.name_scope(name or 'ShardedSample' + distribution.name) as name:
self._shard_axis = shard_axis
@@ -59,7 +67,9 @@ def __init__(self,
distribution,
validate_args=validate_args,
sample_shape=sample_shape,
+ experimental_use_kahan_sum=experimental_use_kahan_sum,
name=name)
+ self._parameters = parameters
@property
def sample_shape(self):
diff --git a/tensorflow_probability/python/experimental/distributions/BUILD b/tensorflow_probability/python/experimental/distributions/BUILD
index 19d151e2c2..edc7375c7b 100644
--- a/tensorflow_probability/python/experimental/distributions/BUILD
+++ b/tensorflow_probability/python/experimental/distributions/BUILD
@@ -52,7 +52,7 @@ multi_substrate_py_library(
multi_substrate_py_test(
name = "joint_distribution_pinned_test",
- size = "medium",
+ size = "large",
srcs = ["joint_distribution_pinned_test.py"],
jax_size = "large",
shard_count = 13,
diff --git a/tensorflow_probability/python/experimental/distributions/joint_distribution_pinned_test.py b/tensorflow_probability/python/experimental/distributions/joint_distribution_pinned_test.py
index f11e695c5a..39dfe85afa 100644
--- a/tensorflow_probability/python/experimental/distributions/joint_distribution_pinned_test.py
+++ b/tensorflow_probability/python/experimental/distributions/joint_distribution_pinned_test.py
@@ -55,15 +55,38 @@ def model():
return model
+def jd_coroutine_autobatched():
+ d0, d1, d2, d3 = part_dists()
+
+ root = tfd.JointDistributionCoroutineAutoBatched.Root
+ @tfd.JointDistributionCoroutineAutoBatched
+ def model():
+ w = yield root(d0)
+ x = yield root(d1)
+ y = yield d2(x)
+ yield d3(y, x, w)
+ return model
+
+
def jd_sequential(model_from_seq=tuple):
return tfd.JointDistributionSequential(model_from_seq(part_dists()))
+def jd_sequential_autobatched(model_from_seq=tuple):
+ return tfd.JointDistributionSequentialAutoBatched(
+ model_from_seq(part_dists()))
+
+
def jd_named():
d0, d1, d2, d3 = part_dists()
return tfd.JointDistributionNamed(dict(w=d0, x=d1, y=d2, z=d3))
+def jd_named_autobatched():
+ d0, d1, d2, d3 = part_dists()
+ return tfd.JointDistributionNamedAutoBatched(dict(w=d0, x=d1, y=d2, z=d3))
+
+
def jd_named_ordered():
d0, d1, d2, d3 = part_dists()
return tfd.JointDistributionNamed(
@@ -82,8 +105,10 @@ def jd_named_namedtuple():
'_'.join(map(str, sample_shape))),
jd_factory=jd_factory,
sample_shape=sample_shape)
- for jd_factory in (jd_coroutine, jd_sequential, jd_named,
- jd_named_ordered, jd_named_namedtuple)
+ for jd_factory in (jd_coroutine, jd_coroutine_autobatched, jd_sequential,
+ jd_sequential_autobatched, jd_named,
+ jd_named_autobatched, jd_named_ordered,
+ jd_named_namedtuple)
# TODO(b/168139745): Add support for: [13], [13, 1], [1, 13]
for sample_shape in ([],)))
class JointDistributionPinnedParameterizedTest(test_util.TestCase):
@@ -95,7 +120,7 @@ def test_pinned_distribution_seq_args(self, jd_factory, sample_shape):
underlying = jd_factory()
tuple_args = (None, x,), (None, x, None, None), (None, x, None, z)
- if jd_factory is jd_named:
+ if jd_factory is jd_named or jd_factory is jd_named_autobatched:
# JDNamed does not support unnamed args unless model is ordered.
for args in tuple_args:
with self.assertRaisesRegexp(ValueError, r'unordered'):
@@ -129,7 +154,8 @@ def test_pinned_distribution_kwargs(self, jd_factory, sample_shape):
self._check_pinning(pinned, sample_shape)
def _check_pinning(self, pinned, sample_shape):
- self.evaluate(pinned.event_shape_tensor())
+ self.evaluate(tf.nest.map_structure(tf.convert_to_tensor,
+ pinned.event_shape_tensor()))
s0 = pinned.sample_unpinned(
sample_shape, seed=test_util.test_seed(sampler_type='stateless'))
diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD
index 7e223ff3b4..8c87d50e75 100644
--- a/tensorflow_probability/python/experimental/mcmc/BUILD
+++ b/tensorflow_probability/python/experimental/mcmc/BUILD
@@ -40,6 +40,7 @@ multi_substrate_py_library(
srcs_version = "PY3",
substrates_omit_deps = [
":covariance_reducer",
+ ":diagonal_mass_matrix_adaptation",
":elliptical_slice_sampler",
":expectations_reducer",
":kernel_builder",
@@ -50,6 +51,7 @@ multi_substrate_py_library(
":preconditioned_hmc",
":progress_bar_reducer",
":reducer",
+ ":run",
":sample",
":sample_discarding_kernel",
":sample_fold",
@@ -58,6 +60,7 @@ multi_substrate_py_library(
],
deps = [
":covariance_reducer",
+ ":diagonal_mass_matrix_adaptation",
":elliptical_slice_sampler",
":expectations_reducer",
":gradient_based_trajectory_length_adaptation",
@@ -70,6 +73,7 @@ multi_substrate_py_library(
":preconditioned_hmc",
":progress_bar_reducer",
":reducer",
+ ":run",
":sample",
":sample_discarding_kernel",
":sample_fold",
@@ -151,6 +155,39 @@ py_test(
],
)
+py_library(
+ name = "diagonal_mass_matrix_adaptation",
+ srcs = ["diagonal_mass_matrix_adaptation.py"],
+ srcs_version = "PY3",
+ deps = [
+ # tensorflow dep,
+ "//tensorflow_probability/python/distributions:independent",
+ "//tensorflow_probability/python/distributions:joint_distribution_sequential",
+ "//tensorflow_probability/python/distributions:mvn_diag",
+ "//tensorflow_probability/python/experimental/stats:sample_stats",
+ "//tensorflow_probability/python/internal:prefer_static",
+ "//tensorflow_probability/python/internal:unnest",
+ "//tensorflow_probability/python/mcmc:kernel",
+ "//tensorflow_probability/python/mcmc/internal:util",
+ ],
+)
+
+py_test(
+ name = "diagonal_mass_matrix_adaptation_test",
+ size = "large",
+ timeout = "long",
+ srcs = ["diagonal_mass_matrix_adaptation_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ # numpy dep,
+ # tensorflow dep,
+ "//tensorflow_probability",
+ "//tensorflow_probability/python/internal:prefer_static",
+ "//tensorflow_probability/python/internal:test_util",
+ ],
+)
+
py_library(
name = "preconditioned_hmc",
srcs = ["preconditioned_hmc.py"],
@@ -590,6 +627,35 @@ py_test(
],
)
+py_library(
+ name = "run",
+ srcs = ["run.py"],
+ srcs_version = "PY3",
+ deps = [
+ ":sample",
+ ":tracing_reducer",
+ ":with_reductions",
+ # numpy dep,
+ # tensorflow dep,
+ "//tensorflow_probability/python/mcmc/internal:util",
+ ],
+)
+
+py_test(
+ name = "run_test",
+ size = "small",
+ srcs = ["run_test.py"],
+ python_version = "PY3",
+ srcs_version = "PY3",
+ deps = [
+ ":run",
+ # tensorflow dep,
+ "//tensorflow_probability",
+ "//tensorflow_probability/python/experimental/mcmc/internal:test_fixtures",
+ "//tensorflow_probability/python/internal:test_util",
+ ],
+)
+
py_library(
name = "sample_fold",
srcs = ["sample_fold.py"],
diff --git a/tensorflow_probability/python/experimental/mcmc/__init__.py b/tensorflow_probability/python/experimental/mcmc/__init__.py
index c2a8183b2e..a13b9bc59d 100644
--- a/tensorflow_probability/python/experimental/mcmc/__init__.py
+++ b/tensorflow_probability/python/experimental/mcmc/__init__.py
@@ -20,6 +20,7 @@
from tensorflow_probability.python.experimental.mcmc.covariance_reducer import CovarianceReducer
from tensorflow_probability.python.experimental.mcmc.covariance_reducer import VarianceReducer
+from tensorflow_probability.python.experimental.mcmc.diagonal_mass_matrix_adaptation import DiagonalMassMatrixAdaptation
from tensorflow_probability.python.experimental.mcmc.elliptical_slice_sampler import EllipticalSliceSampler
from tensorflow_probability.python.experimental.mcmc.expectations_reducer import ExpectationsReducer
from tensorflow_probability.python.experimental.mcmc.gradient_based_trajectory_length_adaptation import chees_criterion
@@ -40,6 +41,7 @@
from tensorflow_probability.python.experimental.mcmc.progress_bar_reducer import make_tqdm_progress_bar_fn
from tensorflow_probability.python.experimental.mcmc.progress_bar_reducer import ProgressBarReducer
from tensorflow_probability.python.experimental.mcmc.reducer import Reducer
+from tensorflow_probability.python.experimental.mcmc.run import run_kernel
from tensorflow_probability.python.experimental.mcmc.sample import step_kernel
from tensorflow_probability.python.experimental.mcmc.sample_discarding_kernel import SampleDiscardingKernel
from tensorflow_probability.python.experimental.mcmc.sample_fold import sample_chain
@@ -64,6 +66,17 @@
__all__ = [
+ 'CovarianceReducer',
+ 'DiagonalMassMatrixAdaptation',
+ 'EllipticalSliceSampler',
+ 'ExpectationsReducer',
+ 'NoUTurnSampler',
+ 'PreconditionedHamiltonianMonteCarlo',
+ 'ProgressBarReducer',
+ 'SequentialMonteCarlo',
+ 'SequentialMonteCarloResults',
+ 'StateWithHistory',
+ 'WeightedParticles',
'augment_prior_with_state_history',
'augment_with_observation_history',
'augment_with_state_history',
@@ -93,6 +106,7 @@
'resample_independent',
'resample_stratified',
'resample_systematic',
+ 'run_kernel',
'sample_chain',
'sample_fold',
'sample_sequential_monte_carlo',
diff --git a/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py b/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py
new file mode 100644
index 0000000000..0c728cff9f
--- /dev/null
+++ b/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation.py
@@ -0,0 +1,307 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""DiagonalMassMatrixAdaptation TransitionKernel."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import tensorflow.compat.v2 as tf
+
+from tensorflow_probability.python.distributions import independent
+from tensorflow_probability.python.distributions import joint_distribution_sequential as jds
+from tensorflow_probability.python.experimental.distributions import mvn_precision_factor_linop as mvn_pfl
+from tensorflow_probability.python.experimental.stats import sample_stats
+from tensorflow_probability.python.internal import auto_composite_tensor
+from tensorflow_probability.python.internal import prefer_static as ps
+from tensorflow_probability.python.internal import unnest
+from tensorflow_probability.python.mcmc import kernel as kernel_base
+from tensorflow_probability.python.mcmc.internal import util as mcmc_util
+
+__all__ = [
+ 'DiagonalMassMatrixAdaptation',
+]
+
+# Add auto-composite tensors to the global namespace to avoid creating new
+# classes inside functions.
+_CompositeJointDistributionSequential = auto_composite_tensor.auto_composite_tensor(
+ jds.JointDistributionSequential, omit_kwargs=('name',))
+_CompositeLinearOperatorDiag = auto_composite_tensor.auto_composite_tensor(
+ tf.linalg.LinearOperatorDiag, omit_kwargs=('name',))
+_CompositeMultivariateNormalPrecisionFactorLinearOperator = auto_composite_tensor.auto_composite_tensor(
+ mvn_pfl.MultivariateNormalPrecisionFactorLinearOperator,
+ omit_kwargs=('name',))
+_CompositeIndependent = auto_composite_tensor.auto_composite_tensor(
+ independent.Independent, omit_kwargs=('name',))
+
+
+def hmc_like_momentum_distribution_setter_fn(kernel_results, new_distribution):
+ """Setter for `momentum_distribution` so it can be adapted."""
+ # Note that unnest.replace_innermost has a special path for going into
+ # `accepted_results` preferentially, so this will set
+ # `accepted_results.momentum_distribution`.
+ return unnest.replace_innermost(
+ kernel_results, momentum_distribution=new_distribution)
+
+
+class DiagonalMassMatrixAdaptationResults(
+ mcmc_util.PrettyNamedTupleMixin,
+ collections.namedtuple('DiagonalMassMatrixAdaptationResults', [
+ 'inner_results',
+ 'running_variance',
+ ])):
+ """Results of the DiagonalMassMatrixAdaptation TransitionKernel.
+
+ Attributes:
+ inner_results: Results of the inner kernel.
+ running_variance: (List of) instance(s) of
+ `tfp.experimental.stats.RunningVariance`, used to set
+ the diagonal covariance of the momentum distribution.
+ """
+ __slots__ = ()
+
+
+class DiagonalMassMatrixAdaptation(kernel_base.TransitionKernel):
+ """Adapts the inner kernel's `momentum_distribution` to estimated variance.
+
+ This kernel uses an online variance estimate to adjust a diagonal covariance
+ matrix for each of the state parts. More specifically, the
+ `momentum_distribution` of the innermost kernel is set to a diagonal
+ multivariate normal distribution whose variance is the *inverse* of the
+ online estimate. The inverse of the covariance of the momentum is often called
+ the "mass matrix" in the context of Hamiltonian Monte Carlo.
+
+ This preconditioning scheme works well when the covariance is diagonally
+ dominant, and may give reasonable results even when the number of draws is
+ less than the dimension. In particular, it should generally do a better job
+ than no preconditioning, which implicitly uses an identity mass matrix.
+
+ Note that this kernel does not implement a calibrated sampler; rather, it is
+ intended to be used as one step of an iterative adaptation process. It
+ should not be used when drawing actual samples.
+ """
+
+ def __init__(
+ self,
+ inner_kernel,
+ initial_running_variance,
+ momentum_distribution_setter_fn=hmc_like_momentum_distribution_setter_fn,
+ validate_args=False,
+ name=None):
+ """Creates the diagonal mass matrix adaptation kernel.
+
+ Users must provide an `initial_running_variance`, either from a previous
+ `DiagonalMassMatrixAdaptation`, or some other source. See
+ `RunningCovariance.from_stats` for a convenient way to construct these.
+
+
+ Args:
+ inner_kernel: `TransitionKernel`-like object.
+ initial_running_variance:
+ `tfp.experimental.stats.RunningVariance`-like object, or list of them,
+ for a batch of momentum distributions. These use `update` on the state
+ to maintain an estimate of the variance, and so space, and so must have
+ a structure compatible with the state space.
+ momentum_distribution_setter_fn: A callable with the signature
+ `(kernel_results, new_momentum_distribution) -> new_kernel_results`
+ where `kernel_results` are the results of the `inner_kernel`,
+ `new_momentum_distribution` is a `CompositeTensor` or a nested
+ collection of `CompositeTensor`s, and `new_kernel_results` are a
+ possibly-modified copy of `kernel_results`. The default,
+ `hmc_like_momentum_distribution_setter_fn`, presumes HMC-style
+ `kernel_results`, and sets the `momentum_distribution` only under the
+ `accepted_results` field.
+ validate_args: Python `bool`. When `True` kernel parameters are checked
+ for validity. When `False` invalid inputs may silently render incorrect
+ outputs.
+ name: Python `str` name prefixed to Ops created by this class. Default:
+ 'diagonal_mass_matrix_adaptation'.
+ """
+ inner_kernel = mcmc_util.enable_store_parameters_in_results(inner_kernel)
+ self._parameters = dict(
+ inner_kernel=inner_kernel,
+ initial_running_variance=initial_running_variance,
+ momentum_distribution_setter_fn=momentum_distribution_setter_fn,
+ name=name,
+ )
+
+ @property
+ def inner_kernel(self):
+ return self._parameters['inner_kernel']
+
+ @property
+ def name(self):
+ return self._parameters['name']
+
+ @property
+ def initial_running_variance(self):
+ return self._parameters['initial_running_variance']
+
+ def momentum_distribution_setter_fn(self, kernel_results,
+ new_momentum_distribution):
+ return self._parameters['momentum_distribution_setter_fn'](
+ kernel_results, new_momentum_distribution)
+
+ @property
+ def parameters(self):
+ """Return `dict` of ``__init__`` arguments and their values."""
+ return self._parameters
+
+ def one_step(self, current_state, previous_kernel_results, seed=None):
+ with tf.name_scope(
+ mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation',
+ 'one_step')):
+ variance_parts = previous_kernel_results.running_variance
+ diags = [variance_part.variance() for variance_part in variance_parts]
+ # Set the momentum.
+ batch_ndims = ps.rank(unnest.get_innermost(previous_kernel_results,
+ 'target_log_prob'))
+ state_parts = tf.nest.flatten(current_state)
+ new_momentum_distribution = _make_momentum_distribution(diags,
+ state_parts,
+ batch_ndims)
+ inner_results = self.momentum_distribution_setter_fn(
+ previous_kernel_results.inner_results, new_momentum_distribution)
+
+ # Step the inner kernel.
+ inner_kwargs = {} if seed is None else dict(seed=seed)
+ new_state, new_inner_results = self.inner_kernel.one_step(
+ current_state, inner_results, **inner_kwargs)
+ new_state_parts = tf.nest.flatten(new_state)
+ new_variance_parts = []
+ for variance_part, diag, state_part in zip(variance_parts, diags,
+ new_state_parts):
+ # Compute new variance for each variance part, accounting for partial
+ # batching of the variance calculation across chains (ie, some, all, or
+ # none of the chains may share the estimated mass matrix).
+ #
+ # For example, say
+ #
+ # state_part has shape [2, 3, 4] + [5, 6] (batch + event)
+ # variance_part has shape [4] + [5, 6]
+ # log_prob has shape [2, 3, 4]
+ #
+ # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass
+ # matrices, each being shared across a [2, 3]-batch of chains. Note this
+ # division is inferred from the shapes of the state part, the log_prob,
+ # and the user-provided initial running variances.
+ #
+ # Until RunningVariance supports rank > 1 chunking, we need to flatten
+ # the states that go into updating the variance estimates. In the above
+ # example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and
+ # fed to `RunningVariance.update(state_part, axis=0)`, recording
+ # 6 new observations in the running variance calculation.
+ # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and
+ # the resulting momentum distribution will have batch shape of
+ # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part.
+ state_rank = ps.rank(state_part)
+ variance_rank = ps.rank(diag)
+ num_reduce_dims = state_rank - variance_rank
+
+ state_part_shape = ps.shape(state_part)
+ # This reshape adds a 1 when reduce_dims==0, and collapses all the lead
+ # dimensions to a single one otherwise.
+ reshaped_state = ps.reshape(
+ state_part,
+ ps.concat(
+ [[ps.reduce_prod(state_part_shape[:num_reduce_dims])],
+ state_part_shape[num_reduce_dims:]], axis=0))
+
+ # The `axis=0` here removes the leading dimension we got from the
+ # reshape above, so the new_variance_parts have the correct shape again.
+ new_variance_parts.append(variance_part.update(reshaped_state,
+ axis=0))
+
+ new_kernel_results = previous_kernel_results._replace(
+ inner_results=new_inner_results,
+ running_variance=new_variance_parts)
+
+ return new_state, new_kernel_results
+
+ def bootstrap_results(self, init_state):
+ with tf.name_scope(
+ mcmc_util.make_name(self.name, 'diagonal_mass_matrix_adaptation',
+ 'bootstrap_results')):
+ if isinstance(self.initial_running_variance,
+ sample_stats.RunningVariance):
+ variance_parts = [self.initial_running_variance]
+ else:
+ variance_parts = self.initial_running_variance
+
+ diags = [variance_part.variance() for variance_part in variance_parts]
+
+ # Step inner results.
+ inner_results = self.inner_kernel.bootstrap_results(init_state)
+ # Set the momentum.
+ batch_ndims = ps.rank(unnest.get_innermost(inner_results,
+ 'target_log_prob'))
+ init_state_parts = tf.nest.flatten(init_state)
+ momentum_distribution = _make_momentum_distribution(
+ diags, init_state_parts, batch_ndims)
+ inner_results = self.momentum_distribution_setter_fn(
+ inner_results, momentum_distribution)
+ proposed = unnest.get_innermost(inner_results, 'proposed_results',
+ default=None)
+ if proposed is not None:
+ proposed = proposed._replace(
+ momentum_distribution=momentum_distribution)
+ inner_results = unnest.replace_innermost(inner_results,
+ proposed_results=proposed)
+ return DiagonalMassMatrixAdaptationResults(
+ inner_results=inner_results,
+ running_variance=variance_parts)
+
+ @property
+ def is_calibrated(self):
+ return False
+
+
+def _make_momentum_distribution(running_variance_parts, state_parts,
+ batch_ndims):
+ """Construct a momentum distribution from the running variance.
+
+ This uses a running variance to construct a momentum distribution with the
+ correct batch_shape and event_shape.
+
+ Args:
+ running_variance_parts: List of `Tensor`, outputs of
+ `tfp.experimental.stats.RunningVariance.variance()`.
+ state_parts: List of `Tensor`.
+ batch_ndims: Scalar, for leading batch dimensions.
+
+ Returns:
+ `tfd.Distribution` where `.sample` has the same structure as `state_parts`,
+ and `.log_prob` of the sample will have the rank of `batch_ndims`
+ """
+ distributions = []
+ for variance_part, state_part in zip(running_variance_parts, state_parts):
+ running_variance_rank = ps.rank(variance_part)
+ state_rank = ps.rank(state_part)
+ # Pad dimensions and tile by multiplying by tf.ones to add a batch shape
+ ones = tf.ones(ps.shape(state_part)[:-(state_rank - running_variance_rank)])
+ ones = mcmc_util.left_justified_expand_dims_like(ones, state_part)
+ variance_tiled = variance_part * ones
+ reinterpreted_batch_ndims = state_rank - batch_ndims - 1
+
+ distributions.append(
+ _CompositeIndependent(
+ _CompositeMultivariateNormalPrecisionFactorLinearOperator(
+ precision_factor=_CompositeLinearOperatorDiag(
+ tf.math.sqrt(variance_tiled)),
+ precision=_CompositeLinearOperatorDiag(variance_tiled)),
+ reinterpreted_batch_ndims=reinterpreted_batch_ndims))
+ return _CompositeJointDistributionSequential(distributions)
diff --git a/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation_test.py b/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation_test.py
new file mode 100644
index 0000000000..3e7677c9eb
--- /dev/null
+++ b/tensorflow_probability/python/experimental/mcmc/diagonal_mass_matrix_adaptation_test.py
@@ -0,0 +1,340 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for DiagonalMassMatrixAdaptation kernel."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+# Dependency imports
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow.compat.v2 as tf
+import tensorflow_probability as tfp
+from tensorflow_probability.python import distributions as tfd
+from tensorflow_probability.python.internal import test_util
+
+RunHMCResults = collections.namedtuple('RunHMCResults', [
+ 'draws',
+ 'final_mean',
+ 'final_precision_factor',
+ 'final_precision',
+ 'empirical_mean',
+ 'empirical_variance',
+ 'true_mean',
+ 'true_variance'])
+
+
+@test_util.test_all_tf_execution_regimes
+class DiagonalMassMatrixAdaptationShapesTest(test_util.TestCase):
+
+ @parameterized.named_parameters([
+ {'testcase_name': '_two_batches_of_three',
+ 'state_part_shape': (2, 3),
+ 'variance_part_shape': (2, 3),
+ 'log_prob_shape': (2,)},
+ {'testcase_name': '_no_batches_of_two',
+ 'state_part_shape': (2,),
+ 'variance_part_shape': (2,),
+ 'log_prob_shape': ()},
+ {'testcase_name': '_batch_of_matrix_batches',
+ 'state_part_shape': (2, 3, 4, 5, 6),
+ 'variance_part_shape': (4, 5, 6),
+ 'log_prob_shape': (2, 3, 4)},
+ ])
+ def testBatches(self, state_part_shape, variance_part_shape, log_prob_shape):
+ dist = tfd.Independent(
+ tfd.Normal(tf.zeros(state_part_shape), tf.ones(state_part_shape)),
+ reinterpreted_batch_ndims=len(state_part_shape) - len(log_prob_shape))
+ state_part = tf.zeros(state_part_shape)
+
+ running_variance = tfp.experimental.stats.RunningVariance.from_stats(
+ num_samples=10.,
+ mean=tf.zeros(variance_part_shape),
+ variance=tf.ones(variance_part_shape))
+
+ kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
+ target_log_prob_fn=dist.log_prob,
+ num_leapfrog_steps=2,
+ step_size=1.)
+ kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
+ inner_kernel=kernel,
+ initial_running_variance=running_variance)
+
+ num_results = 5
+ draws = tfp.mcmc.sample_chain(
+ num_results=num_results,
+ current_state=state_part,
+ kernel=kernel,
+ seed=test_util.test_seed(),
+ trace_fn=None)
+
+ # Make sure the result has the correct shape
+ self.assertEqual(draws.shape, (num_results,) + state_part_shape)
+
+ def testBatchBroadcast(self):
+ n_batches = 8
+ dist = tfd.MultivariateNormalDiag(tf.zeros(3), tf.ones(3))
+ target_log_prob_fn = dist.log_prob
+ kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
+ target_log_prob_fn=target_log_prob_fn,
+ num_leapfrog_steps=2,
+ step_size=1.)
+ initial_running_variance = (
+ tfp.experimental.stats.RunningVariance.from_stats(
+ num_samples=1.,
+ mean=tf.zeros(3),
+ variance=tf.ones(3)))
+ kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
+ inner_kernel=kernel,
+ initial_running_variance=initial_running_variance)
+
+ num_results = 5
+ draws = tfp.mcmc.sample_chain(
+ num_results=num_results,
+ current_state=tf.zeros([n_batches, 3]),
+ kernel=kernel,
+ seed=test_util.test_seed(),
+ trace_fn=None)
+
+ # Make sure the result has the correct shape
+ self.assertEqual(draws.shape, (num_results, n_batches, 3))
+
+ def testMultipleStateParts(self):
+ dist = tfd.JointDistributionSequential([
+ tfd.MultivariateNormalDiag(tf.zeros(3), tf.ones(3)),
+ tfd.MultivariateNormalDiag(tf.zeros(2), tf.ones(2))])
+ target_log_prob_fn = dist.log_prob
+ kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
+ target_log_prob_fn=target_log_prob_fn,
+ num_leapfrog_steps=2,
+ step_size=1.)
+ initial_running_variance = [
+ tfp.experimental.stats.RunningVariance.from_stats(
+ num_samples=1., mean=tf.zeros(3), variance=tf.ones(3)),
+ tfp.experimental.stats.RunningVariance.from_stats(
+ num_samples=1., mean=tf.zeros(2), variance=tf.ones(2))]
+ kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
+ inner_kernel=kernel,
+ initial_running_variance=initial_running_variance)
+
+ num_results = 5
+ draws = tfp.mcmc.sample_chain(
+ num_results=num_results,
+ current_state=[tf.zeros(3), tf.zeros(2)],
+ kernel=kernel,
+ seed=test_util.test_seed(),
+ trace_fn=None)
+
+ # Make sure the result has the correct shape
+ self.assertEqual(len(draws), 2)
+ self.assertEqual(draws[0].shape, (num_results, 3))
+ self.assertEqual(draws[1].shape, (num_results, 2))
+
+
+@test_util.test_graph_and_eager_modes
+class DiagonalMassMatrixAdaptationTest(test_util.TestCase):
+
+ def setUp(self):
+ self.mvn_mean = [0., 0., 0.]
+ self.mvn_scale = [0.1, 1., 10.]
+ super(DiagonalMassMatrixAdaptationTest, self).setUp()
+
+ def testTurnOnStoreParametersInKernelResults(self):
+ mvn = tfd.MultivariateNormalDiag(self.mvn_mean, scale_diag=self.mvn_scale)
+ target_log_prob_fn = mvn.log_prob
+ kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
+ target_log_prob_fn=target_log_prob_fn,
+ num_leapfrog_steps=2,
+ step_size=1.)
+ self.assertFalse(kernel.parameters['store_parameters_in_results'])
+ initial_running_variance = (
+ tfp.experimental.stats.RunningVariance.from_stats(
+ num_samples=1., mean=tf.zeros(3), variance=tf.ones(3)))
+ kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
+ inner_kernel=kernel,
+ initial_running_variance=initial_running_variance)
+ self.assertTrue(
+ kernel.inner_kernel.parameters['store_parameters_in_results'])
+
+ def _run_hmc(self, num_results, initial_running_variance):
+ mvn = tfd.MultivariateNormalDiag(self.mvn_mean, scale_diag=self.mvn_scale)
+ target_log_prob_fn = mvn.log_prob
+ kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
+ target_log_prob_fn=target_log_prob_fn,
+ num_leapfrog_steps=32,
+ step_size=0.001)
+ kernel = tfp.experimental.mcmc.DiagonalMassMatrixAdaptation(
+ inner_kernel=kernel,
+ initial_running_variance=initial_running_variance)
+
+ @tf.function
+ def do_sample():
+
+ def trace_fn(_, pkr):
+ return (pkr.running_variance,
+ pkr.inner_results.accepted_results.momentum_distribution)
+
+ draws, (rv_state, dist) = tfp.mcmc.sample_chain(
+ num_results=num_results,
+ current_state=tf.zeros(3),
+ kernel=kernel,
+ seed=test_util.test_seed(),
+ trace_fn=trace_fn)
+
+ # sample_distributions returns `[dists], [samples]`, so the 0th
+ # distribution corresponds to the 0th, and only, state part
+ # The distribution is an Independent containing the distribution
+ # we want to query, which we access with .distribution
+ momentum_dist = dist.sample_distributions()[0][0].distribution
+ final_precision_factor = tf.linalg.diag_part(
+ momentum_dist.precision_factor)[-1]
+ # Evaluate here so we can check the value twice later
+ final_precision = tf.linalg.diag_part(momentum_dist.precision)[-1]
+ final_mean = rv_state[0].mean[-1]
+ empirical_mean = tf.reduce_mean(draws, axis=0)
+ # The final_precision is taken directly from the momentum distribution,
+ # which never "sees" the last sample.
+ empirical_variance = tf.math.reduce_variance(draws[:-1], axis=0)
+ return RunHMCResults(draws=draws,
+ final_mean=final_mean,
+ final_precision_factor=final_precision_factor,
+ final_precision=final_precision,
+ empirical_mean=empirical_mean,
+ empirical_variance=empirical_variance,
+ true_mean=mvn.mean(),
+ true_variance=mvn.variance())
+ return self.evaluate(do_sample())
+
+ def testUpdatesCorrectly(self):
+ running_variance = tfp.experimental.stats.RunningVariance.from_shape((3,))
+ # This is more straightforward than doing the math, but need at least
+ # two observations to get a start.
+ pseudo_observations = [-tf.ones(3), tf.ones(3)]
+ for pseudo_observation in pseudo_observations:
+ running_variance = running_variance.update(pseudo_observation)
+
+ results = self._run_hmc(
+ num_results=5,
+ initial_running_variance=running_variance)
+ draws = tf.concat([tf.stack(pseudo_observations), results.draws], axis=0)
+ self.assertAllClose(results.final_precision_factor**2,
+ results.final_precision)
+ self.assertAllClose(results.final_mean, tf.reduce_mean(draws, axis=0))
+ self.assertAllClose(results.final_precision,
+ tf.math.reduce_variance(draws[:-1], axis=0))
+
+ def testDoesRegularize(self):
+ # Make sure that using regularization makes the final estimate closer to
+ # the initial state than the empirical result.
+ init_mean = tf.zeros(3)
+ init_variance = tf.ones(3)
+ initial_running_variance = (
+ tfp.experimental.stats.RunningVariance.from_stats(
+ num_samples=10., mean=init_mean, variance=init_variance))
+ results = self._run_hmc(
+ num_results=50,
+ initial_running_variance=initial_running_variance)
+
+ # the `final_mean` should be a weighted average
+ self.assertAllClose(
+ results.final_mean,
+ 10. / 60. * init_mean + 50. / 60. * results.empirical_mean)
+
+ # the `final_precision` is not quite a weighted average, since the
+ # estimate of the mean also gets updated, but it is close-ish
+ self.assertAllClose(
+ results.final_precision,
+ 10. / 60. * init_variance + 50. / 60. * results.empirical_variance,
+ rtol=0.2)
+
+ def testVarGoesInRightDirection(self):
+ # This online updating scheme violates detailed balance, and in general
+ # will not leave the target distribution invariant. We test a weaker
+ # property, which is that the variance gets closer to the target variance,
+ # assuming we start at the correct mean. This test does not pass reliably
+ # when the mean is not near the true mean.
+ error_factor = 5.
+ init_variance = error_factor * tf.convert_to_tensor(self.mvn_scale)**2
+ init_mean = tf.convert_to_tensor(self.mvn_mean)
+ initial_running_variance = (
+ tfp.experimental.stats.RunningVariance.from_stats(
+ num_samples=10., mean=init_mean, variance=init_variance))
+ results = self._run_hmc(
+ num_results=1000,
+ initial_running_variance=initial_running_variance)
+
+ # This number started off at `error_factor`, and should be smaller now
+ # This makes sure it is 90% closer to equal. The intention is that the
+ # precision of the momentum should eventually equal the variance of the
+ # state. We test elsewhere that the precision of the momentum faithfully
+ # updates according to the draws it makes. This makes sure that those draws
+ # are also getting closer to the underlying variance.
+ new_error_factor = 1. + 0.1 * (error_factor - 1.)
+
+ final_var_ratio = results.final_precision / results.true_variance
+ self.assertAllLess(final_var_ratio, new_error_factor)
+
+ def testMeanGoesInRightDirection(self):
+ # As with `testVarGoesInRightDirection`, this makes sure the mean gets
+ # closer. Pleasantly, we do not even need that the variance starts very
+ # close to the true variance.
+ mvn_scale = tf.convert_to_tensor(self.mvn_scale)
+ error_factor = 5. * mvn_scale
+ init_variance = error_factor * mvn_scale**2
+ init_mean = tf.convert_to_tensor(self.mvn_mean) + error_factor
+ initial_running_variance = (
+ tfp.experimental.stats.RunningVariance.from_stats(
+ num_samples=10., mean=init_mean, variance=init_variance))
+ results = self._run_hmc(
+ num_results=1000,
+ initial_running_variance=initial_running_variance)
+
+ # This number started at `error_factor`. Make sure the mean is now at least
+ # 80% closer.
+ final_mean_diff = tf.abs(results.final_mean - results.true_mean)
+ np.testing.assert_array_less(
+ self.evaluate(final_mean_diff),
+ self.evaluate(0.2 * error_factor))
+
+ def testDoesNotGoesInWrongDirection(self):
+ # As above, we test a weaker property, which is that the variance and
+ # mean estimates do not get too away if initialized at the true variance
+ # and mean.
+ initial_running_variance = (
+ tfp.experimental.stats.RunningVariance.from_stats(
+ num_samples=10., mean=self.mvn_mean,
+ variance=tf.convert_to_tensor(self.mvn_scale)**2))
+ results = self._run_hmc(
+ num_results=1000,
+ initial_running_variance=initial_running_variance)
+
+ # Allow the large scale dimension to be a little further off
+ fudge_factor = tf.sqrt(results.true_variance)
+ final_mean_diff = tf.abs(results.final_mean - results.true_mean)
+ np.testing.assert_array_less(self.evaluate(final_mean_diff),
+ self.evaluate(fudge_factor))
+
+ final_std_diff = tf.abs(results.final_precision_factor -
+ tf.sqrt(results.true_variance))
+ np.testing.assert_array_less(self.evaluate(final_std_diff),
+ self.evaluate(fudge_factor))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py
index fdd2754cc1..527ee5a833 100644
--- a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py
+++ b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py
@@ -19,7 +19,6 @@
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.internal import assert_util
-from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import unnest
@@ -412,14 +411,8 @@ class docstring).
with tf.name_scope(
mcmc_util.make_name(name, 'gradient_based_trajectory_length_adaptation',
'__init__')) as name:
- dtype = dtype_util.common_dtype([adaptation_rate, jitter_amount],
- tf.float32)
num_adaptation_steps = tf.convert_to_tensor(
num_adaptation_steps, dtype=tf.int32, name='num_adaptation_steps')
- adaptation_rate = tf.convert_to_tensor(
- adaptation_rate, dtype=dtype, name='adaptation_rate')
- jitter_amount = tf.convert_to_tensor(
- jitter_amount, dtype=dtype, name='jitter_amount')
max_leapfrog_steps = tf.convert_to_tensor(
max_leapfrog_steps, dtype=tf.int32, name='max_leapfrog_steps')
@@ -571,7 +564,7 @@ def bootstrap_results(self, init_state):
'gradient_based_trajectory_length_adaptation',
'bootstrap_results')):
inner_results = self.inner_kernel.bootstrap_results(init_state)
- dtype = self.parameters['adaptation_rate'].dtype
+ dtype = self.log_accept_prob_getter_fn(inner_results).dtype
init_state = tf.nest.map_structure(
lambda x: tf.convert_to_tensor(x, dtype=dtype), init_state)
step_size = _ensure_step_size_is_scalar(
@@ -583,8 +576,11 @@ def bootstrap_results(self, init_state):
inner_results=inner_results,
max_trajectory_length=init_max_trajectory_length,
step=tf.zeros([], tf.int32),
- adaptation_rate=self.parameters['adaptation_rate'],
- jitter_amount=self.parameters['jitter_amount'],
+ adaptation_rate=tf.convert_to_tensor(
+ self.parameters['adaptation_rate'], dtype,
+ name='adaptation_rate'),
+ jitter_amount=tf.convert_to_tensor(
+ self.parameters['jitter_amount'], dtype, name='jitter_amount'),
averaged_sq_grad=tf.zeros([], dtype),
averaged_max_trajectory_length=tf.zeros([], dtype),
criterion=tf.zeros_like(
diff --git a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py
index 920cb2c7af..e005840f69 100644
--- a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py
+++ b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py
@@ -121,7 +121,6 @@ def target_log_prob_fn(*x):
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(
kernel,
num_adaptation_steps=num_adaptation_steps,
- adaptation_rate=tf.constant(0.025, self.dtype),
validate_args=True)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
kernel, num_adaptation_steps=num_adaptation_steps)
@@ -179,7 +178,7 @@ def target_log_prob_fn(x):
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(
kernel,
num_adaptation_steps=5,
- adaptation_rate=tf.constant(1., self.dtype),
+ adaptation_rate=1.,
validate_args=True)
state = tf.zeros([64], self.dtype)
@@ -211,7 +210,7 @@ def target_log_prob_fn(x):
tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(
kernel,
num_adaptation_steps=5,
- adaptation_rate=tf.constant(1., self.dtype),
+ adaptation_rate=1.,
validate_args=True))
state = tf.zeros([64, 2, 3], self.dtype)
@@ -242,7 +241,7 @@ def target_log_prob_fn(x, y):
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(
kernel,
num_adaptation_steps=5,
- adaptation_rate=tf.constant(1., self.dtype),
+ adaptation_rate=1.,
validate_args=True)
state = [tf.zeros([64], self.dtype), tf.zeros([64], self.dtype)]
@@ -273,7 +272,7 @@ def target_log_prob_fn(x):
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(
kernel,
num_adaptation_steps=1,
- adaptation_rate=tf.constant(1., self.dtype),
+ adaptation_rate=1.,
validate_args=True)
state = tf.zeros([64], self.dtype)
diff --git a/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py b/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py
index 08fb167f74..ceb385f4ac 100644
--- a/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py
+++ b/tensorflow_probability/python/experimental/mcmc/preconditioned_hmc.py
@@ -280,7 +280,7 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
state_gradients_are_stopped=self.state_gradients_are_stopped)
seed = samplers.sanitize_seed(seed)
- current_momentum_parts = momentum_distribution.sample(seed=seed)
+ current_momentum_parts = list(momentum_distribution.sample(seed=seed))
momentum_log_prob = getattr(momentum_distribution,
'_log_prob_unnormalized',
momentum_distribution.log_prob)
@@ -331,14 +331,15 @@ def bootstrap_results(self, init_state):
mcmc_util.make_name(self.name, 'phmc', 'bootstrap_results')):
result = super(UncalibratedPreconditionedHamiltonianMonteCarlo,
self).bootstrap_results(init_state)
+
+ if (not self._store_parameters_in_results or
+ self.momentum_distribution is None):
+ momentum_distribution = []
+ else:
+ momentum_distribution = self.momentum_distribution
result = UncalibratedPreconditionedHamiltonianMonteCarloKernelResults(
**result._asdict(), # pylint: disable=protected-access
- momentum_distribution=[])
-
- if self._store_parameters_in_results:
- result = result._replace(
- momentum_distribution=[] if self.momentum_distribution is None else
- self.momentum_distribution)
+ momentum_distribution=momentum_distribution)
return result
diff --git a/tensorflow_probability/python/experimental/mcmc/run.py b/tensorflow_probability/python/experimental/mcmc/run.py
new file mode 100644
index 0000000000..32db8c0077
--- /dev/null
+++ b/tensorflow_probability/python/experimental/mcmc/run.py
@@ -0,0 +1,271 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""High(er) level driver for streaming MCMC."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import warnings
+# Dependency imports
+
+import tensorflow.compat.v2 as tf
+from tensorflow_probability.python.experimental.mcmc import sample as exp_sample_lib
+from tensorflow_probability.python.experimental.mcmc import tracing_reducer
+from tensorflow_probability.python.experimental.mcmc import with_reductions
+from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import
+
+__all__ = [
+ 'run_kernel',
+]
+
+
+def _trace_everything(chain_state, kernel_results, *reduction_results):
+ del kernel_results
+ return chain_state, reduction_results
+
+
+class RunKernelResults(collections.namedtuple(
+ 'RunKernelResults', ['trace', 'reduction_results', 'final_state',
+ 'final_kernel_results', 'resume_kwargs'])):
+ """Result from a sampling run.
+
+ Attributes:
+ trace: A `Tensor` or a nested collection of `Tensor`s representing the
+ values during the run, if any.
+
+ reduction_results: A `Tensor` or a nested collection of `Tensor`s giving the
+ results of any requested reductions.
+
+ final_state: A `Tensor` or a nested collection of `Tensor`s giving the final
+ state of the Markov chain.
+
+ final_kernel_results: The last auxiliary state of the `kernel` that was run.
+
+ resume_kwargs: A dict of `Tensor` or nested collections of `Tensor`s giving
+ keyword arguments that can be used to continue the Markov chain (and
+ auxiliaries) where it left off.
+ """
+ # This list of fields is meant to grow as we decide what metrics, diagnostics,
+ # or auxiliary information MCMC entry points should return. Part of the idea,
+ # like scipy.optimize, is to admit multiple entry points; insofar as they all
+ # need to return similar information, we should use consistent field names to
+ # store them, so users can change entry points without having to write (as
+ # much) glue code.
+
+ # Specific possible fields to add:
+ # - Performance diagnostics such as number of log_prob and gradient
+ # evaluations
+ # - Statistical diagnostics such as ESS or R-hat
+ # - Internal diagnostics about adaptation convergence, etc
+ # - Once our methods become sophisticated enough to evaluate their own
+ # efficacy, we can also adopt a "success" boolean, failure reason message,
+ # and things like that.
+ __slots__ = ()
+
+
+def run_kernel(
+ kernel,
+ num_results,
+ current_state,
+ previous_kernel_results=None,
+ reducer=(),
+ previous_reducer_state=None,
+ trace_fn=_trace_everything,
+ parallel_iterations=10,
+ seed=None,
+ name=None,
+):
+ """Runs a Markov chain defined by the given `TransitionKernel`.
+
+ This is meant as a (more) helpful frontend to the low-level
+ `TransitionKernel`-based MCMC API, supporting several main features:
+
+ - Running a batch of multiple independent chains using SIMD parallelism
+ - Tracing the history of the chains, or not tracing it to save memory
+ - Computing reductions over chain history, whether it is also traced or not
+ - Warm (re-)start, including auxiliary state
+
+ This function samples from a Markov chain at `current_state` whose
+ stationary distribution is governed by the supplied `TransitionKernel`
+ instance (`kernel`).
+
+ The `current_state` can be represented as a single `Tensor` or a `list` of
+ `Tensors` which collectively represent the current state.
+
+ This function can sample from multiple chains, in parallel. Whether or not
+ there are multiple chains is dictated by how the `kernel` treats its inputs.
+ Typically, the shape of the independent chains is shape of the result of the
+ `target_log_prob_fn` used by the `kernel` when applied to the given
+ `current_state`.
+
+ This function can compute reductions over the samples in tandem with sampling,
+ for example to return summary statistics without materializing all the
+ samples. To request reductions, pass a `Reducer` object, or a nested
+ structure of `Reducer` objects, as the `reducer=` argument.
+
+ In addition to the chain state, this function supports tracing of auxiliary
+ variables used by the kernel, as well as intermediate values of any supplied
+ reductions. The traced values are selected by specifying `trace_fn`. The
+ `trace_fn` must be a callable accepting three arguments: the chain state, the
+ kernel_results of the `kernel`, and the current results of the reductions, if
+ any are supplied. The return value of `trace_fn` (which may be a `Tensor` or
+ a nested structure of `Tensor`s) is accumulated, such that each `Tensor` gains
+ a new outmost dimension representing time in the chain history.
+
+ Since MCMC states are correlated, it is sometimes desirable to produce
+ additional intermediate states, and then discard them, ending up with a set of
+ states with decreased autocorrelation. See [Owen (2017)][1]. Such 'thinning'
+ is made possible by setting `num_steps_between_results > 0`. The chain then
+ takes `num_steps_between_results` extra steps between the steps that make it
+ into the results, or are shown to any supplied reductions. The extra steps
+ are never materialized, and thus do not increase memory requirements.
+
+ Args:
+ kernel: An instance of `tfp.mcmc.TransitionKernel` which implements one step
+ of the Markov chain.
+ num_results: Integer number of (non-discarded) Markov chain draws to
+ compute.
+ current_state: `Tensor` or Python `list` of `Tensor`s representing the
+ initial state(s) of the Markov chain(s).
+ previous_kernel_results: A `Tensor` or a nested collection of `Tensor`s
+ representing internal calculations made within the previous call to this
+ function (or as returned by `bootstrap_results`).
+ reducer: A (possibly nested) structure of `Reducer`s to be evaluated
+ on the `kernel`'s samples. If no reducers are given (`reducer=None`),
+ their states will not be passed to any supplied `trace_fn`.
+ previous_reducer_state: A (possibly nested) structure of running states
+ corresponding to the structure in `reducer`. For resuming streaming
+ reduction computations begun in a previous run.
+ trace_fn: A callable that takes in the current chain state, the current
+ auxiliary kernel state, and the current result of any reducers, and
+ returns a `Tensor` or a nested collection of `Tensor`s that is then
+ traced. If `None`, nothing is traced.
+ parallel_iterations: The number of iterations allowed to run in parallel. It
+ must be a positive integer. See `tf.while_loop` for more details.
+ seed: Optional, a seed for reproducible sampling.
+ name: Python `str` name prefixed to Ops created by this function.
+ Default value: `None` (i.e., 'mcmc_run_kernel').
+
+ Returns:
+
+ result: A `RunKernelResults` instance containing information about the
+ sampling run. Main fields are `trace`, the history of outputs of
+ `trace_fn`, and `reduction_results`, the final outputs of all supplied
+ `Reducer`s. See `RunKernelResults` for contents of other fields.
+ """
+ # Features omitted for simplicity:
+ # - Can only warm start either all the reducers or none of them, not
+ # piecemeal.
+ #
+ # Defects admitted for simplicity:
+ # - All reducers are finalized internally at every step, whether the user
+ # wished to trace them or not. We expect graph mode TF to avoid that unused
+ # computation, but eager mode will not.
+ # - The user is not given the opportunity to trace the running state of
+ # reducers. For example, the user cannot trace the sum and count of a
+ # running mean, only the running mean itself. Arguably this is a feature,
+ # because the sum and count can be considered implementation details, the
+ # hiding of which is the purpose of the `finalize` method.
+ with tf.name_scope(name or 'mcmc_run_kernel'):
+ if not kernel.is_calibrated:
+ warnings.warn('supplied `TransitionKernel` is not calibrated. Markov '
+ 'chain may not converge to intended target distribution.')
+
+ if trace_fn is None:
+ trace_fn = lambda *args: ()
+
+ # Form kernel onion
+ reduction_kernel = with_reductions.WithReductions(
+ inner_kernel=kernel,
+ reducer=reducer)
+
+ # User trace function should be called with
+ # - current chain state
+ # - kernel results structure of the passed-in kernel
+ # - if there were any reducers, their intermediate results
+ #
+ # `WithReductions` will show the TracingReducer the intermediate state as
+ # the kernel results of the onion named `reduction_kernel` above. This
+ # wrapper converts from that to what the user-supplied trace function needs
+ # to see.
+ def internal_trace_fn(curr_state, kr):
+ if reducer:
+ def fin(reducer, red_state):
+ return reducer.finalize(red_state)
+ # Extra level of list will be unwrapped by *reduction_args, below.
+ reduction_args = [nest.map_structure_up_to(
+ reducer, fin, reducer, kr.reduction_results)]
+ else:
+ reduction_args = []
+ return trace_fn(curr_state, kr.inner_results, *reduction_args)
+
+ trace_reducer = tracing_reducer.TracingReducer(
+ trace_fn=internal_trace_fn,
+ size=num_results
+ )
+ tracing_kernel = with_reductions.WithReductions(
+ inner_kernel=reduction_kernel,
+ reducer=trace_reducer,
+ )
+
+ # Bootstrap corresponding warm start
+ if previous_kernel_results is None:
+ previous_kernel_results = kernel.bootstrap_results(current_state)
+ reduction_pkr = reduction_kernel.bootstrap_results(
+ current_state, previous_kernel_results, previous_reducer_state)
+ tracing_pkr = tracing_kernel.bootstrap_results(
+ current_state, reduction_pkr)
+
+ # pylint: disable=unbalanced-tuple-unpacking
+ final_state, tracing_kernel_results = exp_sample_lib.step_kernel(
+ num_steps=num_results,
+ current_state=current_state,
+ previous_kernel_results=tracing_pkr,
+ kernel=tracing_kernel,
+ return_final_kernel_results=True,
+ parallel_iterations=parallel_iterations,
+ seed=seed,
+ name=name,
+ )
+
+ trace = trace_reducer.finalize(
+ tracing_kernel_results.reduction_results)
+
+ reduction_kernel_results = tracing_kernel_results.inner_results
+ reduction_results = nest.map_structure_up_to(
+ reducer,
+ lambda r, s: r.finalize(s),
+ reducer,
+ reduction_kernel_results.reduction_results,
+ check_types=False)
+
+ user_kernel_results = reduction_kernel_results.inner_results
+
+ resume_kwargs = {
+ 'current_state': final_state,
+ 'previous_kernel_results': user_kernel_results,
+ 'kernel': kernel,
+ 'reducer': reducer,
+ 'previous_reducer_state': reduction_kernel_results.reduction_results,
+ }
+
+ return RunKernelResults(
+ trace=trace,
+ reduction_results=reduction_results,
+ final_state=final_state,
+ final_kernel_results=user_kernel_results,
+ resume_kwargs=resume_kwargs)
diff --git a/tensorflow_probability/python/experimental/mcmc/run_test.py b/tensorflow_probability/python/experimental/mcmc/run_test.py
new file mode 100644
index 0000000000..c07fd334cc
--- /dev/null
+++ b/tensorflow_probability/python/experimental/mcmc/run_test.py
@@ -0,0 +1,125 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for high(er) level drivers for streaming MCMC."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+
+import tensorflow.compat.v2 as tf
+import tensorflow_probability as tfp
+
+from tensorflow_probability.python.experimental.mcmc.internal import test_fixtures
+from tensorflow_probability.python.internal import test_util
+
+
+@test_util.test_all_tf_execution_regimes
+class RunTest(test_util.TestCase):
+
+ def test_simple_reduction(self):
+ fake_kernel = test_fixtures.TestTransitionKernel()
+ fake_reducer = test_fixtures.NaiveMeanReducer()
+ result = tfp.experimental.mcmc.run_kernel(
+ num_results=5,
+ current_state=0.,
+ kernel=fake_kernel,
+ reducer=fake_reducer,
+ )
+ last_sample, reduction_result, kernel_results = self.evaluate([
+ result.final_state, result.reduction_results,
+ result.final_kernel_results
+ ])
+ self.assertEqual(5, last_sample)
+ self.assertEqual(3, reduction_result)
+ self.assertEqual(5, kernel_results.counter_1)
+ self.assertEqual(10, kernel_results.counter_2)
+
+ # Warm-restart the underlying kernel but not the reduction
+ result_2 = tfp.experimental.mcmc.run_kernel(
+ num_results=5,
+ current_state=last_sample,
+ kernel=fake_kernel,
+ reducer=fake_reducer,
+ previous_kernel_results=kernel_results,
+ )
+ last_sample_2, reduction_result_2, kernel_results_2 = self.evaluate([
+ result_2.final_state, result_2.reduction_results,
+ result_2.final_kernel_results
+ ])
+ self.assertEqual(10, last_sample_2)
+ self.assertEqual(8, reduction_result_2)
+ self.assertEqual(10, kernel_results_2.counter_1)
+ self.assertEqual(20, kernel_results_2.counter_2)
+
+ def test_reducer_warm_restart(self):
+ fake_kernel = test_fixtures.TestTransitionKernel()
+ fake_reducer = test_fixtures.NaiveMeanReducer()
+ result = tfp.experimental.mcmc.run_kernel(
+ num_results=5,
+ current_state=0.,
+ kernel=fake_kernel,
+ reducer=fake_reducer,
+ )
+ last_sample, red_res, kernel_results = self.evaluate([
+ result.final_state, result.reduction_results,
+ result.final_kernel_results
+ ])
+ self.assertEqual(3, red_res)
+ self.assertEqual(5, last_sample)
+ self.assertEqual(5, kernel_results.counter_1)
+ self.assertEqual(10, kernel_results.counter_2)
+
+ # Warm-restart the underlying kernel and the reduction using the provided
+ # restart package
+ result_2 = tfp.experimental.mcmc.run_kernel(
+ num_results=5, **result.resume_kwargs)
+ last_sample_2, reduction_result_2, kernel_results_2 = self.evaluate([
+ result_2.final_state, result_2.reduction_results,
+ result_2.final_kernel_results
+ ])
+ self.assertEqual(5.5, reduction_result_2)
+ self.assertEqual(10, last_sample_2)
+ self.assertEqual(10, kernel_results_2.counter_1)
+ self.assertEqual(20, kernel_results_2.counter_2)
+
+ def test_tracing_a_reduction(self):
+ fake_kernel = test_fixtures.TestTransitionKernel()
+ fake_reducer = test_fixtures.NaiveMeanReducer()
+ result = tfp.experimental.mcmc.run_kernel(
+ num_results=5,
+ current_state=0.,
+ kernel=fake_kernel,
+ reducer=fake_reducer,
+ trace_fn=lambda _state, _kr, reductions: reductions
+ )
+ trace = self.evaluate(result.trace)
+ self.assertAllEqual(trace, [1.0, 1.5, 2.0, 2.5, 3.0])
+
+ def test_tracing_no_reduction(self):
+ fake_kernel = test_fixtures.TestTransitionKernel()
+ result = tfp.experimental.mcmc.run_kernel(
+ num_results=5,
+ current_state=0.,
+ kernel=fake_kernel,
+ trace_fn=lambda state, _kr: state + 10
+ )
+ trace = self.evaluate(result.trace)
+ self.assertAllEqual(trace, [11.0, 12.0, 13.0, 14.0, 15.0])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow_probability/python/experimental/mcmc/sample.py b/tensorflow_probability/python/experimental/mcmc/sample.py
index 5f6ad89931..ccae18fd37 100644
--- a/tensorflow_probability/python/experimental/mcmc/sample.py
+++ b/tensorflow_probability/python/experimental/mcmc/sample.py
@@ -43,6 +43,9 @@ def step_kernel(
):
"""Takes `num_steps` repeated `TransitionKernel` steps from `current_state`.
+ This is meant to be a minimal driver for executing `TransitionKernel`s; for
+ something more featureful, see `run_kernel`.
+
Args:
num_steps: Integer number of Markov chain steps.
current_state: `Tensor` or Python `list` of `Tensor`s representing the
diff --git a/tensorflow_probability/python/experimental/stats/BUILD b/tensorflow_probability/python/experimental/stats/BUILD
index 01effb6d03..9428d71e9d 100644
--- a/tensorflow_probability/python/experimental/stats/BUILD
+++ b/tensorflow_probability/python/experimental/stats/BUILD
@@ -55,7 +55,6 @@ py_test(
size = "small",
srcs = ["sample_stats_test.py"],
python_version = "PY3",
- shard_count = 10,
srcs_version = "PY3",
deps = [
":sample_stats",
diff --git a/tensorflow_probability/python/experimental/stats/sample_stats.py b/tensorflow_probability/python/experimental/stats/sample_stats.py
index 035351d0f4..c5c159870a 100644
--- a/tensorflow_probability/python/experimental/stats/sample_stats.py
+++ b/tensorflow_probability/python/experimental/stats/sample_stats.py
@@ -19,7 +19,6 @@
from __future__ import print_function
import functools
-import inspect
import math
# Dependency imports
@@ -43,7 +42,7 @@
@auto_composite_tensor.auto_composite_tensor(omit_kwargs='name')
-class RunningCovariance(object):
+class RunningCovariance(auto_composite_tensor.AutoCompositeTensor):
"""A running covariance computation.
The running covariance computation supports batching. The `event_ndims`
@@ -301,14 +300,7 @@ def from_shape(cls, shape=(), dtype=tf.float32):
Returns:
var: An empty `RunningCovariance`, ready for incoming samples.
"""
- # TODO(b/172068479): Get rid of this method resolution order hack.
- mro = inspect.getmro(RunningVariance)
- # This `super` needs to exclude not just the subclass of CompositeTensor
- # that `auto_composite_tensor` generates, but also the base class
- # `RunningVariance`. That way, we get the from_shape of
- # `RunningCovariance`, which is what we want here.
- # pylint: disable=bad-super-call
- return super(mro[1], cls).from_shape(shape, dtype, event_ndims=0)
+ return super().from_shape(shape, dtype, event_ndims=0)
def variance(self, ddof=0):
"""Returns the variance accumulated so far.
@@ -324,7 +316,7 @@ def variance(self, ddof=0):
return self.covariance(ddof)
@classmethod
- def init_from_stats(cls, num_samples, mean, variance):
+ def from_stats(cls, num_samples, mean, variance):
"""Initialize a `RunningVariance` object with given stats.
This allows the user to initialize knowing the mean, variance, and number
@@ -349,7 +341,7 @@ def init_from_stats(cls, num_samples, mean, variance):
@auto_composite_tensor.auto_composite_tensor(omit_kwargs='name')
-class RunningMean(object):
+class RunningMean(auto_composite_tensor.AutoCompositeTensor):
"""Computes a running mean.
In computation, samples can be provided individually or in chunks. A
@@ -437,7 +429,7 @@ def update(self, new_sample, axis=None):
@auto_composite_tensor.auto_composite_tensor
-class RunningCentralMoments(object):
+class RunningCentralMoments(auto_composite_tensor.AutoCompositeTensor):
"""Computes running central moments.
`RunningCentralMoments` will compute arbitrary central moments in
@@ -589,7 +581,7 @@ def _n_choose_k(n, k):
@auto_composite_tensor.auto_composite_tensor(omit_kwargs='name')
-class RunningPotentialScaleReduction(object):
+class RunningPotentialScaleReduction(auto_composite_tensor.AutoCompositeTensor):
"""A running R-hat diagnostic.
`RunningPotentialScaleReduction` uses Gelman and Rubin (1992)'s potential
diff --git a/tensorflow_probability/python/experimental/stats/sample_stats_test.py b/tensorflow_probability/python/experimental/stats/sample_stats_test.py
index f8cc6433e6..aa4e7eec41 100644
--- a/tensorflow_probability/python/experimental/stats/sample_stats_test.py
+++ b/tensorflow_probability/python/experimental/stats/sample_stats_test.py
@@ -25,12 +25,20 @@
import numpy as np
import scipy.stats as stats
-import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import test_util
+def consume(running_stat, elems, chunk_axis=None):
+ def body(running_stat, elem):
+ if chunk_axis is None:
+ return running_stat.update(elem)
+ else:
+ return running_stat.update(elem, axis=chunk_axis)
+ return tf.foldl(body, elems, running_stat)
+
+
@test_util.test_all_tf_execution_regimes
class RunningCovarianceTest(test_util.TestCase):
@@ -38,27 +46,17 @@ def test_from_stats(self):
num_counts = 10.
mean = 1.
variance = 3.
- var = tfp.experimental.stats.RunningVariance.init_from_stats(
+ var = tfp.experimental.stats.RunningVariance.from_stats(
num_counts, mean, variance)
self.assertEqual(self.evaluate(var.mean), mean)
self.assertEqual(self.evaluate(var.variance()), variance)
- def test_zero_running_variance(self):
- deterministic_samples = [0., 0., 0., 0.]
- var = tfp.experimental.stats.RunningVariance.from_shape()
- for sample in deterministic_samples:
- var = var.update(sample)
- final_mean, final_var = self.evaluate([var.mean, var.variance()])
- self.assertEqual(final_mean, 0.)
- self.assertEqual(final_var, 0.)
-
@parameterized.parameters(0, 1)
def test_running_variance(self, ddof):
rng = test_util.test_np_rng()
x = rng.rand(100)
var = tfp.experimental.stats.RunningVariance.from_shape()
- for sample in x:
- var = var.update(sample)
+ var = consume(var, x)
final_mean, final_var = self.evaluate([var.mean, var.variance(ddof=ddof)])
self.assertNear(np.mean(x), final_mean, err=1e-6)
self.assertNear(np.var(x, ddof=ddof), final_var, err=1e-6)
@@ -76,30 +74,11 @@ def test_integer_running_covariance(self):
self.assertNear(2, final_mean, err=1e-6)
self.assertNear(2, final_cov, err=1e-6)
- def test_higher_rank_running_variance(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 5, 2)
- var = tfp.experimental.stats.RunningVariance.from_shape(
- tf.TensorShape([5, 2]))
- for sample in x:
- var = var.update(sample)
- final_mean, final_var = self.evaluate([var.mean, var.variance()])
- self.assertAllClose(np.mean(x, axis=0), final_mean, rtol=1e-5)
- self.assertEqual(final_var.shape, (5, 2))
-
- # reshaping to be compatible with a check against numpy
- x_reshaped = x.reshape(100, 10)
- final_var_reshape = tf.reshape(final_var, (10,))
- self.assertAllClose(np.var(x_reshaped, axis=0),
- final_var_reshape,
- rtol=1e-5)
-
- def test_chunked_running_variance(self):
+ def test_chunked_higher_rank_running_variance(self):
rng = test_util.test_np_rng()
x = rng.rand(100, 2, 5)
var = tfp.experimental.stats.RunningVariance.from_shape((5,))
- for sample in x:
- var = var.update(sample, axis=0)
+ var = consume(var, x, chunk_axis=0)
final_mean, final_var = self.evaluate([var.mean, var.variance()])
self.assertAllClose(
np.mean(x.reshape(200, 5), axis=0),
@@ -113,36 +92,6 @@ def test_chunked_running_variance(self):
final_var,
rtol=1e-5)
- def test_dynamic_shape_running_variance(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 2, 5)
- var = tfp.experimental.stats.RunningVariance.from_shape((5,), tf.float64)
- for sample in x:
- if not tf.executing_eagerly():
- sample = tf1.placeholder_with_default(sample, shape=None)
- var = var.update(sample, axis=0)
- final_var = self.evaluate(var.variance())
- x_reshaped = x.reshape(200, 5)
- self.assertEqual(final_var.shape, (5,))
- self.assertAllClose(np.var(x_reshaped, axis=0),
- final_var,
- rtol=1e-5)
-
- def test_running_covariance_as_variance(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 5, 2)
- cov = tfp.experimental.stats.RunningCovariance.from_shape(
- tf.TensorShape([5, 2]),
- event_ndims=0)
- var = tfp.experimental.stats.RunningVariance.from_shape(
- tf.TensorShape([5, 2]))
- for sample in x:
- cov = cov.update(sample)
- var = var.update(sample)
- final_cov, final_var = self.evaluate([cov.covariance(), var.variance()])
- self.assertEqual(final_cov.shape, (5, 2))
- self.assertAllClose(final_cov, final_var, rtol=1e-5)
-
def test_zero_running_covariance(self):
fake_samples = [[0., 0.] for _ in range(2)]
cov = tfp.experimental.stats.RunningCovariance.from_shape((2,))
@@ -157,36 +106,18 @@ def test_running_covariance(self, ddof):
rng = test_util.test_np_rng()
x = rng.rand(100, 10)
cov = tfp.experimental.stats.RunningCovariance.from_shape((10,))
- for sample in x:
- cov = cov.update(sample)
+ cov = consume(cov, x)
final_mean, final_cov = self.evaluate([cov.mean, cov.covariance(ddof=ddof)])
self.assertAllClose(np.mean(x, axis=0), final_mean, rtol=1e-5)
self.assertAllClose(np.cov(x.T, ddof=ddof), final_cov, rtol=1e-5)
+ self.assertEqual(cov.event_ndims, 1)
+ self.assertEqual(cov.mean.dtype, tf.float32)
- def test_higher_rank_running_covariance(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 5, 2)
- cov = tfp.experimental.stats.RunningCovariance.from_shape(
- tf.TensorShape([5, 2]))
- for sample in x:
- cov = cov.update(sample)
- final_mean, final_cov = self.evaluate([cov.mean, cov.covariance()])
- self.assertAllClose(np.mean(x, axis=0), final_mean, rtol=1e-5)
- self.assertEqual(final_cov.shape, (5, 2, 5, 2))
-
- # reshaping to be compatible with a check against numpy
- x_reshaped = x.reshape(100, 10)
- final_cov_reshaped = tf.reshape(final_cov, (10, 10))
- self.assertAllClose(np.cov(x_reshaped.T, ddof=0),
- final_cov_reshaped,
- rtol=1e-5)
-
- def test_chunked_running_covariance(self):
+ def test_chunked_high_rank_running_covariance(self):
rng = test_util.test_np_rng()
x = rng.rand(100, 2, 3, 5)
cov = tfp.experimental.stats.RunningCovariance.from_shape((3, 5))
- for sample in x:
- cov = cov.update(sample, axis=0)
+ cov = consume(cov, x, chunk_axis=0)
final_mean, final_cov = self.evaluate([cov.mean, cov.covariance()])
self.assertAllClose(
np.mean(x.reshape((200, 3, 5)), axis=0),
@@ -201,44 +132,17 @@ def test_chunked_running_covariance(self):
final_cov_reshaped,
rtol=1e-5)
- def test_running_covariance_with_event_ndims(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 3, 5, 2)
- cov = tfp.experimental.stats.RunningCovariance.from_shape(
- tf.TensorShape([5, 2]),
- event_ndims=1)
- for sample in x:
- cov = cov.update(sample, axis=0)
- final_mean, final_cov = self.evaluate([cov.mean, cov.covariance()])
- self.assertAllClose(
- np.mean(x.reshape(300, 5, 2), axis=0),
- final_mean,
- rtol=1e-5)
- self.assertEqual(final_cov.shape, (5, 2, 2))
-
- # manual computation with loops
- manual_cov = np.zeros((5, 2, 2))
- x_reshaped = x.reshape((300, 5, 2))
- delta_mean = x_reshaped - np.mean(x_reshaped, axis=0)
- for residual in delta_mean:
- for i in range(5):
- for j in range(2):
- for k in range(2):
- manual_cov[i][j][k] += residual[i][j] * residual[i][k]
- manual_cov /= 300
- self.assertAllClose(manual_cov, final_cov, rtol=1e-5)
-
- def test_batched_running_covariance(self):
+ def test_running_covariance_with_event_ndims_2(self):
rng = test_util.test_np_rng()
x = rng.rand(100, 3, 5, 2)
cov = tfp.experimental.stats.RunningCovariance.from_shape(
tf.TensorShape([3, 5, 2]),
event_ndims=2)
- for sample in x:
- cov = cov.update(sample)
+ cov = consume(cov, x)
final_mean, final_cov = self.evaluate([cov.mean, cov.covariance()])
self.assertAllClose(np.mean(x, axis=0), final_mean, rtol=1e-5)
self.assertEqual(final_cov.shape, (3, 5, 2, 5, 2))
+ self.assertEqual(cov.event_ndims, 2)
# check against numpy
x_reshaped = x.reshape((100, 3, 10))
@@ -246,142 +150,43 @@ def test_batched_running_covariance(self):
np_cov = np.cov(x_reshaped[:, i, :].T, ddof=0).reshape((5, 2, 5, 2))
self.assertAllClose(np_cov, final_cov[i], rtol=1e-5)
- def test_dynamic_shape_running_covariance(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 3, 5, 2)
- cov = tfp.experimental.stats.RunningCovariance.from_shape(
- tf.TensorShape([5, 2]),
- event_ndims=1)
- for sample in x:
- if not tf.executing_eagerly():
- sample = tf1.placeholder_with_default(sample, shape=None)
- cov = cov.update(sample, axis=0)
- final_mean, final_cov = self.evaluate([cov.mean, cov.covariance()])
- self.assertAllClose(
- np.mean(x.reshape(300, 5, 2), axis=0),
- final_mean,
- rtol=1e-5)
- self.assertEqual(final_cov.shape, (5, 2, 2))
-
- # manual computation with loops
- manual_cov = np.zeros((5, 2, 2))
- x_reshaped = x.reshape((300, 5, 2))
- delta_mean = x_reshaped - np.mean(x_reshaped, axis=0)
- for residual in delta_mean:
- for i in range(5):
- for j in range(2):
- for k in range(2):
- manual_cov[i][j][k] += residual[i][j] * residual[i][k]
- manual_cov /= 300
- self.assertAllClose(manual_cov, final_cov, rtol=1e-5)
-
def test_manual_dtype(self):
rng = test_util.test_np_rng()
- x = rng.rand(100, 10)
+ x = rng.rand(3, 10)
cov = tfp.experimental.stats.RunningCovariance.from_shape(
(10,), dtype=tf.float64)
- for sample in x:
- cov = cov.update(sample)
+ cov = consume(cov, x)
final_cov = cov.covariance()
self.assertEqual(final_cov.dtype, tf.float64)
def test_shift_in_running_covariance(self):
rng = test_util.test_np_rng()
x = rng.rand(100, 10) * 10
+ shifted_x = x + 1e4
cov = tfp.experimental.stats.RunningCovariance.from_shape((10,))
+ cov = consume(cov, x)
shifted_cov = tfp.experimental.stats.RunningCovariance.from_shape((10,))
- for sample in x:
- cov = cov.update(sample)
- shifted_cov = shifted_cov.update(sample + 1e4)
+ shifted_cov = consume(shifted_cov, shifted_x)
final_cov, final_shifted_cov = self.evaluate([
cov.covariance(), shifted_cov.covariance()])
self.assertAllClose(final_cov, np.cov(x.T, ddof=0), rtol=1e-5)
self.assertAllClose(
final_shifted_cov, np.cov(x.T, ddof=0), rtol=1e-1)
- def test_sorted_ascending_running_covariance(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 10)
- x.sort(axis=0)
- cov = tfp.experimental.stats.RunningCovariance.from_shape((10,))
- for sample in x:
- cov = cov.update(sample)
- final_cov = self.evaluate(cov.covariance())
- self.assertAllClose(final_cov, np.cov(x.T, ddof=0), rtol=1e-5)
-
- def test_sorted_descending_running_covariance(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 10)
- x[::-1].sort(axis=0) # sorts in descending order
- cov = tfp.experimental.stats.RunningCovariance.from_shape((10,))
- for sample in x:
- cov = cov.update(sample)
- final_cov = self.evaluate(cov.covariance())
- self.assertAllClose(final_cov, np.cov(x.T, ddof=0), rtol=1e-5)
-
- def test_attributes(self):
- rng = test_util.test_np_rng()
- x = rng.rand(2, 3, 10)
- cov = tfp.experimental.stats.RunningCovariance.from_shape(
- (3, 10,), event_ndims=1)
- for sample in x:
- cov = cov.update(sample)
- self.assertEqual(self.evaluate(cov.num_samples), 2.)
- self.assertEqual(cov.event_ndims, 1)
- self.assertEqual(cov.mean.dtype, tf.float32)
-
- def test_tf_while(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 10)
- tensor_x = tf.convert_to_tensor(x, dtype=tf.float32)
- cov = tfp.experimental.stats.RunningCovariance.from_shape((10,))
- var = tfp.experimental.stats.RunningVariance.from_shape((10,))
- _, cov = tf.while_loop(
- lambda i, _: i < 100,
- lambda i, cov: (i + 1, cov.update(tensor_x[i])),
- (0, cov))
- final_cov = cov.covariance()
- _, var = tf.while_loop(
- lambda i, _: i < 100,
- lambda i, var: (i + 1, var.update(tensor_x[i])),
- (0, var))
- final_var = var.variance()
- self.assertAllClose(final_cov, np.cov(x.T, ddof=0), rtol=1e-5)
- self.assertAllClose(final_var, np.var(x, axis=0), rtol=1e-5)
-
@test_util.test_all_tf_execution_regimes
class RunningPotentialScaleReductionTest(test_util.TestCase):
- def test_simple_operation(self):
- running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape(
- shape=(3,),
- )
- # 5 samples from 3 independent Markov chains
- x = np.arange(15, dtype=np.float32).reshape((5, 3))
- for sample in x:
- running_rhat = running_rhat.update(sample)
- rhat = running_rhat.potential_scale_reduction()
- true_rhat = tfp.mcmc.potential_scale_reduction(
- chains_states=x,
- independent_chain_ndims=1,
- )
- true_rhat, rhat = self.evaluate([true_rhat, rhat])
- self.assertNear(true_rhat, rhat, err=1e-6)
-
def test_random_scalar_computation(self):
rng = test_util.test_np_rng()
x = rng.rand(100, 10) * 100
running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape(
- shape=(10,),
- )
- for sample in x:
- running_rhat = running_rhat.update(sample)
+ shape=(10,))
+ running_rhat = consume(running_rhat, x)
rhat = running_rhat.potential_scale_reduction()
true_rhat = tfp.mcmc.potential_scale_reduction(
chains_states=x,
- independent_chain_ndims=1,
- )
+ independent_chain_ndims=1)
true_rhat, rhat = self.evaluate([true_rhat, rhat])
self.assertNear(true_rhat, rhat, err=1e-6)
@@ -389,20 +194,17 @@ def test_non_scalar_samples(self):
rng = test_util.test_np_rng()
x = rng.rand(100, 2, 2, 3, 5) * 100
running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape(
- shape=(2, 2, 3, 5),
- )
- for sample in x:
- running_rhat = running_rhat.update(sample)
+ shape=(2, 2, 3, 5))
+ running_rhat = consume(running_rhat, x)
rhat = running_rhat.potential_scale_reduction()
true_rhat = tfp.mcmc.potential_scale_reduction(
chains_states=x,
- independent_chain_ndims=1,
- )
+ independent_chain_ndims=1)
true_rhat, rhat = self.evaluate([true_rhat, rhat])
self.assertAllClose(true_rhat, rhat, rtol=1e-6)
- def test_batching(self):
- n_samples = 100
+ def test_multistate(self):
+ n_samples = 5
# state_0 is two scalar chains taken from iid Normal(0, 1).
state_0 = np.random.randn(n_samples, 2)
@@ -412,8 +214,7 @@ def test_batching(self):
state_1 = np.random.randn(n_samples, 3, 4) + offset
running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape(
shape=[(2,), (3, 4)],
- independent_chain_ndims=[1, 1]
- )
+ independent_chain_ndims=[1, 1])
for sample in zip(state_0, state_1):
running_rhat = running_rhat.update(sample)
rhat = self.evaluate(running_rhat.potential_scale_reduction())
@@ -424,55 +225,20 @@ def test_batching(self):
def test_independent_chain_ndims(self):
running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape(
shape=(5, 3),
- independent_chain_ndims=2,
- )
+ independent_chain_ndims=2)
x = np.arange(30, dtype=np.float32).reshape((2, 5, 3))
- for sample in x:
- running_rhat = running_rhat.update(sample)
+ running_rhat = consume(running_rhat, x)
rhat = running_rhat.potential_scale_reduction()
true_rhat = tfp.mcmc.potential_scale_reduction(
chains_states=x,
- independent_chain_ndims=2,
- )
+ independent_chain_ndims=2)
true_rhat, rhat = self.evaluate([true_rhat, rhat])
self.assertAllClose(true_rhat, rhat, rtol=1e-6)
- def test_tf_while(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 10) * 100
- tensor_x = tf.convert_to_tensor(x)
- running_rhat = tfp.experimental.stats.RunningPotentialScaleReduction.from_shape(
- shape=(10,),
- independent_chain_ndims=1
- )
- def _loop_body(i, running_rhat):
- running_rhat = running_rhat.update(tensor_x[i])
- return i + 1, running_rhat
- _, running_rhat = tf.while_loop(
- lambda i, _: i < 100,
- _loop_body,
- (0, running_rhat)
- )
- rhat = running_rhat.potential_scale_reduction()
- true_rhat = tfp.mcmc.potential_scale_reduction(
- chains_states=x,
- independent_chain_ndims=1,
- )
- true_rhat, rhat = self.evaluate([true_rhat, rhat])
- self.assertNear(true_rhat, rhat, err=1e-6)
-
@test_util.test_all_tf_execution_regimes
class RunningMeanTest(test_util.TestCase):
- def test_zero_mean(self):
- running_mean = tfp.experimental.stats.RunningMean.from_shape(
- shape=())
- for _ in range(6):
- running_mean = running_mean.update(0)
- mean = self.evaluate(running_mean.mean)
- self.assertEqual(0, mean)
-
def test_higher_rank_shape(self):
running_mean = tfp.experimental.stats.RunningMean.from_shape(
shape=(5, 3))
@@ -481,7 +247,7 @@ def test_higher_rank_shape(self):
mean = self.evaluate(running_mean.mean)
self.assertAllEqual(np.ones((5, 3)) * 2.5, mean)
- def test_manual_dtype(self):
+ def test_zero_mean_and_manual_dtype(self):
running_mean = tfp.experimental.stats.RunningMean.from_shape(
shape=(),
dtype=tf.float64)
@@ -493,8 +259,7 @@ def test_manual_dtype(self):
def test_integer_dtype(self):
running_mean = tfp.experimental.stats.RunningMean.from_shape(
shape=(),
- dtype=tf.int32,
- )
+ dtype=tf.int32)
for sample in range(6):
running_mean = running_mean.update(sample)
mean = running_mean.mean
@@ -507,8 +272,7 @@ def test_random_mean(self):
x = rng.rand(100)
running_mean = tfp.experimental.stats.RunningMean.from_shape(
shape=())
- for sample in x:
- running_mean = running_mean.update(sample)
+ running_mean = consume(running_mean, x)
mean = self.evaluate(running_mean.mean)
self.assertAllClose(np.mean(x), mean, rtol=1e-6)
@@ -516,26 +280,11 @@ def test_chunking(self):
rng = test_util.test_np_rng()
x = rng.rand(100, 10, 5)
running_mean = tfp.experimental.stats.RunningMean.from_shape(
- shape=(5,),
- )
- for sample in x:
- running_mean = running_mean.update(sample, axis=0)
+ shape=(5,))
+ running_mean = consume(running_mean, x, chunk_axis=0)
mean = self.evaluate(running_mean.mean)
self.assertAllClose(np.mean(x.reshape(1000, 5), axis=0), mean, rtol=1e-6)
- def test_tf_while(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100, 10)
- tensor_x = tf.convert_to_tensor(x, dtype=tf.float32)
- running_mean = tfp.experimental.stats.RunningMean.from_shape(
- shape=(10,))
- _, running_mean = tf.while_loop(
- lambda i, _: i < 100,
- lambda i, running_mean: (i + 1, running_mean.update(tensor_x[i])),
- (0, running_mean))
- mean = self.evaluate(running_mean.mean)
- self.assertAllClose(np.mean(x, axis=0), mean, rtol=1e-6)
-
def test_no_inputs(self):
running_mean = tfp.experimental.stats.RunningMean.from_shape(
shape=())
@@ -596,29 +345,17 @@ def test_higher_rank_samples(self):
self.assertAllClose(tf.ones((2, 2)) * 6.8, kur, rtol=1e-6)
self.assertAllClose(tf.zeros((2, 2)), fifth_moment, rtol=1e-6)
- def test_random_scalar_samples(self):
- rng = test_util.test_np_rng()
- x = rng.rand(100)
- running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape(
- shape=(),
- moment=np.arange(5) + 1)
- for sample in x:
- running_moments = running_moments.update(sample)
- moments = self.evaluate(running_moments.moments())
- self.assertAllClose(
- stats.moment(x, moment=[1, 2, 3, 4, 5]), moments, rtol=1e-6)
-
def test_random_higher_rank_samples(self):
rng = test_util.test_np_rng()
- x = rng.rand(100, 10)
+ x_orig = rng.rand(100, 10)
+ x = tf.convert_to_tensor(x_orig, dtype=tf.float32)
running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape(
shape=(10,),
moment=np.arange(5) + 1)
- for sample in x:
- running_moments = running_moments.update(sample)
+ running_moments = consume(running_moments, x)
moments = self.evaluate(running_moments.moments())
self.assertAllClose(
- stats.moment(x, moment=[1, 2, 3, 4, 5]), moments, rtol=1e-6)
+ stats.moment(x_orig, moment=[1, 2, 3, 4, 5]), moments, rtol=1e-6)
def test_manual_dtype(self):
running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape(
@@ -638,20 +375,6 @@ def test_int_dtype_casts(self):
moment = running_moments.moments()
self.assertEqual(tf.float32, moment.dtype)
- def test_in_tf_while(self):
- running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape(
- shape=(), moment=[1, 2, 3, 4])
- _, running_moments = tf.while_loop(
- lambda i, _: i < 5,
- lambda i, mom: (i + 1, mom.update(tf.ones(()) * i)),
- (0., running_moments)
- )
- moments = self.evaluate(running_moments.moments())
- self.assertAllClose(
- stats.moment(np.arange(5), moment=np.arange(4) + 1),
- moments,
- rtol=1e-6)
-
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow_probability/python/internal/BUILD b/tensorflow_probability/python/internal/BUILD
index 166715d085..9e20a352a4 100644
--- a/tensorflow_probability/python/internal/BUILD
+++ b/tensorflow_probability/python/internal/BUILD
@@ -113,6 +113,27 @@ multi_substrate_py_library(
],
)
+multi_substrate_py_library(
+ name = "callable_util",
+ srcs = ["callable_util.py"],
+ deps = [
+ # numpy dep,
+ # tensorflow dep,
+ ],
+)
+
+multi_substrate_py_test(
+ name = "callable_util_test",
+ size = "small",
+ srcs = ["callable_util_test.py"],
+ deps = [
+ ":callable_util",
+ # numpy dep,
+ # tensorflow dep,
+ "//tensorflow_probability/python/internal:test_util",
+ ],
+)
+
multi_substrate_py_library(
name = "custom_gradient",
srcs = ["custom_gradient.py"],
@@ -579,6 +600,21 @@ multi_substrate_py_test(
],
)
+multi_substrate_py_library(
+ name = "variadic_reduce",
+ srcs = [
+ "variadic_reduce.py",
+ ],
+ srcs_version = "PY3",
+ deps = [
+ ":implementation_selection",
+ ":prefer_static",
+ ":tensorshape_util",
+ # numpy dep,
+ # tensorflow dep,
+ ],
+)
+
exports_files(
[
"assert_util.py",
diff --git a/tensorflow_probability/python/internal/auto_composite_tensor.py b/tensorflow_probability/python/internal/auto_composite_tensor.py
index 61bee6b787..26ab6b3ac5 100644
--- a/tensorflow_probability/python/internal/auto_composite_tensor.py
+++ b/tensorflow_probability/python/internal/auto_composite_tensor.py
@@ -20,101 +20,138 @@
import functools
import inspect
-import warnings
-import six
import tensorflow.compat.v2 as tf
-from tensorflow.python.framework.composite_tensor import CompositeTensor # pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.framework import composite_tensor # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.saved_model import nested_structure_coder # pylint: disable=g-direct-tensorflow-import
-__all__ = ['auto_composite_tensor']
+__all__ = [
+ 'auto_composite_tensor',
+ 'AutoCompositeTensor',
+]
_registry = {} # Mapping from (python pkg, class name) -> class.
_SENTINEL = object()
+_AUTO_COMPOSITE_TENSOR_VERSION = 1
-def _mk_err_msg(clsid, obj, suffix=''):
- msg = ('Unable to expand "{}", derived from type `{}.{}`, to its Tensor '
- 'components. Email `tfprobability@tensorflow.org` or file an issue on '
- 'github if you would benefit from this working. {}'.format(
- obj, clsid[0], clsid[1], suffix))
- warnings.warn(msg)
- return msg
-
-def _kwargs_from(clsid, obj, limit_to=None):
+def _extract_init_kwargs(obj, omit_kwargs=(), limit_to=None):
"""Extract constructor kwargs to reconstruct `obj`."""
- if six.PY3:
- argspec = inspect.getfullargspec(obj.__init__)
- invalid_spec = bool(argspec.varargs or argspec.varkw)
- params = argspec.args + argspec.kwonlyargs
- else:
- argspec = inspect.getargspec(obj.__init__) # pylint: disable=deprecated-method
- invalid_spec = bool(argspec.varargs or argspec.keywords)
- params = argspec.args
- if invalid_spec:
- raise NotImplementedError(
- _mk_err_msg(
- clsid, obj,
- '*args and **kwargs are not supported. Found `{}`'.format(argspec)))
- keys = [p for p in params if p != 'self']
+ argspec = inspect.getfullargspec(obj.__init__)
+ if argspec.varargs or argspec.varkw:
+ raise ValueError(
+ '*args and **kwargs are not supported. Found `{}`'.format(argspec))
+
+ params = argspec.args + argspec.kwonlyargs
+ keys = [p for p in params if p != 'self' and p not in omit_kwargs]
if limit_to is not None:
keys = [k for k in keys if k in limit_to]
- kwargs = {k: getattr(obj, k, getattr(obj, '_' + k, _SENTINEL)) for k in keys}
- for k, v in kwargs.items():
- if v is _SENTINEL:
+
+ kwargs = {}
+ for k in keys:
+ if hasattr(obj, k):
+ kwargs[k] = getattr(obj, k)
+ elif hasattr(obj, '_' + k):
+ kwargs[k] = getattr(obj, '_' + k)
+ else:
raise ValueError(
- _mk_err_msg(
- clsid, obj,
- 'Object did not have getter for constructor argument {k}. (Tried '
- 'both `obj.{k}` and obj._{k}`).'.format(k=k)))
+ 'Object did not have an attr corresponding to constructor argument '
+ '{k}. (Tried both `obj.{k}` and obj._{k}`).'.format(k=k))
return kwargs
+def _extract_type_spec_recursively(value):
+ """Return (collection of) TypeSpec(s) for `value` if it includes `Tensor`s.
+
+ If `value` is a `Tensor` or `CompositeTensor`, return its `TypeSpec`. If
+ `value` is a collection containing `Tensor` values, recursively supplant them
+ with their respective `TypeSpec`s in a collection of parallel stucture.
+
+ If `value` is nont of the above, return it unchanged.
+
+ Args:
+ value: a Python `object` to (possibly) turn into a (collection of)
+ `tf.TypeSpec`(s).
+
+ Returns:
+ spec: the `TypeSpec` or collection of `TypeSpec`s corresponding to `value`
+ or `value`, if no `Tensor`s are found.
+ """
+ if tf.is_tensor(value):
+ return tf.TensorSpec.from_tensor(value)
+ if isinstance(value, composite_tensor.CompositeTensor):
+ return value._type_spec # pylint: disable=protected-access
+ if isinstance(value, (list, tuple)):
+ specs = [_extract_type_spec_recursively(v) for v in value]
+ has_tensors = any(a is not b for a, b in zip(value, specs))
+ has_only_tensors = all(a is not b for a, b in zip(value, specs))
+ if has_tensors:
+ if has_tensors != has_only_tensors:
+ raise NotImplementedError(
+ 'Found `{}` with both Tensor and non-Tensor parts: {}'
+ .format(type(value), value))
+ return type(value)(specs)
+ return value
+
+
class _AutoCompositeTensorTypeSpec(tf.TypeSpec):
"""A tf.TypeSpec for `AutoCompositeTensor` objects."""
- __slots__ = ('_clsid', '_param_specs', '_kwargs')
+ __slots__ = ('_param_specs', '_non_tensor_params', '_omit_kwargs')
- def __init__(self, clsid, param_specs, kwargs):
- self._clsid = clsid
+ def __init__(self, param_specs, non_tensor_params, omit_kwargs):
self._param_specs = param_specs
- self._kwargs = kwargs
+ self._non_tensor_params = non_tensor_params
+ self._omit_kwargs = omit_kwargs
- @property
- def value_type(self):
- return _registry[self._clsid]
+ @classmethod
+ def from_instance(cls, instance, omit_kwargs=()):
+ kwargs = _extract_init_kwargs(instance, omit_kwargs)
+
+ non_tensor_params = {}
+ param_specs = {}
+ for k, v in list(kwargs.items()):
+ # If v contains no Tensors, this will just be v
+ type_spec_or_v = _extract_type_spec_recursively(v)
+ if type_spec_or_v is not v:
+ param_specs[k] = type_spec_or_v
+ else:
+ non_tensor_params[k] = v
+
+ # Construct the spec.
+ return cls(param_specs=param_specs,
+ non_tensor_params=non_tensor_params,
+ omit_kwargs=omit_kwargs)
def _to_components(self, obj):
- params = _kwargs_from(self._clsid, obj, limit_to=list(self._param_specs))
- return params
+ return _extract_init_kwargs(obj, limit_to=list(self._param_specs))
def _from_components(self, components):
- kwargs = dict(self._kwargs, **components)
- return self.value_type(**kwargs) # pylint: disable=not-callable
+ kwargs = dict(self._non_tensor_params, **components)
+ return self.value_type(**kwargs)
@property
def _component_specs(self):
return self._param_specs
def _serialize(self):
- return 1, self._clsid, self._param_specs, self._kwargs
+ result = (_AUTO_COMPOSITE_TENSOR_VERSION,
+ self._param_specs,
+ self._non_tensor_params,
+ self._omit_kwargs)
+ return result
@classmethod
def _deserialize(cls, encoded):
- version, clsid, param_specs, kwargs = encoded
- if version != 1:
- raise ValueError('Unexpected version')
- if clsid not in _registry:
- raise ValueError(
- 'Unable to identify AutoCompositeTensor type for {}. Make sure the '
- 'class is decorated with `@tfp.experimental.auto_composite_tensor` '
- 'and its module is imported before calling '
- '`tf.saved_model.load`.'.format(clsid))
- return cls(clsid, param_specs, kwargs)
+ version, param_specs, non_tensor_params, omit_kwargs = encoded
+ if version != _AUTO_COMPOSITE_TENSOR_VERSION:
+ raise ValueError('Expected version {}, but got {}'
+ .format(_AUTO_COMPOSITE_TENSOR_VERSION, version))
+ return cls(param_specs, non_tensor_params, omit_kwargs)
_TypeSpecCodec = nested_structure_coder._TypeSpecCodec # pylint: disable=protected-access
@@ -125,23 +162,92 @@ def _deserialize(cls, encoded):
del _TypeSpecCodec
+class AutoCompositeTensor(composite_tensor.CompositeTensor):
+ """Recommended base class for `@auto_composite_tensor`-ified classes.
+
+ See details in `tfp.experimental.auto_composite_tensor` description.
+ """
+
+ @property
+ def _type_spec(self):
+ # This property will be overwritten by the `@auto_composite_tensor`
+ # decorator. However, we need it so that a valid subclass of the `ABCMeta`
+ # class `CompositeTensor` can be constructed and passed to the
+ # `@auto_composite_tensor` decorator
+ pass
+
+
def auto_composite_tensor(cls=None, omit_kwargs=()):
- """Automagically create a `CompositeTensor` class for `cls`.
+ """Automagically generate `CompositeTensor` behavior for `cls`.
+
+ `CompositeTensor` objects are able to pass in and out of `tf.function` and
+ `tf.while_loop`, or serve as part of the signature of a TF saved model.
+
+ The contract of `auto_composite_tensor` is that all __init__ args and kwargs
+ must have corresponding public or private attributes (or properties). Each of
+ these attributes is inspected (recursively) to determine whether it is (or
+ contains) `Tensor`s or non-`Tensor` metadata. `list` and `tuple` attributes
+ are supported, but must either contain *only* `Tensor`s (or lists, etc,
+ thereof), or *no* `Tensor`s. E.g.,
+ - object.attribute = [1., 2., 'abc'] # valid
+ - object.attribute = [tf.constant(1.), [tf.constant(2.)]] # valid
+ - object.attribute = ['abc', tf.constant(1.)] # invalid
+
+ If the decorated class `A` does not subclass `CompositeTensor`, a *new class*
+ will be generated, which mixes in `A` and `CompositeTensor`.
+
+ To avoid this extra class in the class hierarchy, we suggest inheriting from
+ `auto_composite_tensor.AutoCompositeTensor`, which inherits from
+ `CompositeTensor` and implants a trivial `_type_spec` @property. The
+ `@auto_composite_tensor` decorator will then overwrite this trivial
+ `_type_spec` @property. The trivial one is necessary because `_type_spec` is
+ an abstract property of `CompositeTensor`, and a valid class instance must be
+ created before the decorator can execute -- without the trivial `_type_spec`
+ property present, `ABCMeta` will throw an error! The user may thus do any of
+ the following:
+
+ #### `AutoCompositeTensor` base class (recommended)
+ ```python
+ @tfp.experimental.auto_composite_tensor
+ class MyClass(tfp.experimental.AutoCompositeTensor):
+ ...
- `CompositeTensor` objects are able to pass in and out of `tf.function`,
- `tf.while_loop` and serve as part of the signature of a TF saved model.
+ mc = MyClass()
+ type(mc)
+ # ==> MyClass
+ ```
- The basic contract is that all args must have public attributes (or
- properties) or private attributes corresponding to each argument to
- `__init__`. Each of these is inspected to determine whether it is a Tensor
- or non-Tensor metadata. Lists and tuples of objects are supported provided
- all items therein are all either Tensor/CompositeTensor, or all are not.
+ #### No `CompositeTensor` base class (ok, but changes expected types)
+ ```python
+ @tfp.experimental.auto_composite_tensor
+ class MyClass(object):
+ ...
- ## Example
+ mc = MyClass()
+ type(mc)
+ # ==> MyClass_AutoCompositeTensor
+ ```
+
+ #### `CompositeTensor` base class, requiring trivial `_type_spec`
+ ```python
+ from tensorflow.python.framework import composite_tensor
+ @tfp.experimental.auto_composite_tensor
+ class MyClass(composite_tensor.CompositeTensor):
+ @property
+ def _type_spec(self): # will be overwritten by @auto_composite_tensor
+ pass
+ ...
+
+ mc = MyClass()
+ type(mc)
+ # ==> MyClass
+ ```
+
+ ## Full usage example
```python
@tfp.experimental.auto_composite_tensor(omit_kwargs=('name',))
- class Adder(object):
+ class Adder(tfp.experimental.AutoCompositeTensor):
def __init__(self, x, y, name=None):
with tf.name_scope(name or 'Adder') as name:
self._x = tf.convert_to_tensor(x)
@@ -168,69 +274,44 @@ def body(obj):
omit_kwargs: Optional sequence of kwarg names to be omitted from the spec.
Returns:
- ctcls: A subclass of `cls` and TF CompositeTensor.
+ composite_tensor_subclass: A subclass of `cls` and TF CompositeTensor.
"""
if cls is None:
return functools.partial(auto_composite_tensor, omit_kwargs=omit_kwargs)
+
+ # If the declared class is already a CompositeTensor subclass, we can avoid
+ # affecting the actual type of the returned class. Otherwise, we need to
+ # explicitly mix in the CT type, and hence create and return a newly
+ # synthesized type.
+ if issubclass(cls, composite_tensor.CompositeTensor):
+ class _AlreadyCTTypeSpec(_AutoCompositeTensorTypeSpec):
+
+ @property
+ def value_type(self):
+ return cls
+ cls._type_spec = property( # pylint: disable=protected-access
+ lambda self: _AlreadyCTTypeSpec.from_instance(self, omit_kwargs))
+ return cls
+
clsid = (cls.__module__, cls.__name__, omit_kwargs)
- # Also check for subclass if retrieving from the _registry, in case the user
+ # Check for subclass if retrieving from the _registry, in case the user
# has redefined the class (e.g. in a REPL/notebook).
if clsid in _registry and issubclass(_registry[clsid], cls):
return _registry[clsid]
- class _AutoCompositeTensor(cls, CompositeTensor):
+ class _GeneratedCTTypeSpec(_AutoCompositeTensorTypeSpec):
+
+ @property
+ def value_type(self):
+ return _registry[clsid]
+
+ class _AutoCompositeTensor(cls, composite_tensor.CompositeTensor):
"""A per-`cls` subclass of `CompositeTensor`."""
@property
def _type_spec(self):
- kwargs = _kwargs_from(clsid, self)
- param_specs = {}
- # Heuristically identify the tensor parts, and separate them.
- for k, v in list(kwargs.items()): # We might pop in the loop body.
-
- if k in omit_kwargs:
- kwargs.pop(k)
- continue
-
- def reduce(v):
- has_tensors = False
- if tf.is_tensor(v):
- v = tf.TensorSpec.from_tensor(v)
- has_tensors = True
- if isinstance(v, CompositeTensor):
- v = v._type_spec # pylint: disable=protected-access
- has_tensors = True
- if isinstance(v, (list, tuple)):
- reduced = [reduce(v_) for v_ in v]
- has_tensors = any(ht for (_, ht) in reduced)
- if has_tensors != all(ht for (_, ht) in reduced):
- raise NotImplementedError(
- _mk_err_msg(
- clsid, self,
- 'Found `{}` with both Tensor and non-Tensor parts: {}'
- .format(type(v), v)))
- v = type(v)([spec for (spec, _) in reduced])
- return v, has_tensors
-
- v, has_tensors = reduce(v)
- if has_tensors:
- kwargs.pop(k)
- param_specs[k] = v
- # Else, we assume this entry is not a Tensor (bool, str, etc).
-
- # Construct the spec.
- spec = _AutoCompositeTensorTypeSpec(
- clsid, param_specs=param_specs, kwargs=kwargs)
- # Verify the spec serializes.
- struct_coder = nested_structure_coder.StructureCoder()
- try:
- struct_coder.encode_structure(spec)
- except nested_structure_coder.NotEncodableError as e:
- raise NotImplementedError(
- _mk_err_msg(clsid, self,
- '(Unable to serialize: {})'.format(str(e))))
- return spec
+ return _GeneratedCTTypeSpec.from_instance(self, omit_kwargs)
_AutoCompositeTensor.__name__ = '{}_AutoCompositeTensor'.format(cls.__name__)
_registry[clsid] = _AutoCompositeTensor
diff --git a/tensorflow_probability/python/internal/auto_composite_tensor_test.py b/tensorflow_probability/python/internal/auto_composite_tensor_test.py
index d70562d061..231f6846eb 100644
--- a/tensorflow_probability/python/internal/auto_composite_tensor_test.py
+++ b/tensorflow_probability/python/internal/auto_composite_tensor_test.py
@@ -106,6 +106,37 @@ def body(d):
after_loop,
expand_composites=True)
+ def test_already_ct_subclass(self):
+
+ @tfp.experimental.auto_composite_tensor
+ class MyCT(tfp.experimental.AutoCompositeTensor):
+
+ def __init__(self, tensor_param, non_tensor_param, maybe_tensor_param):
+ self._tensor_param = tf.convert_to_tensor(tensor_param)
+ self._non_tensor_param = non_tensor_param
+ self._maybe_tensor_param = maybe_tensor_param
+
+ def body(obj):
+ return MyCT(obj._tensor_param + 1,
+ obj._non_tensor_param,
+ obj._maybe_tensor_param),
+
+ init = MyCT(0., 0, 0)
+ result, = tf.while_loop(
+ cond=lambda *_: True,
+ body=body,
+ loop_vars=(init,),
+ maximum_iterations=3)
+ self.assertAllClose(3., result._tensor_param)
+
+ init = MyCT(0., 0, tf.constant(0))
+ result, = tf.while_loop(
+ cond=lambda *_: True,
+ body=body,
+ loop_vars=(init,),
+ maximum_iterations=3)
+ self.assertAllClose(3., result._tensor_param)
+
if __name__ == '__main__':
tf.enable_v2_behavior()
diff --git a/tensorflow_probability/python/internal/backend/numpy/nest.py b/tensorflow_probability/python/internal/backend/numpy/nest.py
index f2630b3e05..b4611b0813 100644
--- a/tensorflow_probability/python/internal/backend/numpy/nest.py
+++ b/tensorflow_probability/python/internal/backend/numpy/nest.py
@@ -32,6 +32,7 @@
# pylint: disable=unused-import
from tree import _assert_shallow_structure
+from tree import _DOT
from tree import _IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ
from tree import _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE
from tree import _is_attrs
@@ -40,6 +41,7 @@
from tree import _STRUCTURES_HAVE_MISMATCHING_LENGTHS
from tree import _STRUCTURES_HAVE_MISMATCHING_TYPES
from tree import _yield_flat_up_to
+from tree import _yield_sorted_items
from tree import _yield_value
from tree import assert_same_structure
from tree import flatten as dm_flatten
diff --git a/tensorflow_probability/python/internal/backend/numpy/numpy_math.py b/tensorflow_probability/python/internal/backend/numpy/numpy_math.py
index 4fdfab5efa..e487696070 100644
--- a/tensorflow_probability/python/internal/backend/numpy/numpy_math.py
+++ b/tensorflow_probability/python/internal/backend/numpy/numpy_math.py
@@ -568,7 +568,7 @@ def _divide_no_nan(x, y, name=None): # pylint: disable=unused-argument
exp = utils.copy_docstring(
'tf.math.exp',
- lambda x, name=None: np.exp(x))
+ lambda x, name=None: np.exp(_convert_to_tensor(x)))
expm1 = utils.copy_docstring(
'tf.math.expm1',
diff --git a/tensorflow_probability/python/internal/backend/numpy/v2.py b/tensorflow_probability/python/internal/backend/numpy/v2.py
index 7dfa30f913..f5e25c579c 100644
--- a/tensorflow_probability/python/internal/backend/numpy/v2.py
+++ b/tensorflow_probability/python/internal/backend/numpy/v2.py
@@ -96,7 +96,17 @@ def unflatten_f(*args_flat):
transform = jit_decorator
else:
- raise NotImplementedError('Could not find compiler: Numpy only.')
+
+ # The decoration will succeed, but calling such a function will fail. This
+ # allows us to have jitted top-level functions in a module, as long as
+ # they aren't called in Numpy mode.
+ def decorator(f):
+ @functools.wraps(f)
+ def wrapped_f(*args, **kwargs):
+ raise NotImplementedError('Could not find compiler: Numpy only.')
+ return wrapped_f
+
+ transform = decorator
# This code path is for the `foo = tf.function(foo, ...)` use case.
if func is not None:
return transform(func)
diff --git a/tensorflow_probability/python/internal/cache_util.py b/tensorflow_probability/python/internal/cache_util.py
index 9d35b540d1..d2f3607492 100644
--- a/tensorflow_probability/python/internal/cache_util.py
+++ b/tensorflow_probability/python/internal/cache_util.py
@@ -170,7 +170,8 @@ def subkey(self):
def __call__(self):
"""Unwraps the tensor reference."""
- return nest.map_structure(lambda x: x(), self._struct)
+ if self.alive:
+ return nest.map_structure(lambda x: x(), self._struct)
def __hash__(self):
"""Returns the cached hash of this structure."""
diff --git a/tensorflow_probability/python/internal/callable_util.py b/tensorflow_probability/python/internal/callable_util.py
new file mode 100644
index 0000000000..5dfb01fd50
--- /dev/null
+++ b/tensorflow_probability/python/internal/callable_util.py
@@ -0,0 +1,54 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Utilities for handling Python callables."""
+
+import tensorflow.compat.v2 as tf
+
+JAX_MODE = False
+NUMPY_MODE = False
+
+
+def get_output_spec(fn, *args, **kwargs):
+ """Traces a callable to determine shape and dtype of its return value(s).
+
+ Args:
+ fn: Python `callable` accepting (structures of) `Tensor` arguments and
+ returning (structures) of `Tensor`s.
+ *args: `Tensor` and/or `tf.TensorSpec` instances representing positional
+ arguments to `fn`.
+ **kwargs: `Tensor` and/or `tf.TensorSpec` instances representing named
+ arguments to `fn`.
+ Returns:
+ structured_outputs: Object or structure of objects corresponding to the
+ value(s) returned by `fn`. These objects have `.shape` and
+ `.dtype` attributes; nothing else about them is guaranteed by the API.
+ """
+
+ if NUMPY_MODE:
+ raise NotImplementedError('Either TensorFlow or JAX is required in order '
+ 'to trace a function without executing it.')
+
+ if JAX_MODE:
+ import jax # pylint: disable=g-import-not-at-top
+ return jax.eval_shape(fn, *args, **kwargs)
+
+ def _as_tensor_spec(t):
+ if isinstance(t, tf.TensorSpec):
+ return t
+ return tf.TensorSpec.from_tensor(tf.convert_to_tensor(t))
+ return tf.function(fn, autograph=False).get_concrete_function(
+ *tf.nest.map_structure(_as_tensor_spec, args),
+ **tf.nest.map_structure(_as_tensor_spec, kwargs)).structured_outputs
+
diff --git a/tensorflow_probability/python/internal/callable_util_test.py b/tensorflow_probability/python/internal/callable_util_test.py
new file mode 100644
index 0000000000..821a845e6c
--- /dev/null
+++ b/tensorflow_probability/python/internal/callable_util_test.py
@@ -0,0 +1,79 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for cache_util."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Dependency imports
+
+import numpy as np
+import tensorflow.compat.v2 as tf
+
+from tensorflow_probability.python.internal import callable_util
+from tensorflow_probability.python.internal import test_util
+
+
+def _return_args_from_infinite_loop(*loop_vars, additional_loop_vars=()):
+ return tf.while_loop(
+ cond=lambda *_: True,
+ body=lambda *loop_vars: loop_vars,
+ loop_vars=loop_vars + additional_loop_vars)
+
+
+class CallableUtilTest(test_util.TestCase):
+
+ @test_util.numpy_disable_test_missing_functionality('Tracing not supported')
+ def test_get_output_spec_avoids_evaluating_fn(self):
+ args = (np.array(0., dtype=np.float64),
+ (tf.convert_to_tensor(0.),
+ tf.convert_to_tensor([1., 1.], dtype=tf.float64)))
+ additional_args = (tf.convert_to_tensor([[3], [4]], dtype=tf.int32),)
+ # Trace using both positional and keyword args.
+ results = callable_util.get_output_spec(
+ _return_args_from_infinite_loop,
+ *args,
+ additional_loop_vars=additional_args)
+ self.assertAllEqualNested(
+ tf.nest.map_structure(lambda x: tf.convert_to_tensor(x).shape,
+ args + additional_args),
+ tf.nest.map_structure(lambda x: x.shape, results))
+ self.assertAllEqualNested(
+ tf.nest.map_structure(lambda x: tf.convert_to_tensor(x).dtype,
+ args + additional_args),
+ tf.nest.map_structure(lambda x: x.dtype, results))
+
+ @test_util.numpy_disable_test_missing_functionality('Tracing not supported')
+ @test_util.jax_disable_test_missing_functionality('b/174071016')
+ def test_get_output_spec_from_tensor_specs(self):
+ args = (tf.TensorSpec([], dtype=tf.float32),
+ (tf.TensorSpec([1, 1], dtype=tf.float32),
+ tf.TensorSpec([2], dtype=tf.float64)))
+ additional_args = (tf.TensorSpec([2, 1], dtype=tf.int32),)
+ # Trace using both positional and keyword args.
+ results = callable_util.get_output_spec(
+ _return_args_from_infinite_loop,
+ *args,
+ additional_loop_vars=additional_args)
+ self.assertAllEqualNested(
+ tf.nest.map_structure(lambda x: x.shape, args + additional_args),
+ tf.nest.map_structure(lambda x: x.shape, results))
+ self.assertAllEqualNested(
+ tf.nest.map_structure(lambda x: x.dtype, args + additional_args),
+ tf.nest.map_structure(lambda x: x.dtype, results))
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow_probability/python/internal/hypothesis_testlib.py b/tensorflow_probability/python/internal/hypothesis_testlib.py
index 3188944703..a7b50bafa3 100644
--- a/tensorflow_probability/python/internal/hypothesis_testlib.py
+++ b/tensorflow_probability/python/internal/hypothesis_testlib.py
@@ -61,6 +61,12 @@ def hypothesis_max_examples(default=None):
return int(os.environ.get('TFP_HYPOTHESIS_MAX_EXAMPLES', default or 20))
+def hypothesis_timeout():
+ # Use --test_env=TFP_HYPOTHESIS_TIMEOUT_SECS=600 to permit longer runs,
+ # ergo deeper exploration of the search tree.
+ return int(os.environ.get('TFP_HYPOTHESIS_TIMEOUT_SECS', 60))
+
+
def hypothesis_reproduction_seed():
# Use --test_env=TFP_HYPOTHESIS_REPRODUCE=hexjunk to reproduce a failure.
return os.environ.get('TFP_HYPOTHESIS_REPRODUCE', None)
@@ -92,14 +98,21 @@ def tfp_hp_settings(default_max_examples=None, **kwargs):
deadline=None,
suppress_health_check=[hp.HealthCheck.too_slow],
max_examples=hypothesis_max_examples(default=default_max_examples),
+ timeout=hypothesis_timeout(),
print_blob=hp.PrintSettings.ALWAYS)
kwds.update(kwargs)
def decorator(test_method):
- seed = hypothesis_reproduction_seed()
- if seed is not None:
+ repro_seed = hypothesis_reproduction_seed()
+ if repro_seed is not None:
# This implements the semantics of TFP_HYPOTHESIS_REPRODUCE via
# the `hp.reproduce_failure` decorator.
- test_method = hp.reproduce_failure('3.56.5', seed)(test_method)
+ test_method = hp.reproduce_failure('3.56.5', repro_seed)(test_method)
+ elif randomize_hypothesis():
+ # Hypothesis defaults to seeding its internal PRNG from the system time,
+ # so since we actually want randomization (including across machines) we
+ # have to force it.
+ entropy = os.urandom(64)
+ test_method = hp.seed(int.from_bytes(entropy, 'big'))(test_method)
return hp.settings(**kwds)(test_method)
return decorator
diff --git a/tensorflow_probability/python/internal/nest_util.py b/tensorflow_probability/python/internal/nest_util.py
index f75ff1f8c4..14831760f7 100644
--- a/tensorflow_probability/python/internal/nest_util.py
+++ b/tensorflow_probability/python/internal/nest_util.py
@@ -253,6 +253,9 @@ def convert_to_nested_tensor(value, dtype=None, dtype_hint=None,
elif hint_is_nested:
dtype = broadcast_structure(dtype_hint, dtype)
+ # Call coerce_structure to force the argument structure to match dtype.
+ value = coerce_structure(dtype, value)
+
def convert_fn(path, value, dtype, dtype_hint, name=None):
if not allow_packing and nest.is_nested(value) and any(
# Treat arrays like Tensors for full parity in JAX backend.
@@ -280,3 +283,110 @@ def convert_fn(path, value, dtype, dtype_hint, name=None):
else:
return nest.map_structure_with_tuple_paths_up_to(
dtype, convert_fn, value, dtype, dtype_hint, check_types=False)
+
+
+# pylint: disable=protected-access
+# TODO(b/173044916): Support namedtuple interop in nest and remove this method.
+def coerce_structure(shallow_tree, input_tree):
+ """Coerces the containers in `input_tree` to exactly match `shallow_tree`.
+
+ This method largely parallels the behavior of `nest.assert_shallow_structure`,
+ but allows `namedtuples` to be interpreted as either sequences or mappings.
+ It returns a structure with the container-classes found in `shallow_tree`
+ and the contents of `input_tree`, such that `shallow_tree` and `input_tree`
+ may be used safely in downstream calls to `nest.map_structure_up_to`.
+
+ Note: this method does not currently support `expand_composites`.
+
+ Example Usage:
+ ```python
+
+ ab = collections.namedtuple('AB', 'a b')(0, 1)
+ ba = collections.namedtuple('BA', 'b a')(2, 3)
+
+ coerce_structure(ab, ba)
+ # -> AB(a=3, b=2)
+ ```
+
+ Args:
+ shallow_tree: A (shallow) structure to be populated.
+ input_tree: A (parallel) structure of values.
+ Returns:
+ A structure with containers from shallow_tree and values from input_tree.
+ Raises:
+ ValueError: When nested sub-structures have differing lengths.
+ ValueError: When nested sub-structures have different keys.
+ TypeError: When `shallow_tree` is deeper than `input_tree`
+ TypeError: When nested sub-structures are incompatible (e.g., list vs dict).
+ """
+ try:
+ return _coerce_structure(shallow_tree, input_tree)
+ except (ValueError, TypeError) as e:
+ str1 = str(nest.map_structure(lambda _: nest._DOT, shallow_tree))
+ str2 = str(nest.map_structure(lambda _: nest._DOT, input_tree))
+ raise type(e)(('{}\n'
+ 'Entire first structure:\n{}\n'
+ 'Entire second structure:\n{}'
+ ).format(e, str1, str2))
+
+
+def _coerce_structure(shallow_tree, input_tree):
+ """Implementation of coerce_structure."""
+ if not nest.is_nested(shallow_tree):
+ return input_tree
+
+ if not nest.is_nested(input_tree):
+ raise TypeError(nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(
+ type(input_tree)))
+
+ if len(input_tree) != len(shallow_tree):
+ raise ValueError(
+ nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
+ input_length=len(input_tree),
+ shallow_length=len(shallow_tree)))
+
+ # Determine whether shallow_tree should be treated as a Mapping or a Sequence.
+ # Namedtuples can be interpreted either way (but keys take precedence).
+ _shallow_is_namedtuple = nest._is_namedtuple(shallow_tree) # pylint: disable=invalid-name
+ _shallow_is_mapping = isinstance(shallow_tree, collections.abc.Mapping) # pylint: disable=invalid-name
+ shallow_supports_keys = _shallow_is_namedtuple or _shallow_is_mapping
+ shallow_supports_iter = _shallow_is_namedtuple or not _shallow_is_mapping
+
+ # Branch-selection depends on both shallow and input container-classes.
+ input_is_mapping = isinstance(input_tree, collections.abc.Mapping)
+ if nest._is_namedtuple(input_tree):
+ if shallow_supports_keys:
+ lookup_branch = lambda k: getattr(input_tree, k)
+ else:
+ input_iter = nest._yield_value(input_tree)
+ lookup_branch = lambda _: next(input_iter)
+ elif shallow_supports_keys and input_is_mapping:
+ lookup_branch = lambda k: input_tree[k]
+ elif shallow_supports_iter and not input_is_mapping:
+ input_iter = nest._yield_value(input_tree)
+ lookup_branch = lambda _: next(input_iter)
+ else:
+ raise TypeError(nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format(
+ input_type=type(input_tree),
+ shallow_type=(
+ type(shallow_tree.__wrapped__)
+ if hasattr(shallow_tree, '__wrapped__') else
+ type(shallow_tree))))
+
+ flat_coerced = []
+ needs_wrapping = type(shallow_tree) is not type(input_tree)
+ for shallow_key, shallow_branch in nest._yield_sorted_items(shallow_tree):
+ try:
+ input_branch = lookup_branch(shallow_key)
+ except (KeyError, AttributeError):
+ raise ValueError(
+ nest._SHALLOW_TREE_HAS_INVALID_KEYS.format([shallow_key]))
+ flat_coerced.append(_coerce_structure(shallow_branch, input_branch))
+ # Keep track of whether nested elements have changed.
+ needs_wrapping |= input_branch is not flat_coerced[-1]
+
+ # Only create a new instance if containers differ or contents changed.
+ return (nest._sequence_like(shallow_tree, flat_coerced)
+ if needs_wrapping else input_tree)
+
+# pylint: enable=protected-access
diff --git a/tensorflow_probability/python/internal/nest_util_test.py b/tensorflow_probability/python/internal/nest_util_test.py
index 61be13021e..f2462bb04d 100644
--- a/tensorflow_probability/python/internal/nest_util_test.py
+++ b/tensorflow_probability/python/internal/nest_util_test.py
@@ -86,7 +86,8 @@ def __repr__(self):
return 'LeafDict' + super(LeafDict, self).__repr__()
-NamedTuple = collections.namedtuple('NamedTuple', 'x, y')
+NamedTuple = collections.namedtuple('NamedTuple', 'x y')
+NamedTupleYX = collections.namedtuple('NamedTuple', 'y x')
# Alias for readability.
Tensor = np.array # pylint: disable=invalid-name
@@ -341,6 +342,65 @@ def testConvertToNestedTensorRaises_incompatible_dtype(self):
value=tf.constant(1, tf.float32),
dtype=tf.float64)
+ @parameterized.named_parameters({
+ 'testcase_name': '_ntup_ntup',
+ 'target': NamedTupleYX(1, 2),
+ 'source': NamedTuple(3, 4),
+ 'expect': NamedTupleYX(y=4, x=3)
+ },{
+ 'testcase_name': '_ntup_dict',
+ 'target': NamedTupleYX(1, 2),
+ 'source': {'x': 3, 'y': 4},
+ 'expect': NamedTupleYX(y=4, x=3)
+ },{
+ 'testcase_name': '_ntup_list',
+ 'target': NamedTupleYX(1, 2),
+ 'source': [3, 4],
+ 'expect': NamedTupleYX(y=3, x=4)
+ },{
+ 'testcase_name': '_list_ntup',
+ 'target': [1, 2],
+ 'source': NamedTupleYX(3, 4),
+ 'expect': [3, 4]
+ },{
+ 'testcase_name': '_dict_ntup',
+ 'target': {'x': 1, 'y': 2},
+ 'source': NamedTupleYX(3, 4),
+ 'expect': {'x': 4, 'y': 3}
+ },{
+ 'testcase_name': '_up_to',
+ 'target': NamedTupleYX(1, 2),
+ 'source': NamedTuple([3,4], [5,6]),
+ 'expect': NamedTupleYX(y=[5,6], x=[3,4])
+ },{
+ 'testcase_name': '_deep_wrap',
+ 'target': {'foo': NamedTuple(1, 2), 'bar': NamedTupleYX(3, 4)},
+ 'source': {'foo': NamedTupleYX(5, 6), 'bar': NamedTuple(7, 8)},
+ 'expect': {'foo': NamedTuple(6, 5), 'bar': NamedTupleYX(8, 7)}
+ })
+ def testCoerceStructure(self, target, source, expect):
+ coerced = nest_util.coerce_structure(target, source)
+ self.assertAllEqualNested(expect, coerced)
+
+ @parameterized.named_parameters({
+ 'testcase_name': '_seqlength',
+ 'target': NamedTuple(1, 2),
+ 'source': [3, 4, 5],
+ 'message': 'sequence length'
+ },{
+ 'testcase_name': '_dict_list',
+ 'target': {'foo': 1, 'bar': 2},
+ 'source': [3, 4],
+ 'message': 'sequence type'
+ },{
+ 'testcase_name': '_shallow_input',
+ 'target': [1, 2],
+ 'source': 3,
+ 'message': 'input must also be a sequence'
+ })
+ def testCoerceStructureRaises(self, target, source, message):
+ with self.assertRaisesRegex((ValueError, TypeError), message):
+ nest_util.coerce_structure(target, source)
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow_probability/python/internal/prefer_static.py b/tensorflow_probability/python/internal/prefer_static.py
index 3961639109..48d576f9e8 100644
--- a/tensorflow_probability/python/internal/prefer_static.py
+++ b/tensorflow_probability/python/internal/prefer_static.py
@@ -37,6 +37,7 @@
except ImportError:
from tensorflow.python import pywrap_tensorflow as c_api # pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops import control_flow_ops # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.util import tf_inspect # pylint: disable=g-direct-tensorflow-import
@@ -109,8 +110,10 @@ def _get_static_value(pred):
pred_value = tf.get_static_value(tf.convert_to_tensor(pred))
# TODO(jamieas): remove the dependency on `pywrap_tensorflow`.
+ # Explicitly check for ops.Tensor, to avoid an AttributeError
+ # when requesting `KerasTensor.graph`.
# pylint: disable=protected-access
- if pred_value is None:
+ if pred_value is None and isinstance(pred, ops.Tensor):
pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
pred._as_tf_output())
# pylint: enable=protected-access
@@ -468,6 +471,7 @@ def is_numpy(x):
reduce_prod = _prefer_static(tf.reduce_prod, nptf.reduce_prod)
reduce_sum = _prefer_static(tf.reduce_sum, nptf.reduce_sum)
reshape = _prefer_static(tf.reshape, nptf.reshape)
+reverse = _prefer_static(tf.reverse, nptf.reverse)
round = _prefer_static(tf.math.round, nptf.math.round) # pylint: disable=redefined-builtin
rsqrt = _prefer_static(tf.math.rsqrt, nptf.math.rsqrt)
slice = _prefer_static(tf.slice, nptf.slice) # pylint: disable=redefined-builtin
diff --git a/tensorflow_probability/python/internal/variadic_reduce.py b/tensorflow_probability/python/internal/variadic_reduce.py
new file mode 100644
index 0000000000..65dcc9e2d0
--- /dev/null
+++ b/tensorflow_probability/python/internal/variadic_reduce.py
@@ -0,0 +1,190 @@
+# Copyright 2020 The TensorFlow Probability Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Helper for generic variadic reductions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow.compat.v1 as tf1
+import tensorflow.compat.v2 as tf
+
+from tensorflow_probability.python.internal import implementation_selection
+from tensorflow_probability.python.internal import prefer_static as ps
+from tensorflow_probability.python.internal import tensorshape_util
+from tensorflow.python.ops import control_flow_util # pylint: disable=g-direct-tensorflow-import
+
+
+JAX_MODE = False
+
+
+def _variadic_reduce(t, init, axis, reducer):
+ """Implements a basic variadic reduce by repeated halving.
+
+ The function executes recursively with each call trimming off one axis. Per
+ axis, the computation uses `tf.while_loop` to repeatedly reduce the two halves
+ of the tensor along the axis in question (padding with `init` as needed).
+
+ Args:
+ t: The tuple of tensors to reduce.
+ init: A tuple of scalar initializations with dtypes aligned with `t`.
+ axis: A sequence of python `int`.
+ reducer: The function implementing the reduction. Each of the two arguments
+ to `reducer` is a tuple like `t`.
+
+ Returns:
+ reduced: A tuple like `t` with the given reduction applied.
+ """
+ if not axis:
+ return t
+
+ if len(axis) > 1:
+ if any(ax < 0 for ax in axis):
+ raise ValueError('All axis args must be non-negative: {}'.format(axis))
+ axis = sorted(axis)
+ # Reduce innermost dims first, so that positive `axis` args remain aligned.
+ reduced_innermost_axis = _variadic_reduce(t, init, axis[-1:], reducer)
+ return _variadic_reduce(reduced_innermost_axis, init, axis[:-1], reducer)
+
+ ax = axis[0]
+ if any(part.shape[ax] != t[0].shape[ax] for part in t):
+ raise ValueError(
+ 'Mismatched shapes along axis {}: {}'.format(
+ ax, [part.shape for part in t]))
+
+ def cond(*t):
+ return tf.not_equal(tf.shape(t[0])[ax], 1)
+
+ def body(*t):
+ dim = tf.shape(t[0])[ax]
+ lhs, rhs = [], []
+ for part, part_init in zip(t, init):
+ p0, p1 = tf.split(part, [dim // 2, dim // 2 + dim % 2], axis=ax)
+ paddings = [(0, 0)] * tensorshape_util.rank(p0.shape)
+ # Ensure we have two `f` operands of matching size:
+ # 1. pad 0 sized dim up to 1.
+ # 2. pad smaller-by-1 dim up to larger.
+ paddings[ax] = (0, tf.maximum(tf.shape(p1)[ax], 1) - tf.shape(p0)[ax])
+ p0 = tf.pad(p0, paddings, constant_values=part_init)
+ lhs.append(p0)
+ paddings[ax] = (0, tf.shape(p0)[ax] - tf.shape(p1)[ax])
+ # p1 may need padding if both are size 0 in the dim.
+ p1 = tf.pad(p1, paddings, constant_values=part_init)
+ rhs.append(p1)
+ return reducer(tuple(lhs), tuple(rhs))
+
+ shape_invariants = tuple(
+ part.shape[:ax] + (None,) + part.shape[ax + 1:] for part in t)
+ t = tuple(t)
+ result = tf.while_loop(cond, body, t, shape_invariants=shape_invariants)
+ # Squeeze out the singleton dim in ax.
+ return tuple(tf.squeeze(part, axis=ax) for part in result)
+
+
+def make_variadic_reduce(reducer):
+ """Wraps a generic reducer function as a variadic reduction.
+
+ The current use-case for this is TFP-internal. This function captures logic
+ related to specific substrates and XLA, and enables some sharing of logic
+ around `axis` and `keepdims` args.
+
+ Args:
+ reducer: The reducer callable. Takes two tuple args and returns a single
+ reduced tuple.
+
+ Returns:
+ reduce_fn: A callable with taking args
+ `(operands, inits, axis=None, keepdims=False)`.
+ """
+
+ # Top-level `tf.function` for XLA (closed-over by the returned reduce_fn).
+ @implementation_selection.never_runs_functions_eagerly
+ @tf.function(experimental_compile=True)
+ def _xla_reduce(operands, inits, axis):
+ """JIT-ed wrapper for TF `xla.variadic_reduce(..., reducer)`."""
+ from tensorflow.compiler.tf2xla.python import xla # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
+ result = xla.variadic_reduce(
+ operands,
+ init_value=inits,
+ dimensions_to_reduce=axis,
+ reducer=tf.function(reducer).get_concrete_function(inits, inits))
+ # Graph mode: variadic reduce doesn't specify output shapes. Patch that.
+ shp = operands[0].shape
+ for arg in operands:
+ shp = tensorshape_util.merge_with(shp, arg.shape)
+ for part in result:
+ tensorshape_util.set_shape(
+ part, tuple(dim for i, dim in enumerate(shp) if i not in axis))
+ return result
+
+ def reduce_fn(operands, inits, axis=None, keepdims=False):
+ """Applies `reducer` to the given operands along the given axes.
+
+ Args:
+ operands: tuple of tensors, all having the same shape.
+ inits: tuple of scalar tensors, with dtypes aligned to those of operands.
+ axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of
+ `int`. `None` is taken to mean "reduce all axes".
+ keepdims: When `True`, we do not squeeze away the reduced dims, instead
+ returning values with singleton dims in those axes.
+
+ Returns:
+ reduced: A tuple of the reduced operands.
+ """
+ # Static shape consistency checks.
+ args_shape = operands[0].shape
+ for arg in operands[1:]:
+ args_shape = tensorshape_util.merge_with(args_shape, arg.shape)
+ ndims = tensorshape_util.rank(args_shape)
+ if ndims is None:
+ raise ValueError(
+ 'Rank of at least one of `operands` must be known statically.')
+ # Ensure the 'axis' arg is a tuple of non-negative ints.
+ axis = np.arange(ndims) if axis is None else np.array(axis)
+ if axis.ndim > 1:
+ raise ValueError('`axis` must be `None`, an `int`, or a sequence of '
+ '`int`, but got {}'.format(axis))
+ axis = np.reshape(axis, [-1])
+ axis = np.where(axis < 0, axis + ndims, axis)
+ axis = tuple(int(ax) for ax in axis)
+
+ if JAX_MODE:
+ from jax import lax # pylint: disable=g-import-not-at-top
+ result = lax.reduce(
+ operands, init_values=inits, dimensions=axis, computation=reducer)
+ elif (tf.executing_eagerly() or
+ not control_flow_util.GraphOrParentsInXlaContext(
+ tf1.get_default_graph())):
+ result = _variadic_reduce(
+ operands, init=inits, axis=axis, reducer=reducer)
+ else:
+ result = _xla_reduce(operands, inits, axis)
+
+ if keepdims:
+ axis_nhot = ps.reduce_sum(
+ ps.one_hot(axis, depth=ndims,
+ on_value=True, off_value=False, dtype=tf.bool),
+ axis=0)
+ in_shape = args_shape
+ if not tensorshape_util.is_fully_defined(in_shape):
+ in_shape = tf.shape(operands[0])
+ final_shape = ps.where(axis_nhot, 1, in_shape)
+ result = tf.nest.map_structure(
+ lambda t: tf.reshape(t, final_shape), result)
+ return result
+
+ return reduce_fn
+
diff --git a/tensorflow_probability/python/internal/vectorization_util.py b/tensorflow_probability/python/internal/vectorization_util.py
index 5c5a9b9eb4..4e5e05c7af 100644
--- a/tensorflow_probability/python/internal/vectorization_util.py
+++ b/tensorflow_probability/python/internal/vectorization_util.py
@@ -265,7 +265,7 @@ def vectorized_fn(*args):
# First, compute how many 'extra' (batch) ndims each part has. This must
# be nonnegative.
- vectorized_arg_shapes = [tf.shape(arg) for arg in vectorized_args]
+ vectorized_arg_shapes = [ps.shape(arg) for arg in vectorized_args]
batch_ndims = [
ps.rank_from_shape(arg_shape) - nd
for (arg_shape, nd) in zip(
diff --git a/tensorflow_probability/python/layers/conv_variational.py b/tensorflow_probability/python/layers/conv_variational.py
index 45b7a3b81e..1bdb7b2834 100644
--- a/tensorflow_probability/python/layers/conv_variational.py
+++ b/tensorflow_probability/python/layers/conv_variational.py
@@ -20,6 +20,7 @@
import tensorflow.compat.v2 as tf
+from tensorflow_probability.python import random as tfp_random
from tensorflow_probability.python.distributions import independent as independent_lib
from tensorflow_probability.python.distributions import kullback_leibler as kl_lib
from tensorflow_probability.python.distributions import normal as normal_lib
@@ -1084,18 +1085,12 @@ def _apply_variational_kernel(self, inputs):
seed_stream = SeedStream(self.seed, salt='ConvFlipout')
- def random_rademacher(shape, dtype=tf.float32, seed=None):
- int_dtype = tf.int64 if tf.as_dtype(dtype) != tf.int32 else tf.int32
- random_bernoulli = tf.random.uniform(
- shape, minval=0, maxval=2, dtype=int_dtype, seed=seed)
- return tf.cast(2 * random_bernoulli - 1, dtype)
-
- sign_input = random_rademacher(
+ sign_input = tfp_random.rademacher(
tf.concat([batch_shape,
tf.expand_dims(channels, 0)], 0),
dtype=inputs.dtype,
seed=seed_stream())
- sign_output = random_rademacher(
+ sign_output = tfp_random.rademacher(
tf.concat([batch_shape,
tf.expand_dims(self.filters, 0)], 0),
dtype=inputs.dtype,
diff --git a/tensorflow_probability/python/layers/conv_variational_test.py b/tensorflow_probability/python/layers/conv_variational_test.py
index 9d505d7488..5340d35d4a 100644
--- a/tensorflow_probability/python/layers/conv_variational_test.py
+++ b/tensorflow_probability/python/layers/conv_variational_test.py
@@ -455,20 +455,14 @@ def _testConvFlipout(self, layer_class): # pylint: disable=invalid-name
seed_stream = tfp.util.SeedStream(layer.seed, salt='ConvFlipout')
- sign_input = tf.random.uniform(
+ sign_input = tfp.random.rademacher(
tf.concat([batch_shape, tf.expand_dims(channels, 0)], 0),
- minval=0,
- maxval=2,
- dtype=tf.int64,
+ dtype=inputs.dtype,
seed=seed_stream())
- sign_input = tf.cast(2 * sign_input - 1, inputs.dtype)
- sign_output = tf.random.uniform(
+ sign_output = tfp.random.rademacher(
tf.concat([batch_shape, tf.expand_dims(filters, 0)], 0),
- minval=0,
- maxval=2,
- dtype=tf.int64,
+ dtype=inputs.dtype,
seed=seed_stream())
- sign_output = tf.cast(2 * sign_output - 1, inputs.dtype)
if self.data_format == 'channels_first':
for _ in range(rank):
diff --git a/tensorflow_probability/python/layers/dense_variational.py b/tensorflow_probability/python/layers/dense_variational.py
index b70abe00d1..4bc508a541 100644
--- a/tensorflow_probability/python/layers/dense_variational.py
+++ b/tensorflow_probability/python/layers/dense_variational.py
@@ -20,6 +20,7 @@
import tensorflow.compat.v2 as tf
+from tensorflow_probability.python import random as tfp_random
from tensorflow_probability.python.distributions import independent as independent_lib
from tensorflow_probability.python.distributions import kullback_leibler as kl_lib
from tensorflow_probability.python.distributions import normal as normal_lib
@@ -687,17 +688,11 @@ def _apply_variational_kernel(self, inputs):
seed_stream = SeedStream(self.seed, salt='DenseFlipout')
- def random_rademacher(shape, dtype=tf.float32, seed=None):
- int_dtype = tf.int64 if tf.as_dtype(dtype) != tf.int32 else tf.int32
- random_bernoulli = tf.random.uniform(
- shape, minval=0, maxval=2, dtype=int_dtype, seed=seed)
- return tf.cast(2 * random_bernoulli - 1, dtype)
-
- sign_input = random_rademacher(
+ sign_input = tfp_random.rademacher(
input_shape,
dtype=inputs.dtype,
seed=seed_stream())
- sign_output = random_rademacher(
+ sign_output = tfp_random.rademacher(
tf.concat([batch_shape,
tf.expand_dims(self.units, 0)], 0),
dtype=inputs.dtype,
diff --git a/tensorflow_probability/python/layers/dense_variational_test.py b/tensorflow_probability/python/layers/dense_variational_test.py
index 11c66df68e..df9653c018 100644
--- a/tensorflow_probability/python/layers/dense_variational_test.py
+++ b/tensorflow_probability/python/layers/dense_variational_test.py
@@ -390,20 +390,16 @@ def testDenseFlipout(self):
expected_kernel_posterior_affine_tensor = (
expected_kernel_posterior_affine.sample(seed=42))
- stream = tfp.util.SeedStream(layer.seed, salt='DenseFlipout')
-
- sign_input = tf.random.uniform([batch_size, in_size],
- minval=0,
- maxval=2,
- dtype=tf.int64,
- seed=stream())
- sign_input = tf.cast(2 * sign_input - 1, inputs.dtype)
- sign_output = tf.random.uniform([batch_size, out_size],
- minval=0,
- maxval=2,
- dtype=tf.int64,
- seed=stream())
- sign_output = tf.cast(2 * sign_output - 1, inputs.dtype)
+ seed_stream = tfp.util.SeedStream(layer.seed, salt='DenseFlipout')
+
+ sign_input = tfp.random.rademacher(
+ [batch_size, in_size],
+ dtype=inputs.dtype,
+ seed=seed_stream())
+ sign_output = tfp.random.rademacher(
+ [batch_size, out_size],
+ dtype=inputs.dtype,
+ seed=seed_stream())
perturbed_inputs = tf.matmul(
inputs * sign_input, expected_kernel_posterior_affine_tensor)
perturbed_inputs *= sign_output
diff --git a/tensorflow_probability/python/math/BUILD b/tensorflow_probability/python/math/BUILD
index 8e3646f0bb..ad8b133d0e 100644
--- a/tensorflow_probability/python/math/BUILD
+++ b/tensorflow_probability/python/math/BUILD
@@ -151,6 +151,8 @@ multi_substrate_py_library(
"//tensorflow_probability/python/internal:custom_gradient",
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:prefer_static",
+ "//tensorflow_probability/python/internal:variadic_reduce",
+ # tensorflow/compiler/jit dep,
],
)
@@ -164,7 +166,6 @@ multi_substrate_py_test(
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/internal:test_util",
-# "//third_party/tensorflow/compiler/jit:xla_cpu_jit", # DisableOnExport
],
)
@@ -392,7 +393,9 @@ multi_substrate_py_library(
srcs_version = "PY3",
deps = [
# tensorflow dep,
+ "//tensorflow_probability/python/internal:callable_util",
"//tensorflow_probability/python/internal:dtype_util",
+ "//tensorflow_probability/python/internal:tensorshape_util",
],
)
@@ -489,16 +492,3 @@ multi_substrate_py_test(
"//tensorflow_probability/python/internal:test_util",
],
)
-
-exports_files(
- [
- "generic.py",
- "generic_test.py",
- "gradient.py",
- "gradient_test.py",
- "linalg.py",
- "numeric.py",
- "special.py",
- ],
- visibility = ["//tensorflow_probability:__subpackages__"],
-)
diff --git a/tensorflow_probability/python/math/__init__.py b/tensorflow_probability/python/math/__init__.py
index d430404dde..eef9d729cf 100644
--- a/tensorflow_probability/python/math/__init__.py
+++ b/tensorflow_probability/python/math/__init__.py
@@ -33,6 +33,7 @@
from tensorflow_probability.python.math.generic import log_cosh
from tensorflow_probability.python.math.generic import log_cumsum_exp
from tensorflow_probability.python.math.generic import log_sub_exp
+from tensorflow_probability.python.math.generic import reduce_kahan_sum
from tensorflow_probability.python.math.generic import reduce_log_harmonic_mean_exp
from tensorflow_probability.python.math.generic import reduce_logmeanexp
from tensorflow_probability.python.math.generic import reduce_weighted_logsumexp
@@ -59,6 +60,7 @@
from tensorflow_probability.python.math.minimize import MinimizeTraceableQuantities
from tensorflow_probability.python.math.numeric import clip_by_value_preserve_gradient
from tensorflow_probability.python.math.numeric import log1psquare
+from tensorflow_probability.python.math.root_search import bracket_root
from tensorflow_probability.python.math.root_search import find_root_chandrupatla
from tensorflow_probability.python.math.root_search import find_root_secant
from tensorflow_probability.python.math.root_search import secant_root
@@ -88,6 +90,7 @@
'bessel_iv_ratio',
'bessel_ive',
'bessel_kve',
+ 'bracket_root',
'cholesky_concat',
'cholesky_update',
'clip_by_value_preserve_gradient',
@@ -124,6 +127,7 @@
'psd_kernels',
'random_rademacher',
'random_rayleigh',
+ 'reduce_kahan_sum',
'reduce_log_harmonic_mean_exp',
'reduce_logmeanexp',
'reduce_weighted_logsumexp',
diff --git a/tensorflow_probability/python/math/generic.py b/tensorflow_probability/python/math/generic.py
index 8aad0bb2f2..a9f3041cb3 100644
--- a/tensorflow_probability/python/math/generic.py
+++ b/tensorflow_probability/python/math/generic.py
@@ -21,13 +21,16 @@
from __future__ import division
from __future__ import print_function
+import collections
+
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient
from tensorflow_probability.python.internal import distribution_util as dist_util
from tensorflow_probability.python.internal import dtype_util
-from tensorflow_probability.python.internal import prefer_static
+from tensorflow_probability.python.internal import prefer_static as ps
+from tensorflow_probability.python.internal import variadic_reduce
from tensorflow_probability.python.math.scan_associative import scan_associative
@@ -38,6 +41,7 @@
'log_combinations',
'log_cumsum_exp',
'log1mexp',
+ 'reduce_kahan_sum',
'reduce_logmeanexp',
'reduce_weighted_logsumexp',
'smootherstep',
@@ -116,6 +120,80 @@ def safe_logsumexp(x, y):
dest_idx=axis)
+def _kahan_reduction(x, y):
+ """Implements the Kahan summation reduction."""
+ (s, c), (s1, c1) = x, y
+ for val in -c1, s1:
+ u = val - c
+ t = s + u
+ # TODO(b/173158845): XLA:CPU reassociates-to-zero the correction term.
+ c = (t - s) - u
+ s = t
+ return s, c
+
+
+_reduce_kahan_sum = variadic_reduce.make_variadic_reduce(_kahan_reduction)
+
+
+class Kahan(collections.namedtuple('Kahan', ['total', 'correction'])):
+ """Result of Kahan summation, i.e. `sum = total - correction`."""
+ __slots__ = ()
+
+ def __add__(self, x):
+ return Kahan._make(_kahan_reduction(
+ self, x if isinstance(x, Kahan) else (x, 0)))
+
+ def __radd__(self, x):
+ return Kahan._make(_kahan_reduction(
+ self, x if isinstance(x, Kahan) else (x, 0)))
+
+ def __neg__(self):
+ return Kahan(-self.total, -self.correction)
+
+ def __sub__(self, y):
+ return Kahan._make(_kahan_reduction(
+ self, -y if isinstance(y, Kahan) else (-y, 0)))
+
+ def __rsub__(self, x):
+ return Kahan._make(_kahan_reduction(
+ x if isinstance(x, Kahan) else (x, 0), -self))
+
+
+def reduce_kahan_sum(input_tensor, axis=None, keepdims=False, name=None):
+ """Reduces the input tensor along the given axis using Kahan summation.
+
+ Returns both the total and the correction term, as a `namedtuple`, so that a
+ more accurate sum may be written as `total - correction`.
+
+ A practical use-case is computing the difference of two large (magnitude) sums
+ we expect to be nearly equal. If instead we take their difference as
+ `(s0.total - s1.total) - (s0.correction - s1.correction)`, we can retain more
+ precision in computing their difference.
+
+ Note: (TF + JAX) This function does not work properly on XLA:CPU without the
+ environment variable: `XLA_FLAGS=--xla_cpu_enable_fast_math=false`, due to
+ LLVM's reassociation optimizations, which simplify error terms to zero.
+
+ Args:
+ input_tensor: The tensor to sum.
+ axis: One of `None`, a Python `int`, or a sequence of Python `int`. The axes
+ to be reduced. `None` is taken as "reduce all axes".
+ keepdims: Python `bool` indicating whether we return a tensor with singleton
+ dimensions in the reduced axes (`True`), or squeeze the axes out (default,
+ `False`).
+ name: Optional name for ops in scope.
+
+ Returns:
+ reduced: A `Kahan(total, correction)` namedtuple.
+ """
+ with tf.name_scope(name or 'reduce_kahan_sum'):
+ t = tf.convert_to_tensor(input_tensor)
+ operands = (t, tf.zeros_like(t))
+ inits = (tf.zeros([], dtype=t.dtype),) * 2
+ return Kahan._make(
+ _reduce_kahan_sum(operands, inits, axis=axis, keepdims=keepdims))
+
+
def reduce_logmeanexp(input_tensor, axis=None, keepdims=False, name=None):
"""Computes `log(mean(exp(input_tensor)))`.
@@ -146,7 +224,7 @@ def reduce_logmeanexp(input_tensor, axis=None, keepdims=False, name=None):
"""
with tf.name_scope(name or 'reduce_logmeanexp'):
lse = tf.reduce_logsumexp(input_tensor, axis=axis, keepdims=keepdims)
- n = prefer_static.size(input_tensor) // prefer_static.size(lse)
+ n = ps.size(input_tensor) // ps.size(lse)
log_n = tf.math.log(tf.cast(n, lse.dtype))
return lse - log_n
diff --git a/tensorflow_probability/python/math/generic_test.py b/tensorflow_probability/python/math/generic_test.py
index e8773207b1..0f41ecccf1 100644
--- a/tensorflow_probability/python/math/generic_test.py
+++ b/tensorflow_probability/python/math/generic_test.py
@@ -18,6 +18,9 @@
from __future__ import division
from __future__ import print_function
+import functools
+import os
+
# Dependency imports
from absl.testing import parameterized
import numpy as np
@@ -612,5 +615,61 @@ def testMatchesArgsort(self, shape, temperature):
actual_sort_ = np.argmax(soft_sort_permutation_, axis=-1)
self.assertAllClose(expected_sort, actual_sort_)
+
+@test_util.test_all_tf_execution_regimes
+class _KahanSumTest(test_util.TestCase):
+
+ @parameterized.named_parameters(
+ dict(testcase_name='_all',
+ sample_shape=[3, int(1e6)], axis=None),
+ dict(testcase_name='_ax1',
+ sample_shape=[13, int(1e6)], axis=1),
+ dict(testcase_name='_ax1_list_keepdims',
+ sample_shape=[13, int(1e6)], axis=[-1], keepdims=True),
+ dict(testcase_name='_ax_both_tuple',
+ sample_shape=[3, int(1e6)], axis=(-2, 1)),
+ dict(testcase_name='_ax_01_keepdims',
+ sample_shape=[2, int(1e6), 13], axis=[0, 1], keepdims=True),
+ dict(testcase_name='_ax_21',
+ sample_shape=[13, int(1e6), 3], axis=[2, -2]))
+ def testKahanSum(self, sample_shape, axis, keepdims=False):
+ fn = functools.partial(tfp.math.reduce_kahan_sum,
+ axis=axis, keepdims=keepdims)
+ if self.jit:
+ self.skip_if_no_xla()
+ fn = tf.function(fn, experimental_compile=True)
+ dist = tfd.MixtureSameFamily(tfd.Categorical(logits=[0., 0]),
+ tfd.Normal(loc=[0., 1e6], scale=[1., 1e3]))
+ vals = self.evaluate(dist.sample(sample_shape, seed=test_util.test_seed()))
+ oracle = tf.reduce_sum(tf.cast(vals, tf.float64), axis=axis,
+ keepdims=keepdims)
+ result = fn(vals)
+ self.assertEqual(oracle.shape, result.total.shape)
+ self.assertEqual(oracle.shape, result.correction.shape)
+ kahan64 = (tf.cast(result.total, tf.float64) -
+ self.evaluate(tf.cast(result.correction, tf.float64)))
+ if np.prod(result.correction.shape) > 1:
+ self.assertNotAllEqual(
+ result.correction, tf.zeros_like(result.correction))
+ self.assertAllClose(oracle, kahan64) # passes even with --vary_seed
+ # The counterpoint naive sum below would not typically pass (but does not
+ # reliably fail, either). It can fail w/ rtol as high as 0.006.
+ # naive_sum = tf.cast(tf.reduce_sum(vals, axis=axis, keepdims=keepdims),
+ # tf.float64)
+ # self.assertAllClose(oracle, naive_sum)
+
+
+class KahanSumJitTest(_KahanSumTest):
+ jit = True
+
+
+class KahanSumTest(_KahanSumTest):
+ jit = False
+
+del _KahanSumTest
+
+
if __name__ == '__main__':
+ # TODO(b/173158845): XLA:CPU reassociates away the Kahan correction term.
+ os.environ['XLA_FLAGS'] = '--xla_cpu_enable_fast_math=false'
tf.test.main()
diff --git a/tensorflow_probability/python/math/gradient.py b/tensorflow_probability/python/math/gradient.py
index 9b609d3ff9..b4de25e263 100644
--- a/tensorflow_probability/python/math/gradient.py
+++ b/tensorflow_probability/python/math/gradient.py
@@ -345,21 +345,19 @@ def _value_and_grad_impl(f, grad_fn, *args, output_gradients,
def _prepare_args(args, kwargs):
"""Returns structures like inputs with values as Tensors."""
- i = [-1]
+ i = -1
def c2t(x):
+ nonlocal i
# Don't use convert_nonref_to_tensor here. We want to have semantics like
# tf.GradientTape which watches only trainable_variables. (Note: we also
# don't want to cal c2t on non-trainable variables since these are already
# watchable by GradientTape.)
if tensor_util.is_module(x) or tensor_util.is_variable(x):
return x
- i[0] += 1
+ i += 1
return tf.convert_to_tensor(
- x, dtype_hint=tf.float32, name='x{}'.format(i[0]))
- return (
- type(args)(c2t(v) for v in args),
- type(kwargs)((k, c2t(v)) for k, v in kwargs.items()),
- )
+ x, dtype_hint=tf.float32, name='x{}'.format(i))
+ return tf.nest.map_structure(c2t, (args, kwargs))
def _has_args(fn):
diff --git a/tensorflow_probability/python/math/gradient_test.py b/tensorflow_probability/python/math/gradient_test.py
index b1193c7f28..bdadd43678 100644
--- a/tensorflow_probability/python/math/gradient_test.py
+++ b/tensorflow_probability/python/math/gradient_test.py
@@ -56,6 +56,16 @@ def test_list(self):
self.assertAllClose(f(*args), y, atol=1e-6, rtol=1e-6)
self.assertAllClose(g(*args), dydx, atol=1e-6, rtol=1e-6)
+ @test_util.numpy_disable_gradient_test
+ def test_nested(self):
+ f = lambda value: value['x'] * value['y']
+ g = lambda value: {'x': value['y'], 'y': value['x']}
+ args = {'x': np.linspace(0, 100, int(1e1)),
+ 'y': np.linspace(-100, 0, int(1e1))}
+ y, dydx = self.evaluate(tfm.value_and_gradient(f, args))
+ self.assertAllClose(f(args), y, atol=1e-6, rtol=1e-6)
+ self.assertAllCloseNested(g(args), dydx, atol=1e-6, rtol=1e-6)
+
@test_util.numpy_disable_gradient_test
def test_output_list(self):
f = lambda x, y: [x, x * y]
diff --git a/tensorflow_probability/python/math/gram_schmidt.py b/tensorflow_probability/python/math/gram_schmidt.py
index 63fa215492..41dcd459da 100644
--- a/tensorflow_probability/python/math/gram_schmidt.py
+++ b/tensorflow_probability/python/math/gram_schmidt.py
@@ -53,7 +53,7 @@ def gram_schmidt(vectors, num_vectors=None):
[2] P. S. Laplace, Thiorie Analytique des Probabilites. Premier Supple'ment,
Mme. Courtier, Paris, 1816.
- [3] E. Schmidt, ijber die Auflosung linearer Gleichungen mit unendlich vielen
+ [3] E. Schmidt, über die Auflosung linearer Gleichungen mit unendlich vielen
Unbekannten, Rend. Circ. Mat. Pulermo (1) 25:53-77 (1908).
Args:
diff --git a/tensorflow_probability/python/math/psd_kernels/hypothesis_testlib.py b/tensorflow_probability/python/math/psd_kernels/hypothesis_testlib.py
index 5223d2bd74..b072c32186 100644
--- a/tensorflow_probability/python/math/psd_kernels/hypothesis_testlib.py
+++ b/tensorflow_probability/python/math/psd_kernels/hypothesis_testlib.py
@@ -18,6 +18,9 @@
from __future__ import division
from __future__ import print_function
+import contextlib
+import re
+
import hypothesis as hp
from hypothesis.extra import numpy as hpnp
@@ -157,6 +160,26 @@ def _grid_indices_to_values(grid_indices):
return result
+@contextlib.contextmanager
+def no_pd_errors():
+ """Catch and ignore examples where a Cholesky decomposition fails.
+
+ This will typically occur when the matrix is not positive definite.
+
+ Yields:
+ None
+ """
+ # TODO(b/174591555): Instead of catching and `assume`ing away positive
+ # definite errors, avoid them in the first place.
+ try:
+ yield
+ except tf.errors.InvalidArgumentError as e:
+ if re.search(r'Cholesky decomposition was not successful', str(e)):
+ hp.assume(False)
+ else:
+ raise
+
+
@hps.composite
def broadcasting_params(draw,
kernel_name,
diff --git a/tensorflow_probability/python/math/root_search.py b/tensorflow_probability/python/math/root_search.py
index 99da7d3966..9d55368bdb 100644
--- a/tensorflow_probability/python/math/root_search.py
+++ b/tensorflow_probability/python/math/root_search.py
@@ -20,15 +20,21 @@
import collections
+import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.internal import assert_util
+from tensorflow_probability.python.internal import callable_util
+from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
+from tensorflow_probability.python.internal import tensorshape_util
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
+NUMPY_MODE = False
__all__ = [
+ 'bracket_root',
'secant_root',
'find_root_chandrupatla',
'find_root_secant',
@@ -338,8 +344,8 @@ def _structure_broadcasting_where(c, x, y):
def find_root_chandrupatla(objective_fn,
- low,
- high,
+ low=None,
+ high=None,
position_tolerance=1e-8,
value_tolerance=0.,
max_iterations=50,
@@ -362,9 +368,15 @@ def find_root_chandrupatla(objective_fn,
callable of a single variable. `objective_fn` must return a `Tensor` with
shape `batch_shape` and dtype matching `lower_bound` and `upper_bound`.
low: Float `Tensor` of shape `batch_shape` representing a lower
- bound(s) on the value of a root(s).
+ bound(s) on the value of a root(s). If either of `low` or `high` is not
+ provided, both are ignored and `tfp.math.bracket_root` is used to attempt
+ to infer bounds.
+ Default value: `None`.
high: Float `Tensor` of shape `batch_shape` representing an upper
- bound(s) on the value of a root(s).
+ bound(s) on the value of a root(s). If either of `low` or `high` is not
+ provided, both are ignored and `tfp.math.bracket_root` is used to attempt
+ to infer bounds.
+ Default value: `None`.
position_tolerance: Optional `Tensor` representing the maximum absolute
error in the positions of the estimated roots. Shape must broadcast with
`batch_shape`.
@@ -492,8 +504,18 @@ def _body(a, b, f_a, f_b, t, num_iterations, converged):
with tf.name_scope(name):
max_iterations = tf.convert_to_tensor(
max_iterations, name='max_iterations', dtype_hint=tf.int32)
- a = tf.convert_to_tensor(low, name='lower_bound')
- b = tf.convert_to_tensor(high, name='upper_bound')
+ dtype = dtype_util.common_dtype(
+ [low, high, position_tolerance, value_tolerance], dtype_hint=tf.float32)
+ position_tolerance = tf.convert_to_tensor(
+ position_tolerance, name='position_tolerance', dtype=dtype)
+ value_tolerance = tf.convert_to_tensor(
+ value_tolerance, name='value_tolerance', dtype=dtype)
+
+ if low is None or high is None:
+ a, b = bracket_root(objective_fn, dtype=dtype)
+ else:
+ a = tf.convert_to_tensor(low, name='lower_bound', dtype=dtype)
+ b = tf.convert_to_tensor(high, name='upper_bound', dtype=dtype)
f_a, f_b = objective_fn(a), objective_fn(b)
batch_shape = ps.broadcast_shape(ps.shape(f_a), ps.shape(f_b))
@@ -529,3 +551,95 @@ def _body(a, b, f_a, f_b, t, num_iterations, converged):
estimated_root=x_best,
objective_at_estimated_root=f_best,
num_iterations=num_iterations)
+
+
+def bracket_root(objective_fn,
+ dtype=tf.float32,
+ num_points=512,
+ name='bracket_root'):
+ """Finds bounds that bracket a root of the objective function.
+
+ This method attempts to return an interval bracketing a root of the objective
+ function. It evaluates the objective in parallel at `num_points`
+ locations, at exponentially increasing distance from the origin, and returns
+ the first pair of adjacent points `[low, high]` such that the objective is
+ finite and has a different sign at the two points. If no such pair was
+ observed, it returns the trivial interval
+ `[np.finfo(dtype).min, np.finfo(dtype).max]` containing all float values of
+ the specified `dtype`. If the objective has multiple
+ roots, the returned interval will contain at least one (but perhaps not all)
+ of the roots.
+
+ Args:
+ objective_fn: Python callable for which roots are searched. It must be a
+ continuous function that accepts a scalar `Tensor` of type `dtype` and
+ returns a `Tensor` of shape `batch_shape`.
+ dtype: Optional float `dtype` of inputs to `objective_fn`.
+ Default value: `tf.float32`.
+ num_points: Optional Python `int` number of points at which to evaluate
+ the objective.
+ Default value: `512`.
+ name: Python `str` name given to ops created by this method.
+ Returns:
+ low: Float `Tensor` of shape `batch_shape` and dtype `dtype`. Lower bound
+ on a root of `objective_fn`.
+ high: Float `Tensor` of shape `batch_shape` and dtype `dtype`. Upper bound
+ on a root of `objective_fn`.
+ """
+ with tf.name_scope(name):
+ # Build a logarithmic sequence of `num_points` values from -inf to inf.
+ dtype_info = np.finfo(dtype_util.as_numpy_dtype(dtype))
+ xs_positive = tf.exp(tf.linspace(tf.cast(-10., dtype),
+ tf.math.log(dtype_info.max),
+ num_points // 2))
+ xs = tf.concat([-xs_positive, xs_positive], axis=0)
+
+ # Evaluate the objective at all points. The objective function may return
+ # a batch of values (e.g., `objective(x) = x - batch_of_roots`).
+ if NUMPY_MODE:
+ objective_output_spec = objective_fn(tf.zeros([], dtype=dtype))
+ else:
+ objective_output_spec = callable_util.get_output_spec(
+ objective_fn,
+ tf.convert_to_tensor(0., dtype=dtype))
+ batch_ndims = tensorshape_util.rank(objective_output_spec.shape)
+ if batch_ndims is None:
+ raise ValueError('Cannot infer tensor rank of objective values.')
+ xs_pad_shape = ps.pad([num_points],
+ paddings=[[0, batch_ndims]],
+ constant_values=1)
+ ys = objective_fn(tf.reshape(xs, xs_pad_shape))
+
+ # Find the smallest point where the objective is finite.
+ is_finite = tf.math.is_finite(ys)
+ ys_transposed = distribution_util.move_dimension( # For batch gather.
+ ys, 0, -1)
+ first_finite_value = tf.gather(
+ ys_transposed,
+ tf.argmax(is_finite, axis=0), # Index of smallest finite point.
+ batch_dims=batch_ndims,
+ axis=-1)
+ # Select the next point where the objective has a different sign.
+ sign_change_idx = tf.argmax(
+ tf.not_equal(tf.math.sign(ys),
+ tf.math.sign(first_finite_value)) & is_finite,
+ axis=0)
+ # If the sign never changes, we can't bracket a root.
+ bracketing_failed = tf.equal(sign_change_idx, 0)
+ # If the objective's sign is zero, we've found an actual root.
+ root_found = tf.equal(tf.gather(tf.math.sign(ys_transposed),
+ sign_change_idx,
+ batch_dims=batch_ndims,
+ axis=-1),
+ 0.)
+ return _structure_broadcasting_where(
+ bracketing_failed,
+ # If we didn't detect a sign change, fall back to the trivial interval.
+ (dtype_info.min, dtype_info.max),
+ # Otherwise, return the points around the sign change, unless we
+ # actually evaluated a root, in which case, return the zero-width
+ # bracket at that root.
+ (tf.gather(xs, tf.where(bracketing_failed | root_found,
+ sign_change_idx,
+ sign_change_idx - 1)),
+ tf.gather(xs, sign_change_idx)))
diff --git a/tensorflow_probability/python/math/root_search_test.py b/tensorflow_probability/python/math/root_search_test.py
index 7a0754f9cb..db998878f3 100644
--- a/tensorflow_probability/python/math/root_search_test.py
+++ b/tensorflow_probability/python/math/root_search_test.py
@@ -19,6 +19,7 @@
from __future__ import print_function
# Dependency imports
+import numpy as np
import scipy.optimize as optimize
import tensorflow.compat.v2 as tf
@@ -257,5 +258,57 @@ def test_chandrupatla_invalid_bounds(self):
4.,
validate_args=True))
+ def test_chandrupatla_automatically_selects_bounds(self):
+ expected_roots = 1e6 * samplers.normal(
+ [4, 3], seed=test_util.test_seed(sampler_type='stateless'))
+ _, value_at_roots, _ = tfp.math.find_root_chandrupatla(
+ objective_fn=lambda x: (x - expected_roots)**5,
+ position_tolerance=1e-8)
+ self.assertAllClose(value_at_roots, tf.zeros_like(value_at_roots))
+
+
+@test_util.test_all_tf_execution_regimes
+class BracketRootTest(test_util.TestCase):
+
+ def test_batch_with_nans(self):
+
+ idxs = np.arange(20, dtype=np.float32)
+ bounds = np.reshape(np.exp(idxs), [4, -1])
+ roots = np.reshape(1. / (20. - idxs), [4, -1])
+ def objective_fn(x):
+ return tf.where(x < -bounds,
+ np.nan,
+ tf.where(x > bounds,
+ np.inf,
+ (x - roots)**3))
+ low, high = self.evaluate(tfp.math.bracket_root(objective_fn))
+ f_low, f_high = self.evaluate((objective_fn(low), objective_fn(high)))
+ self.assertAllFinite(f_low)
+ self.assertAllFinite(f_high)
+ self.assertAllTrue(low < roots)
+ self.assertAllTrue(high > roots)
+
+ def test_returns_zero_width_bracket_at_root(self):
+ root = tf.exp(-10.)
+ low, high = self.evaluate(tfp.math.bracket_root(lambda x: (x - root)))
+ self.assertAllClose(low, root)
+ self.assertAllClose(high, root)
+
+ def test_backs_off_to_trivial_bracket(self):
+ dtype_info = np.finfo(np.float32)
+ low, high = self.evaluate(tfp.math.bracket_root(
+ lambda x: np.nan * x, dtype=np.float32))
+ self.assertEqual(low, dtype_info.min)
+ self.assertEqual(high, dtype_info.max)
+
+ def test_float64(self):
+ low, high = self.evaluate(tfp.math.bracket_root(
+ lambda x: (x - np.pi)**3, dtype=np.float64))
+ self.assertEqual(low.dtype, np.float64)
+ self.assertEqual(high.dtype, np.float64)
+ self.assertLess(low, np.pi)
+ self.assertGreater(high, np.pi)
+
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow_probability/python/mcmc/BUILD b/tensorflow_probability/python/mcmc/BUILD
index 5432cc4e05..0f3fc734c2 100644
--- a/tensorflow_probability/python/mcmc/BUILD
+++ b/tensorflow_probability/python/mcmc/BUILD
@@ -431,7 +431,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/internal:tensorshape_util",
"//tensorflow_probability/python/mcmc/internal:slice_sampler_utils",
"//tensorflow_probability/python/mcmc/internal:util",
- "//tensorflow_probability/python/util:seed_stream",
+ "//tensorflow_probability/python/random:random_ops",
],
)
diff --git a/tensorflow_probability/python/mcmc/internal/util.py b/tensorflow_probability/python/mcmc/internal/util.py
index bf93304404..4eb890b9b2 100644
--- a/tensorflow_probability/python/mcmc/internal/util.py
+++ b/tensorflow_probability/python/mcmc/internal/util.py
@@ -144,7 +144,8 @@ def make_name(super_name, default_super_name, sub_name):
def _choose_base_case(is_accepted,
proposed,
current,
- name=None):
+ name=None,
+ addr=None,):
"""Helper to `choose` which expand_dims `is_accepted` and applies tf.where."""
def _where(proposed, current):
"""Wraps `tf.where`."""
@@ -162,30 +163,38 @@ def _where(proposed, current):
with tf.name_scope(name or 'choose'):
if not is_list_like(proposed):
return _where(proposed, current)
- return [(choose(is_accepted, p, c, name=name) if is_namedtuple_like(p)
- else _where(p, c))
- for p, c in zip(proposed, current)]
+ return tf.nest.pack_sequence_as(
+ current,
+ [(_choose_recursive(is_accepted, p, c, name=name, addr=f'{addr}[i]')
+ if is_namedtuple_like(p) else
+ _where(p, c)) for i, (p, c) in enumerate(zip(proposed, current))])
-def choose(is_accepted, proposed, current, name=None):
- """Helper which expand_dims `is_accepted` then applies tf.where."""
+def _choose_recursive(is_accepted, proposed, current, name=None, addr=''):
+ """Recursion helper which also reports the address of any failures."""
with tf.name_scope(name or 'choose'):
if not is_namedtuple_like(proposed):
- return _choose_base_case(is_accepted, proposed, current, name=name)
+ return _choose_base_case(is_accepted, proposed, current, name=name,
+ addr=addr)
if not isinstance(proposed, type(current)):
- raise TypeError('Type of `proposed` ({}) must be identical to '
- 'type of `current` ({})'.format(
- type(proposed).__name__,
- type(current).__name__))
+ raise TypeError(
+ f'Type of `proposed` ({type(proposed).__name__}) must be identical '
+ f'to type of `current` ({type(current).__name__}). (At "{addr}".)')
items = {}
for fn in proposed._fields:
- items[fn] = choose(is_accepted,
- getattr(proposed, fn),
- getattr(current, fn),
- name=name)
+ items[fn] = _choose_recursive(is_accepted,
+ getattr(proposed, fn),
+ getattr(current, fn),
+ name=name,
+ addr=f'{addr}/{fn}')
return type(proposed)(**items)
+def choose(is_accepted, proposed, current, name=None):
+ """Helper which expand_dims `is_accepted` then applies tf.where."""
+ return _choose_recursive(is_accepted, proposed, current, name=name)
+
+
def strip_seeds(obj):
if not is_namedtuple_like(obj):
return obj
diff --git a/tensorflow_probability/python/mcmc/metropolis_hastings.py b/tensorflow_probability/python/mcmc/metropolis_hastings.py
index b589f18cdd..f61dbd7459 100644
--- a/tensorflow_probability/python/mcmc/metropolis_hastings.py
+++ b/tensorflow_probability/python/mcmc/metropolis_hastings.py
@@ -196,6 +196,8 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
current_state,
previous_kernel_results.accepted_results,
**inner_kwargs)
+ if mcmc_util.is_list_like(current_state):
+ proposed_state = tf.nest.pack_sequence_as(current_state, proposed_state)
if (not has_target_log_prob(proposed_results) or
not has_target_log_prob(previous_kernel_results.accepted_results)):
diff --git a/tensorflow_probability/python/mcmc/sample_test.py b/tensorflow_probability/python/mcmc/sample_test.py
index 01be949bba..b3bb921046 100644
--- a/tensorflow_probability/python/mcmc/sample_test.py
+++ b/tensorflow_probability/python/mcmc/sample_test.py
@@ -22,8 +22,8 @@
import warnings
# Dependency imports
+from absl.testing import parameterized
import numpy as np
-
import tensorflow.compat.v2 as tf
import tensorflow_probability as tfp
@@ -31,6 +31,11 @@
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.internal import test_util
+tfb = tfp.bijectors
+tfd = tfp.distributions
+
+
+NUMPY_MODE = False
TestTransitionKernelResults = collections.namedtuple(
'TestTransitionKernelResults', 'counter_1, counter_2')
@@ -338,6 +343,102 @@ def testSeedReproducibility(self):
self.assertAllCloseNested(
first_final_state, second_final_state, rtol=1e-6)
+ @parameterized.named_parameters(
+ dict(testcase_name='RWM_tuple',
+ kernel_from_log_prob=tfp.mcmc.RandomWalkMetropolis,
+ sample_dtype=(tf.float32,) * 4),
+ dict(testcase_name='RWM_namedtuple',
+ kernel_from_log_prob=tfp.mcmc.RandomWalkMetropolis),
+ dict(testcase_name='HMC_tuple',
+ kernel_from_log_prob=lambda lp_fn: tfp.mcmc.HamiltonianMonteCarlo( # pylint: disable=g-long-lambda
+ lp_fn, step_size=0.1, num_leapfrog_steps=10),
+ skip='HMC requires gradients' if NUMPY_MODE else '',
+ sample_dtype=(tf.float32,) * 4),
+ dict(testcase_name='HMC_namedtuple',
+ kernel_from_log_prob=lambda lp_fn: tfp.mcmc.HamiltonianMonteCarlo( # pylint: disable=g-long-lambda
+ lp_fn, step_size=0.1, num_leapfrog_steps=10),
+ skip='HMC requires gradients' if NUMPY_MODE else '')
+ )
+ def testStructuredState(self, kernel_from_log_prob, skip='',
+ **model_kwargs):
+ if skip:
+ self.skipTest(skip)
+ seed_stream = test_util.test_seed_stream()
+
+ n = 300
+ p = 50
+ x = tf.random.normal([n, p], seed=seed_stream())
+
+ def beta_proportion(mu, kappa):
+ return tfd.Beta(concentration0=mu * kappa,
+ concentration1=(1 - mu) * kappa)
+
+ root = tfd.JointDistributionCoroutine.Root
+ def model_coroutine():
+ beta = yield root(tfd.Sample(tfd.Normal(0, 1), [p], name='beta'))
+ alpha = yield root(tfd.Normal(0, 1, name='alpha'))
+ kappa = yield root(tfd.Gamma(1, 1, name='kappa'))
+ mu = tf.math.sigmoid(alpha[..., tf.newaxis] +
+ tf.einsum('...p,np->...n', beta, x))
+ yield tfd.Independent(beta_proportion(mu, kappa[..., tf.newaxis]),
+ reinterpreted_batch_ndims=1,
+ name='prob')
+
+ model = tfd.JointDistributionCoroutine(model_coroutine, **model_kwargs)
+ probs = model.sample(seed=seed_stream())[-1]
+ pinned = model.experimental_pin(prob=probs)
+
+ kernel = kernel_from_log_prob(pinned.unnormalized_log_prob)
+ nburnin = 5
+ if not isinstance(kernel, tfp.mcmc.RandomWalkMetropolis):
+ kernel = tfp.mcmc.SimpleStepSizeAdaptation(
+ kernel, num_adaptation_steps=nburnin // 2)
+ kernel = tfp.mcmc.TransformedTransitionKernel(
+ kernel, pinned.experimental_default_event_space_bijector())
+ nchains = 4
+
+ @tf.function
+ def sample():
+ return tfp.mcmc.sample_chain(
+ 1, current_state=pinned.sample_unpinned(nchains, seed=seed_stream()),
+ kernel=kernel, num_burnin_steps=nburnin, trace_fn=None,
+ seed=seed_stream())
+ self.evaluate(sample())
+
+ @test_util.jax_disable_test_missing_functionality('PHMC b/175107050')
+ @test_util.numpy_disable_gradient_test('HMC')
+ def testStructuredState2(self):
+ @tfd.JointDistributionCoroutineAutoBatched
+ def model():
+ mu = yield tfd.Sample(tfd.Normal(0, 1), [65], name='mu')
+ sigma = yield tfd.Sample(tfd.Exponential(1.), [65], name='sigma')
+ beta = yield tfd.Sample(
+ tfd.Normal(loc=tf.gather(mu, tf.range(436) % 65, axis=-1),
+ scale=tf.gather(sigma, tf.range(436) % 65, axis=-1)),
+ 4, name='beta')
+ _ = yield tfd.Multinomial(total_count=100.,
+ logits=tfb.Pad([[0, 1]])(beta),
+ name='y')
+
+ stream = test_util.test_seed_stream()
+ pinned = model.experimental_pin(y=model.sample(seed=stream()).y)
+ struct = pinned.dtype
+ stddevs = struct._make([
+ tf.fill([65], .1), tf.fill([65], 1.), tf.fill([436, 4], 10.)])
+ momentum_dist = tfd.JointDistributionNamedAutoBatched(
+ struct._make(tfd.Normal(0, 1 / std) for std in stddevs))
+ kernel = tfp.experimental.mcmc.PreconditionedHamiltonianMonteCarlo(
+ pinned.unnormalized_log_prob,
+ step_size=.1, num_leapfrog_steps=10,
+ momentum_distribution=momentum_dist)
+ bijector = pinned.experimental_default_event_space_bijector()
+ kernel = tfp.mcmc.TransformedTransitionKernel(kernel, bijector)
+ state = bijector(struct._make(
+ tfd.Uniform(-2., 2.).sample(shp)
+ for shp in bijector.inverse_event_shape(pinned.event_shape)))
+ self.evaluate(tfp.mcmc.sample_chain(
+ 3, current_state=state, kernel=kernel, seed=stream()))
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow_probability/python/mcmc/slice_sampler_kernel.py b/tensorflow_probability/python/mcmc/slice_sampler_kernel.py
index 86e6fcc913..28154359da 100644
--- a/tensorflow_probability/python/mcmc/slice_sampler_kernel.py
+++ b/tensorflow_probability/python/mcmc/slice_sampler_kernel.py
@@ -30,6 +30,7 @@
from tensorflow_probability.python.mcmc import kernel as kernel_base
from tensorflow_probability.python.mcmc.internal import slice_sampler_utils as ssu
from tensorflow_probability.python.mcmc.internal import util as mcmc_util
+from tensorflow_probability.python.random import random_ops
__all__ = [
@@ -341,26 +342,20 @@ def bootstrap_results(self, init_state):
def _choose_random_direction(current_state_parts, batch_rank, seed=None):
"""Chooses a random direction in the event space."""
seeds = samplers.split_seed(seed, n=len(current_state_parts))
- # Chooses the random directions across each of the input components.
- rnd_direction_parts = [
- samplers.normal(
- ps.shape(current_state_part), dtype=tf.float32, seed=part_seed)
- for (current_state_part, part_seed) in zip(current_state_parts, seeds)
- ]
-
- # Sum squares over all of the input components. Note this takes all
- # components into account.
- sum_squares = sum(
- tf.reduce_sum( # pylint: disable=g-complex-comprehension
- rnd_direction**2,
- axis=ps.range(batch_rank, ps.rank(rnd_direction)),
- keepdims=True) for rnd_direction in rnd_direction_parts)
-
- # Normalizes the random direction fragments.
- rnd_direction_parts = [rnd_direction / tf.sqrt(sum_squares)
- for rnd_direction in rnd_direction_parts]
-
- return rnd_direction_parts
+ # Sample random directions across each of the input components.
+ def _sample_direction_part(state_part, part_seed):
+ state_part_shape = ps.shape(state_part)
+ batch_shape = state_part_shape[:batch_rank]
+ dimension = ps.reduce_prod(state_part_shape[batch_rank:])
+ return ps.reshape(
+ random_ops.spherical_uniform(
+ shape=batch_shape,
+ dimension=dimension,
+ dtype=state_part.dtype,
+ seed=part_seed),
+ state_part_shape)
+ return [_sample_direction_part(state_part, seed)
+ for state_part, seed in zip(current_state_parts, seeds)]
def _sample_next(target_log_prob_fn,
diff --git a/tensorflow_probability/python/mcmc/transformed_kernel.py b/tensorflow_probability/python/mcmc/transformed_kernel.py
index 27a19dd4cd..1e69098925 100644
--- a/tensorflow_probability/python/mcmc/transformed_kernel.py
+++ b/tensorflow_probability/python/mcmc/transformed_kernel.py
@@ -81,7 +81,9 @@ def fn(state_parts):
if len(bijector) != len(state_parts):
raise ValueError('State has {} parts, but bijector has {}.'.format(
len(state_parts), len(bijector)))
- return [getattr(b, direction)(sp) for b, sp in zip(bijector, state_parts)]
+ transformed_parts = [
+ getattr(b, direction)(sp) for b, sp in zip(bijector, state_parts)]
+ return tf.nest.pack_sequence_as(state_parts, transformed_parts)
return fn
@@ -397,8 +399,9 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
with tf.name_scope(mcmc_util.make_name(
self.name, 'transformed_kernel', 'one_step')):
inner_kwargs = {} if seed is None else dict(seed=seed)
+ transformed_prev_state = previous_kernel_results.transformed_state
transformed_next_state, kernel_results = self._inner_kernel.one_step(
- previous_kernel_results.transformed_state,
+ transformed_prev_state,
previous_kernel_results.inner_results,
**inner_kwargs)
transformed_next_state_parts = (
@@ -410,6 +413,9 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
next_state = (
next_state_parts if mcmc_util.is_list_like(transformed_next_state)
else next_state_parts[0])
+ if mcmc_util.is_list_like(transformed_prev_state):
+ transformed_next_state = tf.nest.pack_sequence_as(
+ transformed_prev_state, transformed_next_state)
kernel_results = TransformedTransitionKernelResults(
transformed_state=transformed_next_state,
inner_results=kernel_results)
@@ -472,14 +478,15 @@ def bootstrap_results(self, init_state=None, transformed_init_state=None):
transformed_init_state_parts = (
self._transform_target_support_to_unconstrained(init_state_parts))
transformed_init_state = (
- transformed_init_state_parts if mcmc_util.is_list_like(init_state)
+ tf.nest.pack_sequence_as(init_state, transformed_init_state_parts)
+ if mcmc_util.is_list_like(init_state)
else transformed_init_state_parts[0])
else:
if mcmc_util.is_list_like(transformed_init_state):
- transformed_init_state = [
- tf.convert_to_tensor(s, name='transformed_init_state')
- for s in transformed_init_state
- ]
+ transformed_init_state = tf.nest.pack_sequence_as(
+ transformed_init_state,
+ [tf.convert_to_tensor(s, name='transformed_init_state')
+ for s in transformed_init_state])
else:
transformed_init_state = tf.convert_to_tensor(
value=transformed_init_state, name='transformed_init_state')
diff --git a/tensorflow_probability/python/optimizer/bfgs.py b/tensorflow_probability/python/optimizer/bfgs.py
index 86f60a946d..068ac776c7 100644
--- a/tensorflow_probability/python/optimizer/bfgs.py
+++ b/tensorflow_probability/python/optimizer/bfgs.py
@@ -315,12 +315,11 @@ def _inv_hessian_control_inputs(inv_hessian):
# The easiest way to validate if the inverse Hessian is positive definite is
# to compute its Cholesky decomposition.
is_positive_definite = tf.reduce_all(
- tf.math.is_finite(tf.linalg.cholesky(inv_hessian)),
- axis=[-1, -2])
+ tf.math.is_finite(tf.linalg.cholesky(inv_hessian)))
# Then check that the supplied inverse Hessian is symmetric.
- is_symmetric = tf.equal(bfgs_utils.norm(
- inv_hessian - _batch_transpose(inv_hessian), dims=2), 0)
+ is_symmetric = tf.reduce_all(tf.equal(bfgs_utils.norm(
+ inv_hessian - _batch_transpose(inv_hessian), dims=2), 0))
# Simply adding a control dependencies on these results is not enough to
# trigger them, we need to add asserts on the results.
diff --git a/tensorflow_probability/python/optimizer/bfgs_test.py b/tensorflow_probability/python/optimizer/bfgs_test.py
index 22e56afdc1..7f267c1d42 100644
--- a/tensorflow_probability/python/optimizer/bfgs_test.py
+++ b/tensorflow_probability/python/optimizer/bfgs_test.py
@@ -151,6 +151,29 @@ def quadratic(x):
quadratic, initial_position=start, tolerance=1e-8,
initial_inverse_hessian_estimate=bad_inv_hessian))
+ def test_batched_inverse_hessian(self):
+ """Checks that specifying a batch of inverse hessians works."""
+ minimum = np.array([1.0, 1.0], dtype=np.float32)
+ scales = np.array([2.0, 3.0], dtype=np.float32)
+
+ @_make_val_and_grad_fn
+ def batched_quadratic(x):
+ return tf.reduce_sum(
+ scales * tf.math.squared_difference(x, minimum), axis=-1)
+
+ start = tf.constant([[0.6, 0.8], [0.5, 0.5]], dtype=tf.float32)
+ test_inv_hessian = tf.constant([[[2.0, 1.0], [1.0, 2.0]],
+ [[1.0, 0.0], [0.0, 1.0]]], dtype=tf.float32)
+ results = self.evaluate(tfp.optimizer.bfgs_minimize(
+ batched_quadratic, initial_position=start, tolerance=1e-8,
+ initial_inverse_hessian_estimate=test_inv_hessian))
+ self.assertAllTrue(results.converged)
+ final_gradient = results.objective_gradient
+ final_gradient_norm = _norm(final_gradient)
+ self.assertAllLessEqual(final_gradient_norm, 1e-8)
+ self.assertArrayNear(results.position[0], minimum, 1e-5)
+ self.assertArrayNear(results.position[1], minimum, 1e-5)
+
def test_quadratic_bowl_10d(self):
"""Can minimize a ten dimensional quadratic function."""
dim = 10
diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py
index 0f946d664b..b2ae389cdb 100644
--- a/tensorflow_probability/python/stats/sample_stats.py
+++ b/tensorflow_probability/python/stats/sample_stats.py
@@ -319,8 +319,8 @@ def covariance(x,
cov_matrix = tfp.stats.covariance(x, y, sample_axis=0, event_axis=-1)
```
- Notice we divide by `N` (the numpy default), which does not create `NaN`
- when `N = 1`, but is slightly biased.
+ Notice we divide by `N`, which does not create `NaN` when `N = 1`, but is
+ slightly biased.
Args:
x: A numeric `Tensor` holding samples.
diff --git a/tensorflow_probability/python/version.py b/tensorflow_probability/python/version.py
index bb2e0a94fa..0b75390eae 100644
--- a/tensorflow_probability/python/version.py
+++ b/tensorflow_probability/python/version.py
@@ -24,7 +24,7 @@
# stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a
# release branch, the current version is by default assumed to be a
# 'development' version, labeled 'dev'.
-_VERSION_SUFFIX = 'rc2'
+_VERSION_SUFFIX = 'rc4'
# Example, '0.4.0-dev'
__version__ = '.'.join([
diff --git a/tensorflow_probability/substrates/meta/rewrite.py b/tensorflow_probability/substrates/meta/rewrite.py
index a56d2eb002..fdd3fcb895 100644
--- a/tensorflow_probability/substrates/meta/rewrite.py
+++ b/tensorflow_probability/substrates/meta/rewrite.py
@@ -82,12 +82,13 @@
LIBS = ('bijectors', 'distributions', 'experimental', 'math', 'mcmc',
'optimizer', 'random', 'stats', 'util')
INTERNALS = ('assert_util', 'batched_rejection_sampler', 'cache_util',
+ 'callable_util',
'custom_gradient', 'distribution_util', 'dtype_util',
'hypothesis_testlib', 'implementation_selection', 'monte_carlo',
'name_util', 'nest_util', 'parameter_properties', 'prefer_static',
'samplers', 'special_math', 'structural_tuple', 'tensor_util',
'tensorshape_util', 'test_combinations', 'test_util', 'unnest',
- 'vectorization_util')
+ 'variadic_reduce', 'vectorization_util')
OPTIMIZERS = ('linesearch',)
LINESEARCH = ('internal',)
SAMPLERS = ('categorical', 'normal', 'poisson', 'uniform', 'shuffle')
@@ -189,7 +190,7 @@ def main(argv):
})
filename = argv[1]
- contents = open(filename).read()
+ contents = open(filename, encoding='utf-8').read()
if '__init__.py' in filename:
# Comment out items from __all__.
for pkg, disabled in disabled_by_pkg.items():
@@ -247,7 +248,7 @@ def disable_all(name):
print('# ' + '@' * 78)
print('\n# (This notice adds 10 to line numbering.)\n\n')
- print(contents)
+ print(contents, file=open(1, 'w', encoding='utf-8', closefd=False))
if __name__ == '__main__':
diff --git a/testing/install_test_dependencies.sh b/testing/install_test_dependencies.sh
index 91c4ed5529..5007fdb30f 100755
--- a/testing/install_test_dependencies.sh
+++ b/testing/install_test_dependencies.sh
@@ -77,6 +77,16 @@ else
TF_NIGHTLY_PACKAGE=tf-nightly-cpu
fi
+PYTHON_PARSE_PACKAGE_JSON="
+import sys
+import json
+package_data = json.loads(sys.stdin.read())
+linux_versions = []
+for release, release_info in package_data['releases'].items():
+ if any('linux' in wheel['filename'] for wheel in release_info):
+ print(release)
+"
+
find_good_tf_nightly_version_str() {
PKG_NAME=$1
# These are nightly builds we'd like to avoid for some reason; separated by
@@ -86,7 +96,9 @@ find_good_tf_nightly_version_str() {
# stderr. We then sort, remove bad versions and take the last entry. This
# allows us to avoid hardcoding the main version number, which would then need
# to be updated on every new TF release.
- python -m pip install $PKG_NAME==X 2>&1 \
+ VERSIONS=$(curl -s https://pypi.org/pypi/$PKG_NAME/json \
+ | python -c "$PYTHON_PARSE_PACKAGE_JSON")
+ echo $VERSIONS \
| grep -o "[0-9.]\+dev[0-9]\{8\}" \
| sort \
| grep -v "$BAD_NIGHTLY_DATES" \
diff --git a/testing/run_travis_lints.sh b/testing/run_github_lints.sh
similarity index 79%
rename from testing/run_travis_lints.sh
rename to testing/run_github_lints.sh
index fb8a3f636e..793db2119b 100755
--- a/testing/run_travis_lints.sh
+++ b/testing/run_github_lints.sh
@@ -18,13 +18,13 @@ set -v # print commands as they are executed
set -e # fail and exit on any command erroring
get_changed_py_files() {
- # Need to fetch the base branch in case it is not master.
- git remote set-branches --add origin ${TRAVIS_BRANCH}
- git fetch --depth=20 --quiet
- git diff \
- --name-only \
- --diff-filter=AM origin/${TRAVIS_BRANCH}...HEAD \
- | grep '^tensorflow_probability.*\.py$' || true
+ if [ $GITHUB_BASE_REF ]; then
+ git fetch origin ${GITHUB_BASE_REF} --depth=1
+ git diff \
+ --name-only \
+ --diff-filter=AM origin/${GITHUB_BASE_REF} \
+ | grep '^tensorflow_probability.*\.py$'
+ fi
}
pip install --quiet pylint
diff --git a/testing/run_travis_tests.sh b/testing/run_github_tests.sh
similarity index 84%
rename from testing/run_travis_tests.sh
rename to testing/run_github_tests.sh
index c5250d256b..76f1414d0b 100755
--- a/testing/run_travis_tests.sh
+++ b/testing/run_github_tests.sh
@@ -34,13 +34,6 @@ if [ -z "${NUM_SHARDS}" ]; then
exit -1
fi
-call_with_log_folding() {
- local command=$1
- echo "travis_fold:start:$command"
- $command
- echo "travis_fold:end:$command"
-}
-
install_bazel() {
# Install Bazel for tests. Based on instructions at
# https://docs.bazel.build/versions/master/install-ubuntu.html#install-on-ubuntu
@@ -63,8 +56,8 @@ install_python_packages() {
# Only install bazel if not already present (useful for locally testing this
# script).
-which bazel || call_with_log_folding install_bazel
-call_with_log_folding install_python_packages
+which bazel || install_bazel
+install_python_packages
test_tags_to_skip="(gpu|requires-gpu-nvidia|notap|no-oss-ci|tfp_jax|tf2-broken|tf2-kokoro-broken)"
@@ -87,9 +80,6 @@ sharded_tests="$(query_and_shard_tests_by_size small)"
sharded_tests="${sharded_tests} $(query_and_shard_tests_by_size medium)"
sharded_tests="${sharded_tests} $(query_and_shard_tests_by_size large)"
-# Run tests using run_tfp_test.sh script. We append the following flags:
-# --notest_keep_going -- stop running tests as soon as anything fails. This is
-# to minimize load on Travis, where we share a limited number of concurrent
-# jobs with a bunch of other TensorFlow projects.
+# Run tests using run_tfp_test.sh script.
echo "${sharded_tests}" \
- | xargs $DIR/run_tfp_test.sh --notest_keep_going
+ | xargs $DIR/run_tfp_test.sh