Skip to content

Commit 6a18217

Browse files
committed
Allow lower rank indices
1 parent a602a8c commit 6a18217

File tree

2 files changed

+68
-15
lines changed

2 files changed

+68
-15
lines changed

tensorflow_probability/python/stats/sample_stats.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -735,12 +735,18 @@ def windowed_variance(
735735
last half of an MCMC chain.
736736
737737
Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
738-
have shape `Bi + [M] + F`, such that:
739-
- `rank(Bx) = rank(Bi) = axis`,
740-
- `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
738+
have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`.
741739
Then each element of `low_indices` and `high_indices` must be
742740
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
743741
742+
The shape of indices must be broadcastable with `x` unless the rank is lower
743+
than the rank of `x`, then the shape is expanded with extra inner dimensions
744+
to match the rank of `x`.
745+
746+
In the special case where the rank of indices is one, i.e when
747+
`rank(Bi) = rank(F) = 0`, the indices are reshaped to
748+
`[1] * rank(Bx) + [M] + [1] * rank(E)`.
749+
744750
The default windows are
745751
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
746752
This corresponds to analyzing `x` as though it were streaming, for
@@ -850,12 +856,18 @@ def windowed_mean(
850856
last half of an MCMC chain.
851857
852858
Suppose `x` has shape `Bx + [N] + E`, `low_indices` and `high_indices`
853-
have shape `Bi + [M] + F`, such that:
854-
- `rank(Bx) = rank(Bi) = axis`,
855-
- `Bi + [1] + F` broadcasts to `Bx + [N] + E`.
859+
have shape `Bi + [M] + F`, such that `rank(Bx) = rank(Bi) = axis`.
856860
Then each element of `low_indices` and `high_indices` must be
857861
between 0 and N+1, and the shape of the output will be `Bx + [M] + E`.
858862
863+
The shape of indices must be broadcastable with `x` unless the rank is lower
864+
than the rank of `x`, then the shape is expanded with extra inner dimensions
865+
to match the rank of `x`.
866+
867+
In the special case where the rank of indices is one, i.e when
868+
`rank(Bi) = rank(F) = 0`, the indices are reshaped to
869+
`[1] * rank(Bx) + [M] + [1] * rank(E)`.
870+
859871
The default windows are
860872
`[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...`
861873
This corresponds to analyzing `x` as though it were streaming, for
@@ -913,17 +925,26 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
913925
# Broadcast indices together.
914926
high_indices = high_indices + tf.zeros_like(low_indices)
915927
low_indices = low_indices + tf.zeros_like(high_indices)
916-
indices = ps.stack([low_indices, high_indices], axis=0)
928+
929+
indices_shape = ps.shape(low_indices)
930+
if ps.rank(low_indices) < ps.rank(x):
931+
if ps.rank(low_indices) == 1:
932+
size = ps.size(low_indices)
933+
bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size,
934+
off_value=1)
935+
else:
936+
# we assume the first dimensions are broadcastable with `x`,
937+
# we add trailing dimensions
938+
extra_dims = ps.rank(x) - ps.rank(low_indices)
939+
bc_shape = ps.concat([indices_shape, [1]*extra_dims], axis=0)
940+
else:
941+
bc_shape = indices_shape
942+
943+
bc_shape = ps.concat([[2], bc_shape], axis=0)
944+
indices = tf.stack([low_indices, high_indices], axis=0)
945+
indices = ps.reshape(indices, bc_shape)
917946
x = tf.expand_dims(x, axis=0)
918947
axis += 1
919-
920-
if ps.rank(indices) != ps.rank(x) and ps.rank(indices) == 2:
921-
# legacy usage, kept for backward compatibility
922-
size = ps.size(indices) // 2
923-
bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size,
924-
off_value=1)
925-
bc_shape = ps.concat([[2], bc_shape[1:]], axis=0)
926-
indices = ps.reshape(indices, bc_shape)
927948
# `take_along_axis` requires the type to be int32
928949
indices = ps.cast(indices, dtype=tf.int32)
929950
return x, indices, axis

