diff --git a/tensorflow_probability/python/internal/test_util.py b/tensorflow_probability/python/internal/test_util.py index ef10557afd..1e44f811cb 100644 --- a/tensorflow_probability/python/internal/test_util.py +++ b/tensorflow_probability/python/internal/test_util.py @@ -613,7 +613,7 @@ def assertAllInRange(self, 'The value of %s does not have an ordered numeric type, instead it ' 'has type: %s' % (target, target.dtype)) - nan_subscripts = np.where(np.isnan(target)) + nan_subscripts = np.where(np.atleast_1d(np.isnan(target))) if np.size(nan_subscripts): raise AssertionError( '%d of the %d element(s) are NaN. ' @@ -631,7 +631,7 @@ def assertAllInRange(self, violations, np.greater_equal(target, upper_bound) if open_upper_bound else np.greater(target, upper_bound)) - violation_subscripts = np.where(violations) + violation_subscripts = np.where(np.atleast_1d(violations)) if np.size(violation_subscripts): raise AssertionError( '%d of the %d element(s) are outside the range %s. ' %