Skip to content

Commit

Permalink
Allow lower rank indices
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaspi committed Aug 8, 2022
1 parent a602a8c commit 6a18217
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 15 deletions.
51 changes: 36 additions & 15 deletions tensorflow_probability/python/stats/sample_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,12 +735,18 @@ def windowed_variance(
last half of an MCMC chain.
Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
have shape `Bi + [M] + F`, such that:
- `rank(Bx) = rank(Bi) = axis`,
- `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`.
Then each element of `low_indices` and `high_indices` must be
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
The shape of indices must be broadcastable with `x` unless the rank is lower
than the rank of `x`, then the shape is expanded with extra inner dimensions
to match the rank of `x`.
In the special case where the rank of indices is one, i.e when
`rank(Bi) = rank(F) = 0`, the indices are reshaped to
`[1] * rank(Bx) + [M] + [1] * rank(E)`.
The default windows are
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
This corresponds to analyzing `x` as though it were streaming, for
Expand Down Expand Up @@ -850,12 +856,18 @@ def windowed_mean(
last half of an MCMC chain.
Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
have shape `Bi + [M] + F`, such that:
- `rank(Bx) = rank(Bi) = axis`,
- `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`.
Then each element of `low_indices` and `high_indices` must be
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
The shape of indices must be broadcastable with `x` unless the rank is lower
than the rank of `x`, then the shape is expanded with extra inner dimensions
to match the rank of `x`.
In the special case where the rank of indices is one, i.e when
`rank(Bi) = rank(F) = 0`, the indices are reshaped to
`[1] * rank(Bx) + [M] + [1] * rank(E)`.
The default windows are
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
This corresponds to analyzing `x` as though it were streaming, for
Expand Down Expand Up @@ -913,17 +925,26 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
# Broadcast indices together.
high_indices = high_indices + tf.zeros_like(low_indices)
low_indices = low_indices + tf.zeros_like(high_indices)
indices = ps.stack([low_indices, high_indices], axis=0)

indices_shape = ps.shape(low_indices)
if ps.rank(low_indices) < ps.rank(x):
if ps.rank(low_indices) == 1:
size = ps.size(low_indices)
bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size,
off_value=1)
else:
# we assume the first dimensions are broadcastable with `x`,
# we add trailing dimensions
extra_dims = ps.rank(x) - ps.rank(low_indices)
bc_shape = ps.concat([indices_shape, [1]*extra_dims], axis=0)
else:
bc_shape = indices_shape

bc_shape = ps.concat([[2], bc_shape], axis=0)
indices = tf.stack([low_indices, high_indices], axis=0)
indices = ps.reshape(indices, bc_shape)
x = tf.expand_dims(x, axis=0)
axis += 1

if ps.rank(indices) != ps.rank(x) and ps.rank(indices) == 2:
# legacy usage, kept for backward compatibility
size = ps.size(indices) // 2
bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size,
off_value=1)
bc_shape = ps.concat([[2], bc_shape[1:]], axis=0)
indices = ps.reshape(indices, bc_shape)
# `take_along_axis` requires the type to be int32
indices = ps.cast(indices, dtype=tf.int32)
return x, indices, axis
Expand Down
32 changes: 32 additions & 0 deletions tensorflow_probability/python/stats/sample_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,19 @@ def test_windowed_mean_corner_cases(self):

@test_util.test_all_tf_execution_regimes
class WindowedStatsTest(test_util.TestCase):

def _maybe_expand_dims_to_make_broadcastable(self, x, shape, axis):
if len(shape) > len(x.shape):
if len(x.shape) == 1:
bc_shape = np.ones(len(shape), dtype=np.int32)
bc_shape[axis] = x.shape[0]
return x.reshape(bc_shape)
else:
extra_dims = len(shape) - len(x.shape)
bc_shape = x.shape + (1,) * extra_dims
return x.reshape(bc_shape)
return x

def apply_slice_along_axis(self, func, arr, low, high, axis):
"""Applies `func` over slices of `arr` along `axis`. Slices intervals are
specified through `low` and `high`. Support broadcasting.
Expand All @@ -709,6 +722,7 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
for r in range(j):
out_1d[r] = func(a_1d[low_1d[r]:high_1d[r]])
return out

def check_gaussian_windowed(self, shape, indice_shape, axis,
window_func, np_func):
stat_shape = np.array(shape).astype(np.int32)
Expand All @@ -721,6 +735,10 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
indices = rng.randint(shape[axis] + 1, size=indice_shape)
indices = np.sort(indices, axis=0)
low_indices, high_indices = indices[0], indices[1]
low_indices = self._maybe_expand_dims_to_make_broadcastable(
low_indices, x.shape, axis)
high_indices = self._maybe_expand_dims_to_make_broadcastable(
high_indices, x.shape, axis)
a = window_func(x, low_indices=low_indices,
high_indices=high_indices, axis=axis)
b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices,
Expand All @@ -736,20 +754,34 @@ def check_windowed(self, func, numpy_func):
check_fn((64, 4, 8), (32, 4, 1), axis=0)
check_fn((64, 4, 8), (32, 4, 8), axis=0)
check_fn((64, 4, 8), (64, 4, 8), axis=0)
check_fn((64, 4, 8), (128, 1), axis=0)
check_fn((64, 4, 8), (32,), axis=0)
check_fn((64, 4, 8), (32, 4), axis=0)

check_fn((64, 4, 8), (64, 64, 1), axis=1)
check_fn((64, 4, 8), (1, 64, 1), axis=1)
check_fn((64, 4, 8), (64, 2, 8), axis=1)
check_fn((64, 4, 8), (64, 4, 8), axis=1)
check_fn((64, 4, 8), (16,), axis=1)
check_fn((64, 4, 8), (1, 64), axis=1)

check_fn((64, 4, 8), (64, 4, 64), axis=2)
check_fn((64, 4, 8), (1, 1, 64), axis=2)
check_fn((64, 4, 8), (64, 4, 4), axis=2)
check_fn((64, 4, 8), (1, 1, 4), axis=2)
check_fn((64, 4, 8), (64, 4, 8), axis=2)
check_fn((64, 4, 8), (16,), axis=2)
check_fn((64, 4, 8), (1, 4), axis=2)
check_fn((64, 4, 8), (64, 4), axis=2)

with self.assertRaises(Exception):
# Non broadcastable shapes
check_fn((64, 4, 8), (4, 1, 4), axis=2)

with self.assertRaises(Exception):
# Non broadcastable shapes
check_fn((64, 4, 8), (2, 4), axis=2)

def test_windowed_mean(self):
self.check_windowed(func=tfp.stats.windowed_mean, numpy_func=np.mean)

Expand Down

0 comments on commit 6a18217

Please sign in to comment.