Skip to content

Commit 4e98ee2

Browse files
srvasudeCopybara-Service
authored andcommitted
Deflake linear_gaussian_ssm_test and log_normal_test
PiperOrigin-RevId: 205160259
1 parent c069ac5 commit 4e98ee2

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

tensorflow_probability/python/distributions/linear_gaussian_ssm_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,8 @@ def testTooManyDimsThrowsError(self):
518518

519519
full_batch_shape, dist = self.build_inputs([5, 4, 2, 3], [6, 5, 4, 2, 3])
520520

521-
with self.assertRaisesError("Cannot broadcast"):
521+
with self.assertRaisesError(
522+
"(Broadcasting is not supported|Cannot broadcast)"):
522523
self.maybe_evaluate(
523524
_augment_sample_shape(dist, full_batch_shape,
524525
validate_args=True))

tensorflow_probability/python/distributions/lognormal_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def testLogNormalStats(self):
5555
def testLogNormalSample(self):
5656
loc, scale = 1.5, 0.4
5757
dist = tfd.LogNormal(loc=loc, scale=scale)
58-
samples = self.evaluate(dist.sample(5000))
58+
samples = self.evaluate(dist.sample(6000, seed=1234))
5959
self.assertAllClose(np.mean(samples),
6060
self.evaluate(dist.mean()),
6161
atol=0.1)

0 commit comments

Comments
 (0)