Skip to content

Commit

Permalink
Fix TruncatedNormal's sample/log_prob dtype when jax_enable_x64=True.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 378583501
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Jun 10, 2021
1 parent 3427e09 commit 92497f4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,15 @@ def grad(dy):
return std_samples * scale[tf.newaxis] + loc[tf.newaxis]

def _log_prob(self, x):
np_dtype = dtype_util.as_numpy_dtype(x.dtype)
loc, scale, low, high = self._loc_scale_low_high()
log_prob = -(0.5 * tf.square(
(x - loc) / scale) + 0.5 * np.log(2. * np.pi) + tf.math.log(scale) +
log_prob = -(np_dtype(0.5) * tf.square(
(x - loc) / scale) + (0.5 * np.log(2. * np.pi)).astype(np_dtype) +
tf.math.log(scale) +
self._log_normalizer(loc=loc, scale=scale, low=low, high=high))
# p(x) is 0 outside the bounds.
bounded_log_prob = tf.where((x > high) | (x < low),
dtype_util.as_numpy_dtype(x.dtype)(-np.inf),
np_dtype(-np.inf),
log_prob)
return bounded_log_prob

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,10 @@ def _truncated_normal_jax(
import jax.random as jaxrand # pylint: disable=g-import-not-at-top
if seed is None:
raise ValueError('Must provide PRNGKey to sample in JAX.')
dtype = utils.common_dtype([means, stddevs, minvals, maxvals])
std_low = (minvals - means) / stddevs
std_high = (maxvals - means) / stddevs
std_samps = jaxrand.truncated_normal(seed, std_low, std_high, shape)
std_samps = jaxrand.truncated_normal(seed, std_low, std_high, shape, dtype)
return std_samps * stddevs + means


Expand Down

0 comments on commit 92497f4

Please sign in to comment.