Skip to content

Commit

Permalink
Test against tensors with dynamic shapes
Browse files Browse the repository at this point in the history
Some `tensorflow` to `prefer_static` replacement
  • Loading branch information
nicolaspi committed Aug 10, 2022
1 parent 6a18217 commit 8d20563
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 21 deletions.
36 changes: 17 additions & 19 deletions tensorflow_probability/python/stats/sample_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import tensorflow.compat.v2 as tf

if JAX_MODE or NUMPY_MODE:
tnp = np
numpy_ops = np
else:
import tensorflow.experimental.numpy as tnp
from tensorflow.python.ops import numpy_ops

from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import distribution_util
Expand Down Expand Up @@ -739,13 +739,12 @@ def windowed_variance(
Then each element of `low_indices` and `high_indices` must be
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
The shape of indices must be broadcastable with `x` unless the rank is lower
than the rank of `x`, then the shape is expanded with extra inner dimensions
to match the rank of `x`.
The shape `Bi + [1] + F` must be broadcastable with the shape of `x`.
In the special case where the rank of indices is one, i.e when
`rank(Bi) = rank(F) = 0`, the indices are reshaped to
`[1] * rank(Bx) + [M] + [1] * rank(E)`.
If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded
with extra inner dimensions to match the rank of `x`. In the special
case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`,
the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`.
The default windows are
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
Expand Down Expand Up @@ -801,10 +800,10 @@ def windowed_variance(
def index_for_cumulative(indices):
return tf.maximum(indices - 1, 0)
cum_sums = tf.cumsum(x, axis=axis)
sums = tnp.take_along_axis(
sums = numpy_ops.take_along_axis(
cum_sums, index_for_cumulative(indices), axis=axis)
cum_variances = cumulative_variance(x, sample_axis=axis)
variances = tnp.take_along_axis(
variances = numpy_ops.take_along_axis(
cum_variances, index_for_cumulative(indices), axis=axis)

# This formula is the binary accurate variance merge from [1],
Expand Down Expand Up @@ -860,13 +859,12 @@ def windowed_mean(
Then each element of `low_indices` and `high_indices` must be
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
The shape of indices must be broadcastable with `x` unless the rank is lower
than the rank of `x`, then the shape is expanded with extra inner dimensions
to match the rank of `x`.
The shape `Bi + [1] + F` must be broadcastable with the shape of `x`.
In the special case where the rank of indices is one, i.e when
`rank(Bi) = rank(F) = 0`, the indices are reshaped to
`[1] * rank(Bx) + [M] + [1] * rank(E)`.
If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded
with extra inner dimensions to match the rank of `x`. In the special
case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`,
the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`.
The default windows are
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
Expand Down Expand Up @@ -906,7 +904,7 @@ def windowed_mean(
paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32),
(rank, 2))
cum_sums = ps.pad(raw_cumsum, paddings)
sums = tnp.take_along_axis(cum_sums, indices,
sums = numpy_ops.take_along_axis(cum_sums, indices,
axis=axis)
counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype)
return tf.math.divide_no_nan(sums[1] - sums[0], counts)
Expand All @@ -915,7 +913,7 @@ def windowed_mean(
def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
"""Common argument defaulting logic for windowed statistics."""
if high_indices is None:
high_indices = tf.range(ps.shape(x)[axis]) + 1
high_indices = ps.range(ps.shape(x)[axis]) + 1
else:
high_indices = tf.convert_to_tensor(high_indices)
if low_indices is None:
Expand All @@ -941,7 +939,7 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
bc_shape = indices_shape

bc_shape = ps.concat([[2], bc_shape], axis=0)
indices = tf.stack([low_indices, high_indices], axis=0)
indices = ps.stack([low_indices, high_indices], axis=0)
indices = ps.reshape(indices, bc_shape)
x = tf.expand_dims(x, axis=0)
axis += 1
Expand Down
13 changes: 11 additions & 2 deletions tensorflow_probability/python/stats/sample_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,17 +735,26 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
indices = rng.randint(shape[axis] + 1, size=indice_shape)
indices = np.sort(indices, axis=0)
low_indices, high_indices = indices[0], indices[1]

tf_low_indices = self._make_dynamic_shape(low_indices)
tf_high_indices = self._make_dynamic_shape(high_indices)
tf_x = self._make_dynamic_shape(x)

a = window_func(tf_x, low_indices=tf_low_indices,
high_indices=tf_high_indices, axis=axis)

low_indices = self._maybe_expand_dims_to_make_broadcastable(
low_indices, x.shape, axis)
high_indices = self._maybe_expand_dims_to_make_broadcastable(
high_indices, x.shape, axis)
a = window_func(x, low_indices=low_indices,
high_indices=high_indices, axis=axis)
b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices,
axis=axis)
b[np.isnan(b)] = 0 # We treat stats computed on empty sets as zeros
self.assertAllClose(a, b)

def _make_dynamic_shape(self, x):
return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape))

def check_windowed(self, func, numpy_func):
check_fn = functools.partial(self.check_gaussian_windowed,
window_func=func, np_func=numpy_func)
Expand Down

0 comments on commit 8d20563

Please sign in to comment.