Skip to content

Commit

Permalink
Assorted JAX shape staging fixes for mcmc/diagnostic.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 398397197
  • Loading branch information
SiegeLordEx authored and jburnim committed Sep 30, 2021
1 parent 7089d1a commit cc77ec8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,7 @@ def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
tensorshape_util.rank(x.shape)
if tensorshape_util.rank(x.shape) is not None else tf.rank(
x, name='ndims'))
axis = tf.convert_to_tensor(axis, name='axis')
axis = ps.convert_to_shape_tensor(axis, name='axis')
axis_ = tf.get_static_value(axis)
if axis_ is not None:
axis = axis_
Expand Down
10 changes: 5 additions & 5 deletions tensorflow_probability/python/mcmc/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def potential_scale_reduction(chains_states,
# array) is not efficiently computable. Therefore, we try constant_value then
# check for None.
icn_const_ = tf.get_static_value(
tf.convert_to_tensor(independent_chain_ndims))
ps.convert_to_shape_tensor(independent_chain_ndims))
if icn_const_ is not None:
independent_chain_ndims = icn_const_
if icn_const_ < 1:
Expand Down Expand Up @@ -539,15 +539,15 @@ def _potential_scale_reduction_single_state(state, independent_chain_ndims,
state = tf.transpose(
a=state,
perm=ps.concat(
[[1, 0], tf.range(2, tf.rank(state))], axis=0))
[[1, 0], ps.range(2, ps.rank(state))], axis=0))

# We're treating the new dim as indexing 2 chains, so increment.
independent_chain_ndims += 1

sample_axis = tf.range(0, sample_ndims)
chain_axis = tf.range(sample_ndims,
sample_axis = ps.range(0, sample_ndims)
chain_axis = ps.range(sample_ndims,
sample_ndims + independent_chain_ndims)
sample_and_chain_axis = tf.range(
sample_and_chain_axis = ps.range(
0, sample_ndims + independent_chain_ndims)

n = _axis_size(state, sample_axis)
Expand Down
8 changes: 4 additions & 4 deletions tensorflow_probability/python/stats/sample_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def auto_correlation(x,
if max_lags is None:
max_lags = x_len - 1
else:
max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
max_lags = ps.convert_to_shape_tensor(max_lags, name='max_lags')
max_lags_ = tf.get_static_value(max_lags)
if max_lags_ is None or not know_static_shape:
know_static_shape = False
Expand Down Expand Up @@ -285,7 +285,7 @@ def cholesky_covariance(x, sample_axis=0, keepdims=False, name=None):
lower triangular matrices (the Cholesky factors).
"""
with tf.name_scope(name or 'cholesky_covariance'):
sample_axis = tf.convert_to_tensor(sample_axis, dtype=tf.int32)
sample_axis = ps.convert_to_shape_tensor(sample_axis, dtype=tf.int32)
cov = covariance(
x, sample_axis=sample_axis, event_axis=-1, keepdims=keepdims)
return tf.linalg.cholesky(cov)
Expand Down Expand Up @@ -971,10 +971,10 @@ def log_average_probs(logits, sample_axis=0, event_axis=None, keepdims=False,
with tf.name_scope(name or 'average_sigmoid'):
logits = tf.convert_to_tensor(logits, dtype_hint=tf.float32, name='logits')
if sample_axis is not None:
sample_axis = tf.convert_to_tensor(
sample_axis = ps.convert_to_shape_tensor(
sample_axis, dtype_hint=tf.int32, name='sample_axis')
if event_axis is not None:
event_axis = tf.convert_to_tensor(
event_axis = ps.convert_to_shape_tensor(
event_axis, dtype_hint=tf.int32, name='event_axis')
if event_axis is None:
# log(sigmoid(x)) = log(1 / (1 + exp(-x))) = -log1p(exp(-x)) = -sp(-x)
Expand Down

0 comments on commit cc77ec8

Please sign in to comment.