diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 6ae297efea..174a8bb262 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -22,9 +22,9 @@ import tensorflow.compat.v2 as tf if JAX_MODE or NUMPY_MODE: - tnp = np + numpy_ops = np else: - import tensorflow.experimental.numpy as tnp + from tensorflow.python.ops import numpy_ops from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util @@ -739,13 +739,12 @@ def windowed_variance( Then each element of `low_indices` and `high_indices` must be between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. - The shape of indices must be broadcastable with `x` unless the rank is lower - than the rank of `x`, then the shape is expanded with extra inner dimensions - to match the rank of `x`. + The shape `Bi + [1] + F` must be broadcastable with the shape of `x`. - In the special case where the rank of indices is one, i.e when - `rank(Bi) = rank(F) = 0`, the indices are reshaped to - `[1] * rank(Bx) + [M] + [1] * rank(E)`. + If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded + with extra inner dimensions to match the rank of `x`. In the special + case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, + the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`. The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` @@ -801,10 +800,10 @@ def windowed_variance( def index_for_cumulative(indices): return tf.maximum(indices - 1, 0) cum_sums = tf.cumsum(x, axis=axis) - sums = tnp.take_along_axis( + sums = numpy_ops.take_along_axis( cum_sums, index_for_cumulative(indices), axis=axis) cum_variances = cumulative_variance(x, sample_axis=axis) - variances = tnp.take_along_axis( + variances = numpy_ops.take_along_axis( cum_variances, index_for_cumulative(indices), axis=axis) # This formula is the binary accurate variance merge from [1], @@ -860,13 +859,12 @@ def windowed_mean( Then each element of `low_indices` and `high_indices` must be between 0 and N+1, and the shape of the output will be `Bx + [M] + E`. - The shape of indices must be broadcastable with `x` unless the rank is lower - than the rank of `x`, then the shape is expanded with extra inner dimensions - to match the rank of `x`. + The shape `Bi + [1] + F` must be broadcastable with the shape of `x`. - In the special case where the rank of indices is one, i.e when - `rank(Bi) = rank(F) = 0`, the indices are reshaped to - `[1] * rank(Bx) + [M] + [1] * rank(E)`. + If `rank(Bi + [M] + F) < rank(x)`, then the indices are expanded + with extra inner dimensions to match the rank of `x`. In the special + case where the rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, + the indices are reshaped to `[1] * rank(Bx) + [M] + [1] * rank(E)`. The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` @@ -906,7 +904,7 @@ def windowed_mean( paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32), (rank, 2)) cum_sums = ps.pad(raw_cumsum, paddings) - sums = tnp.take_along_axis(cum_sums, indices, + sums = numpy_ops.take_along_axis(cum_sums, indices, axis=axis) counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype) return tf.math.divide_no_nan(sums[1] - sums[0], counts) @@ -915,7 +913,7 @@ def windowed_mean( def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): """Common argument defaulting logic for windowed statistics.""" if high_indices is None: - high_indices = tf.range(ps.shape(x)[axis]) + 1 + high_indices = ps.range(ps.shape(x)[axis]) + 1 else: high_indices = tf.convert_to_tensor(high_indices) if low_indices is None: @@ -941,7 +939,7 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): bc_shape = indices_shape bc_shape = ps.concat([[2], bc_shape], axis=0) - indices = tf.stack([low_indices, high_indices], axis=0) + indices = ps.stack([low_indices, high_indices], axis=0) indices = ps.reshape(indices, bc_shape) x = tf.expand_dims(x, axis=0) axis += 1 diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index 47d5dc29b8..edc6b6447d 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -735,17 +735,26 @@ def check_gaussian_windowed(self, shape, indice_shape, axis, indices = rng.randint(shape[axis] + 1, size=indice_shape) indices = np.sort(indices, axis=0) low_indices, high_indices = indices[0], indices[1] + + tf_low_indices = self._make_dynamic_shape(low_indices) + tf_high_indices = self._make_dynamic_shape(high_indices) + tf_x = self._make_dynamic_shape(x) + + a = window_func(tf_x, low_indices=tf_low_indices, + high_indices=tf_high_indices, axis=axis) + low_indices = self._maybe_expand_dims_to_make_broadcastable( low_indices, x.shape, axis) high_indices = self._maybe_expand_dims_to_make_broadcastable( high_indices, x.shape, axis) - a = window_func(x, low_indices=low_indices, - high_indices=high_indices, axis=axis) b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices, axis=axis) b[np.isnan(b)] = 0 # We treat stats computed on empty sets as zeros self.assertAllClose(a, b) + def _make_dynamic_shape(self, x): + return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape)) + def check_windowed(self, func, numpy_func): check_fn = functools.partial(self.check_gaussian_windowed, window_func=func, np_func=numpy_func)