Skip to content

Commit

Permalink
Infer the log-deg-jacobian of scalar bijectors using autodiff, if not…
Browse files Browse the repository at this point in the history
… otherwise specified.

This fixes http://github.com/tensorflow/probability/issues/573 in the scalar case.

PiperOrigin-RevId: 341531582
  • Loading branch information
davmre authored and tensorflower-gardener committed Nov 10, 2020
1 parent 0b8af93 commit de48765
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 8 deletions.
1 change: 1 addition & 0 deletions tensorflow_probability/python/bijectors/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ multi_substrate_py_library(
"//tensorflow_probability/python/internal:name_util",
"//tensorflow_probability/python/internal:nest_util",
"//tensorflow_probability/python/internal:tensorshape_util",
"//tensorflow_probability/python/math:gradient",
],
)

Expand Down
19 changes: 19 additions & 0 deletions tensorflow_probability/python/bijectors/bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tensorflow_probability.python.internal import name_util
from tensorflow_probability.python.internal import nest_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.math import gradient
from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import


Expand Down Expand Up @@ -604,6 +605,11 @@ def _is_injective(self):
"""
return True

@property
def _is_scalar(self):
return (tf.get_static_value(self._forward_min_event_ndims) == 0 and
tf.get_static_value(self._inverse_min_event_ndims) == 0)

@property
def validate_args(self):
"""Returns True if Tensor arguments will be validated."""
Expand Down Expand Up @@ -1033,6 +1039,8 @@ def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
elif hasattr(self, '_forward_log_det_jacobian'):
x = self.inverse(y, **kwargs) # Fall back to computing `-fldj(x)`
ildj = attrs['ildj'] = -self._forward_log_det_jacobian(x, **kwargs)
elif self._is_scalar:
ildj = _autodiff_log_det_jacobian(self._inverse, y)
else:
raise NotImplementedError(
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian '
Expand Down Expand Up @@ -1136,6 +1144,8 @@ def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
elif hasattr(self, '_inverse_log_det_jacobian'):
y = self.forward(x, **kwargs) # Fall back to computing `ildj(y)`
ildj = attrs['ildj'] = self._inverse_log_det_jacobian(y, **kwargs)
elif self._is_scalar:
ildj = -_autodiff_log_det_jacobian(self._forward, x)
else:
raise NotImplementedError(
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian '
Expand Down Expand Up @@ -1670,3 +1680,12 @@ def ldj_reduction_shape(shape_structure,
'LDJ reduction shape.')))

return ldj_reduce_shape, assertions


def _autodiff_log_det_jacobian(fn, x):
"""Automatically compute the log det jacobian of a scalar function."""
_, grads = gradient.value_and_gradient(fn, x)
if grads is None:
raise ValueError('Cannot compute log det jacobian; function {} has `None` '
'gradient.'.format(fn))
return tf.math.log(tf.abs(grads))
51 changes: 49 additions & 2 deletions tensorflow_probability/python/bijectors/bijector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def __init__(self):

with self.assertRaisesRegexp(
NotImplementedError,
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian.*'):
'inverse not implemented'):
bij.inverse_log_det_jacobian(0, event_ndims=0)

with self.assertRaisesRegexp(
NotImplementedError,
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian.*'):
'forward not implemented'):
bij.forward_log_det_jacobian(0, event_ndims=0)

@test_util.disable_test_for_backend(
Expand Down Expand Up @@ -124,6 +124,53 @@ def _forward(self, x):
error_clazz, 'Tensor conversion requested dtype'):
b64.forward(x32)

@test_util.numpy_disable_gradient_test
def testAutodiffLogDetJacobian(self):

class NoJacobianBijector(tfb.Bijector):
"""Bijector with no log det jacobian methods."""

def __init__(self, scale=2.):
parameters = dict(locals())
self._scale = tensor_util.convert_nonref_to_tensor(scale)
super(NoJacobianBijector, self).__init__(
validate_args=True,
forward_min_event_ndims=0,
parameters=parameters)

def _forward(self, x):
return tf.exp(self._scale * x)

def _inverse(self, y):
return tf.math.log(y) / self._scale

b = NoJacobianBijector(scale=1.4)
x = tf.convert_to_tensor([2., -3.])
[
fldj,
true_fldj,
ildj
] = self.evaluate([
b.forward_log_det_jacobian(x, event_ndims=0),
tf.math.log(b._scale) + b._scale * x,
b.inverse_log_det_jacobian(b.forward(x), event_ndims=0)
])
self.assertAllClose(fldj, true_fldj)
self.assertAllClose(fldj, -ildj)

y = tf.convert_to_tensor([27., 5.])
[
ildj,
true_ildj,
fldj
] = self.evaluate([
b.inverse_log_det_jacobian(y, event_ndims=0),
-tf.math.log(tf.abs(y * b._scale)),
b.forward_log_det_jacobian(b.inverse(y), event_ndims=0)
])
self.assertAllClose(ildj, true_ildj)
self.assertAllClose(ildj, -fldj)


class IntentionallyMissingError(Exception):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import tensorflow.compat.v2 as tf

from tensorflow_probability.python.bijectors.kumaraswamy_cdf import KumaraswamyCDF
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.math.psd_kernels import feature_transformed
Expand Down Expand Up @@ -62,6 +61,11 @@ def __init__(self,
name: Python `str` name given to ops managed by this object.
"""
parameters = dict(locals())

# Delayed import to avoid circular dependency between `tfp.bijectors` and
# `tfp.math`
from tensorflow_probability.python.bijectors import kumaraswamy_cdf # pylint: disable=g-import-not-at-top

with tf.name_scope(name):
self._concentration1 = tensor_util.convert_nonref_to_tensor(
concentration1, name='concentration1')
Expand All @@ -78,9 +82,9 @@ def transform_by_kumaraswamy(x, feature_ndims, example_ndims):
self.concentration0,
example_ndims,
start=-(feature_ndims + 1))
bij = KumaraswamyCDF(concentration1,
concentration0,
validate_args=validate_args)
bij = kumaraswamy_cdf.KumaraswamyCDF(concentration1,
concentration0,
validate_args=validate_args)
return bij.forward(x)

super(KumaraswamyTransformed, self).__init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import functools

import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import cholesky_outer_product
from tensorflow_probability.python.bijectors import invert
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import tensor_util
Expand Down Expand Up @@ -203,6 +201,13 @@ def __init__(self,
Default value: `"SchurComplement"`
"""
parameters = dict(locals())

# Delayed import to avoid circular dependency between `tfp.bijectors` and
# `tfp.math`
# pylint: disable=g-import-not-at-top
from tensorflow_probability.python.bijectors import cholesky_outer_product
from tensorflow_probability.python.bijectors import invert
# pylint: enable=g-import-not-at-top
with tf.name_scope(name) as name:
dtype = dtype_util.common_dtype(
[base_kernel, fixed_inputs, diag_shift], tf.float32)
Expand Down

0 comments on commit de48765

Please sign in to comment.