@@ -685,6 +685,19 @@ def test_windowed_mean_corner_cases(self):
685
685
686
686
@test_util .test_all_tf_execution_regimes
687
687
class WindowedStatsTest (test_util .TestCase ):
688
+
689
+ def _maybe_expand_dims_to_make_broadcastable (self , x , shape , axis ):
690
+ if len (shape ) > len (x .shape ):
691
+ if len (x .shape ) == 1 :
692
+ bc_shape = np .ones (len (shape ), dtype = np .int32 )
693
+ bc_shape [axis ] = x .shape [0 ]
694
+ return x .reshape (bc_shape )
695
+ else :
696
+ extra_dims = len (shape ) - len (x .shape )
697
+ bc_shape = x .shape + (1 ,) * extra_dims
698
+ return x .reshape (bc_shape )
699
+ return x
700
+
688
701
def apply_slice_along_axis (self , func , arr , low , high , axis ):
689
702
"""Applies `func` over slices of `arr` along `axis`. Slices intervals are
690
703
specified through `low` and `high`. Support broadcasting.
@@ -709,6 +722,7 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
709
722
for r in range (j ):
710
723
out_1d [r ] = func (a_1d [low_1d [r ]:high_1d [r ]])
711
724
return out
725
+
712
726
def check_gaussian_windowed (self , shape , indice_shape , axis ,
713
727
window_func , np_func ):
714
728
stat_shape = np .array (shape ).astype (np .int32 )
@@ -721,6 +735,10 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
721
735
indices = rng .randint (shape [axis ] + 1 , size = indice_shape )
722
736
indices = np .sort (indices , axis = 0 )
723
737
low_indices , high_indices = indices [0 ], indices [1 ]
738
+ low_indices = self ._maybe_expand_dims_to_make_broadcastable (
739
+ low_indices , x .shape , axis )
740
+ high_indices = self ._maybe_expand_dims_to_make_broadcastable (
741
+ high_indices , x .shape , axis )
724
742
a = window_func (x , low_indices = low_indices ,
725
743
high_indices = high_indices , axis = axis )
726
744
b = self .apply_slice_along_axis (np_func , x , low_indices , high_indices ,
@@ -736,20 +754,34 @@ def check_windowed(self, func, numpy_func):
736
754
check_fn ((64 , 4 , 8 ), (32 , 4 , 1 ), axis = 0 )
737
755
check_fn ((64 , 4 , 8 ), (32 , 4 , 8 ), axis = 0 )
738
756
check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 0 )
757
+ check_fn ((64 , 4 , 8 ), (128 , 1 ), axis = 0 )
758
+ check_fn ((64 , 4 , 8 ), (32 ,), axis = 0 )
759
+ check_fn ((64 , 4 , 8 ), (32 , 4 ), axis = 0 )
760
+
739
761
check_fn ((64 , 4 , 8 ), (64 , 64 , 1 ), axis = 1 )
740
762
check_fn ((64 , 4 , 8 ), (1 , 64 , 1 ), axis = 1 )
741
763
check_fn ((64 , 4 , 8 ), (64 , 2 , 8 ), axis = 1 )
742
764
check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 1 )
765
+ check_fn ((64 , 4 , 8 ), (16 ,), axis = 1 )
766
+ check_fn ((64 , 4 , 8 ), (1 , 64 ), axis = 1 )
767
+
743
768
check_fn ((64 , 4 , 8 ), (64 , 4 , 64 ), axis = 2 )
744
769
check_fn ((64 , 4 , 8 ), (1 , 1 , 64 ), axis = 2 )
745
770
check_fn ((64 , 4 , 8 ), (64 , 4 , 4 ), axis = 2 )
746
771
check_fn ((64 , 4 , 8 ), (1 , 1 , 4 ), axis = 2 )
747
772
check_fn ((64 , 4 , 8 ), (64 , 4 , 8 ), axis = 2 )
773
+ check_fn ((64 , 4 , 8 ), (16 ,), axis = 2 )
774
+ check_fn ((64 , 4 , 8 ), (1 , 4 ), axis = 2 )
775
+ check_fn ((64 , 4 , 8 ), (64 , 4 ), axis = 2 )
748
776
749
777
with self .assertRaises (Exception ):
750
778
# Non broadcastable shapes
751
779
check_fn ((64 , 4 , 8 ), (4 , 1 , 4 ), axis = 2 )
752
780
781
+ with self .assertRaises (Exception ):
782
+ # Non broadcastable shapes
783
+ check_fn ((64 , 4 , 8 ), (2 , 4 ), axis = 2 )
784
+
753
785
def test_windowed_mean (self ):
754
786
self .check_windowed (func = tfp .stats .windowed_mean , numpy_func = np .mean )
755
787
0 commit comments