Skip to content

Commit

Permalink
Deflake linear_gaussian_ssm_test and log_normal_test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 205160259
  • Loading branch information
srvasude authored and Copybara-Service committed Jul 18, 2018
1 parent c069ac5 commit 4e98ee2
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,8 @@ def testTooManyDimsThrowsError(self):

full_batch_shape, dist = self.build_inputs([5, 4, 2, 3], [6, 5, 4, 2, 3])

with self.assertRaisesError("Cannot broadcast"):
with self.assertRaisesError(
"(Broadcasting is not supported|Cannot broadcast)"):
self.maybe_evaluate(
_augment_sample_shape(dist, full_batch_shape,
validate_args=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def testLogNormalStats(self):
def testLogNormalSample(self):
loc, scale = 1.5, 0.4
dist = tfd.LogNormal(loc=loc, scale=scale)
samples = self.evaluate(dist.sample(5000))
samples = self.evaluate(dist.sample(6000, seed=1234))
self.assertAllClose(np.mean(samples),
self.evaluate(dist.mean()),
atol=0.1)
Expand Down

0 comments on commit 4e98ee2

Please sign in to comment.