Skip to content

Commit

Permalink
Rename interpolate_nondiscrete flag to `force_probs_to_zero_outside…
Browse files Browse the repository at this point in the history
…_support` (with the opposite sense) in Skellam.

The second name leaves room for other distributions to use the same
flag to control tf.where gates on extrapolating the support to, e.g.,
negative arguments.

Not going through the usual deprecation process because Skellam was
only contributed a few days ago, and in particular was not in any
stable release yet.

PiperOrigin-RevId: 341064310
  • Loading branch information
axch authored and tensorflower-gardener committed Nov 6, 2020
1 parent 7f9b9ba commit 2c75fb3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
18 changes: 9 additions & 9 deletions tensorflow_probability/python/distributions/skellam.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self,
rate2=None,
log_rate1=None,
log_rate2=None,
interpolate_nondiscrete=True,
force_probs_to_zero_outside_support=False,
validate_args=False,
allow_nan_stats=True,
name='Skellam'):
Expand All @@ -86,12 +86,12 @@ def __init__(self,
Must specify exactly one of `rate1` and `log_rate1`.
log_rate2: Floating point tensor, the log of the second rate parameter.
Must specify exactly one of `rate2` and `log_rate2`.
interpolate_nondiscrete: Python `bool`. When `False`,
force_probs_to_zero_outside_support: Python `bool`. When `True`,
`log_prob` returns `-inf` (and `prob` returns `0`) for non-integer
inputs. When `True`, `log_prob` evaluates the Skellam pmf as a
inputs. When `False`, `log_prob` evaluates the Skellam pmf as a
continuous function (note that this function is not itself
a normalized probability log-density).
Default value: `True`.
Default value: `False`.
validate_args: Python `bool`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
Expand Down Expand Up @@ -126,7 +126,7 @@ def __init__(self,
self._log_rate2 = tensor_util.convert_nonref_to_tensor(
log_rate2, name='log_rate2', dtype=dtype)

self._interpolate_nondiscrete = interpolate_nondiscrete
self._force_probs_to_zero_outside_support = force_probs_to_zero_outside_support
super(Skellam, self).__init__(
dtype=dtype,
reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
Expand Down Expand Up @@ -171,9 +171,9 @@ def log_rate2(self):
return self._log_rate2

@property
def interpolate_nondiscrete(self):
def force_probs_to_zero_outside_support(self):
"""Interpolate (log) probs on non-integer inputs."""
return self._interpolate_nondiscrete
return self._force_probs_to_zero_outside_support

def _batch_shape_tensor(self, log_rate1=None, log_rate2=None):
x1 = self._rate1 if self._log_rate1 is None else self._log_rate1
Expand All @@ -198,7 +198,7 @@ def _log_prob(self, x):
# Catch such x's and set the output value accordingly.
lr1, r1, lr2, r2 = self._all_rate_parameters()

safe_x = x if self.interpolate_nondiscrete else tf.floor(x)
safe_x = tf.floor(x) if self.force_probs_to_zero_outside_support else x
y = tf.math.multiply_no_nan(0.5 * (lr1 - lr2), safe_x)
numpy_dtype = dtype_util.as_numpy_dtype(y.dtype)

Expand All @@ -211,7 +211,7 @@ def _log_prob(self, x):
safe_x, 2. * tf.math.sqrt(r1 * r2)) - tf.math.square(
tf.math.sqrt(r1) - tf.math.sqrt(r2))
y = tf.where(tf.math.equal(x, safe_x), y, numpy_dtype(-np.inf))
if not self.interpolate_nondiscrete:
if self.force_probs_to_zero_outside_support:
# Ensure the gradient wrt `rate` is zero at non-integer points.
y = tf.where(
(y < 0.) & tf.math.is_inf(y), numpy_dtype(-np.inf), y)
Expand Down
19 changes: 10 additions & 9 deletions tensorflow_probability/python/distributions/skellam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def _make_skellam(self,
rate1,
rate2,
validate_args=True,
interpolate_nondiscrete=True):
return tfd.Skellam(rate1=rate1,
rate2=rate2,
validate_args=validate_args,
interpolate_nondiscrete=interpolate_nondiscrete)
force_probs_to_zero_outside_support=False):
return tfd.Skellam(
rate1=rate1,
rate2=rate2,
validate_args=validate_args,
force_probs_to_zero_outside_support=force_probs_to_zero_outside_support)

def testSkellamShape(self):
rate1 = tf.constant([3.0] * 5, dtype=self.dtype)
Expand Down Expand Up @@ -77,7 +78,7 @@ def testSkellamLogPmfDiscreteMatchesScipy(self):
dtype=self.dtype)
skellam = self._make_skellam(
rate1=rate1, rate2=rate2,
interpolate_nondiscrete=False, validate_args=False)
force_probs_to_zero_outside_support=True, validate_args=False)
log_pmf = skellam.log_prob(x)
self.assertEqual(log_pmf.shape, (2, batch_size))
self.assertAllClose(
Expand Down Expand Up @@ -124,7 +125,7 @@ def skellam_log_prob(lam):
return self._make_skellam(
rate1=rate1 if apply_to_second_rate else lam,
rate2=lam if apply_to_second_rate else rate2,
interpolate_nondiscrete=False,
force_probs_to_zero_outside_support=True,
validate_args=False).log_prob(x)
return skellam_log_prob
_, dlog_pmf_dlam = self.evaluate(tfp.math.value_and_gradient(
Expand Down Expand Up @@ -283,12 +284,12 @@ def _make_skellam(self,
rate1,
rate2,
validate_args=True,
interpolate_nondiscrete=True):
force_probs_to_zero_outside_support=False):
return tfd.Skellam(
log_rate1=tf.math.log(rate1),
log_rate2=tf.math.log(rate2),
validate_args=validate_args,
interpolate_nondiscrete=interpolate_nondiscrete)
force_probs_to_zero_outside_support=force_probs_to_zero_outside_support)

# No need to worry about the non-negativity of `rate` when using the
# `log_rate` parameterization.
Expand Down

0 comments on commit 2c75fb3

Please sign in to comment.