tensorflow_probability/python/stats/sample_stats_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,19 @@ def test_windowed_mean_corner_cases(self):
685685

686686
@test_util.test_all_tf_execution_regimes
687687
class WindowedStatsTest(test_util.TestCase):
688+
689+
def _maybe_expand_dims_to_make_broadcastable(self, x, shape, axis):
690+
if len(shape) > len(x.shape):
691+
if len(x.shape) == 1:
692+
bc_shape = np.ones(len(shape), dtype=np.int32)
693+
bc_shape[axis] = x.shape[0]
694+
return x.reshape(bc_shape)
695+
else:
696+
extra_dims = len(shape) - len(x.shape)
697+
bc_shape = x.shape + (1,) * extra_dims
698+
return x.reshape(bc_shape)
699+
return x
700+
688701
def apply_slice_along_axis(self, func, arr, low, high, axis):
689702
"""Applies `func` over slices of `arr` along `axis`. Slices intervals are
690703
specified through `low` and `high`. Support broadcasting.
@@ -709,6 +722,7 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
709722
for r in range(j):
710723
out_1d[r] = func(a_1d[low_1d[r]:high_1d[r]])
711724
return out
725+
712726
def check_gaussian_windowed(self, shape, indice_shape, axis,
713727
window_func, np_func):
714728
stat_shape = np.array(shape).astype(np.int32)
@@ -721,6 +735,10 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
721735
indices = rng.randint(shape[axis] + 1, size=indice_shape)
722736
indices = np.sort(indices, axis=0)
723737
low_indices, high_indices = indices[0], indices[1]
738+
low_indices = self._maybe_expand_dims_to_make_broadcastable(
739+
low_indices, x.shape, axis)
740+
high_indices = self._maybe_expand_dims_to_make_broadcastable(
741+
high_indices, x.shape, axis)
724742
a = window_func(x, low_indices=low_indices,
725743
high_indices=high_indices, axis=axis)
726744
b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices,
@@ -736,20 +754,34 @@ def check_windowed(self, func, numpy_func):
736754
check_fn((64, 4, 8), (32, 4, 1), axis=0)
737755
check_fn((64, 4, 8), (32, 4, 8), axis=0)
738756
check_fn((64, 4, 8), (64, 4, 8), axis=0)
757+
check_fn((64, 4, 8), (128, 1), axis=0)
758+
check_fn((64, 4, 8), (32,), axis=0)
759+
check_fn((64, 4, 8), (32, 4), axis=0)
760+
739761
check_fn((64, 4, 8), (64, 64, 1), axis=1)
740762
check_fn((64, 4, 8), (1, 64, 1), axis=1)
741763
check_fn((64, 4, 8), (64, 2, 8), axis=1)
742764
check_fn((64, 4, 8), (64, 4, 8), axis=1)
765+
check_fn((64, 4, 8), (16,), axis=1)
766+
check_fn((64, 4, 8), (1, 64), axis=1)
767+
743768
check_fn((64, 4, 8), (64, 4, 64), axis=2)
744769
check_fn((64, 4, 8), (1, 1, 64), axis=2)
745770
check_fn((64, 4, 8), (64, 4, 4), axis=2)
746771
check_fn((64, 4, 8), (1, 1, 4), axis=2)
747772
check_fn((64, 4, 8), (64, 4, 8), axis=2)
773+
check_fn((64, 4, 8), (16,), axis=2)
774+
check_fn((64, 4, 8), (1, 4), axis=2)
775+
check_fn((64, 4, 8), (64, 4), axis=2)
748776

749777
with self.assertRaises(Exception):
750778
# Non broadcastable shapes
751779
check_fn((64, 4, 8), (4, 1, 4), axis=2)
752780

781+
with self.assertRaises(Exception):
782+
# Non broadcastable shapes
783+
check_fn((64, 4, 8), (2, 4), axis=2)
784+
753785
def test_windowed_mean(self):
754786
self.check_windowed(func=tfp.stats.windowed_mean, numpy_func=np.mean)
755787

0 commit comments

Comments
 (0)