Skip to content

Commit cc77ec8

Browse files
SiegeLordExjburnim
authored andcommitted
Assorted JAX shape staging fixes for mcmc/diagnostic.py.
PiperOrigin-RevId: 398397197
1 parent 7089d1a commit cc77ec8

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

tensorflow_probability/python/internal/distribution_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ def pad(x, axis, front=False, back=False, value=0, count=1, name=None):
12381238
tensorshape_util.rank(x.shape)
12391239
if tensorshape_util.rank(x.shape) is not None else tf.rank(
12401240
x, name='ndims'))
1241-
axis = tf.convert_to_tensor(axis, name='axis')
1241+
axis = ps.convert_to_shape_tensor(axis, name='axis')
12421242
axis_ = tf.get_static_value(axis)
12431243
if axis_ is not None:
12441244
axis = axis_

tensorflow_probability/python/mcmc/diagnostic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def potential_scale_reduction(chains_states,
461461
# array) is not efficiently computable. Therefore, we try constant_value then
462462
# check for None.
463463
icn_const_ = tf.get_static_value(
464-
tf.convert_to_tensor(independent_chain_ndims))
464+
ps.convert_to_shape_tensor(independent_chain_ndims))
465465
if icn_const_ is not None:
466466
independent_chain_ndims = icn_const_
467467
if icn_const_ < 1:
@@ -539,15 +539,15 @@ def _potential_scale_reduction_single_state(state, independent_chain_ndims,
539539
state = tf.transpose(
540540
a=state,
541541
perm=ps.concat(
542-
[[1, 0], tf.range(2, tf.rank(state))], axis=0))
542+
[[1, 0], ps.range(2, ps.rank(state))], axis=0))
543543

544544
# We're treating the new dim as indexing 2 chains, so increment.
545545
independent_chain_ndims += 1
546546

547-
sample_axis = tf.range(0, sample_ndims)
548-
chain_axis = tf.range(sample_ndims,
547+
sample_axis = ps.range(0, sample_ndims)
548+
chain_axis = ps.range(sample_ndims,
549549
sample_ndims + independent_chain_ndims)
550-
sample_and_chain_axis = tf.range(
550+
sample_and_chain_axis = ps.range(
551551
0, sample_ndims + independent_chain_ndims)
552552

553553
n = _axis_size(state, sample_axis)

tensorflow_probability/python/stats/sample_stats.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def auto_correlation(x,
183183
if max_lags is None:
184184
max_lags = x_len - 1
185185
else:
186-
max_lags = tf.convert_to_tensor(max_lags, name='max_lags')
186+
max_lags = ps.convert_to_shape_tensor(max_lags, name='max_lags')
187187
max_lags_ = tf.get_static_value(max_lags)
188188
if max_lags_ is None or not know_static_shape:
189189
know_static_shape = False
@@ -285,7 +285,7 @@ def cholesky_covariance(x, sample_axis=0, keepdims=False, name=None):
285285
lower triangular matrices (the Cholesky factors).
286286
"""
287287
with tf.name_scope(name or 'cholesky_covariance'):
288-
sample_axis = tf.convert_to_tensor(sample_axis, dtype=tf.int32)
288+
sample_axis = ps.convert_to_shape_tensor(sample_axis, dtype=tf.int32)
289289
cov = covariance(
290290
x, sample_axis=sample_axis, event_axis=-1, keepdims=keepdims)
291291
return tf.linalg.cholesky(cov)
@@ -971,10 +971,10 @@ def log_average_probs(logits, sample_axis=0, event_axis=None, keepdims=False,
971971
with tf.name_scope(name or 'average_sigmoid'):
972972
logits = tf.convert_to_tensor(logits, dtype_hint=tf.float32, name='logits')
973973
if sample_axis is not None:
974-
sample_axis = tf.convert_to_tensor(
974+
sample_axis = ps.convert_to_shape_tensor(
975975
sample_axis, dtype_hint=tf.int32, name='sample_axis')
976976
if event_axis is not None:
977-
event_axis = tf.convert_to_tensor(
977+
event_axis = ps.convert_to_shape_tensor(
978978
event_axis, dtype_hint=tf.int32, name='event_axis')
979979
if event_axis is None:
980980
# log(sigmoid(x)) = log(1 / (1 + exp(-x))) = -log1p(exp(-x)) = -sp(-x)

0 commit comments

Comments
 (0)