Skip to content

Commit c8ce9c9

Browse files
ArmavicaricardoV94
authored andcommitted
Clean up more warnings from the test suite
1 parent a7f5d22 commit c8ce9c9

7 files changed

+86
-43
lines changed

pymc/tests/test_distributions.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -2077,12 +2077,13 @@ def test_mvt(self, n):
20772077

20782078
@pytest.mark.parametrize("n", [2, 3])
20792079
def test_wishart(self, n):
2080-
check_logp(
2081-
Wishart,
2082-
PdMatrix(n),
2083-
{"nu": Domain([0, 3, 4, np.inf], "int64"), "V": PdMatrix(n)},
2084-
lambda value, nu, V: scipy.stats.wishart.logpdf(value, int(nu), V),
2085-
)
2080+
with pytest.warns(UserWarning, match="Wishart distribution can currently not be used"):
2081+
check_logp(
2082+
Wishart,
2083+
PdMatrix(n),
2084+
{"nu": Domain([0, 3, 4, np.inf], "int64"), "V": PdMatrix(n)},
2085+
lambda value, nu, V: scipy.stats.wishart.logpdf(value, int(nu), V),
2086+
)
20862087

20872088
@pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES)
20882089
def test_lkjcorr(self, x, eta, n, lp):

pymc/tests/test_distributions_random.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ class BaseTestDistributionRandom(SeededTest):
230230

231231
def test_distribution(self):
232232
self.validate_tests_list()
233-
self._instantiate_pymc_rv()
233+
if self.pymc_dist == pm.Wishart:
234+
with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):
235+
self._instantiate_pymc_rv()
236+
else:
237+
self._instantiate_pymc_rv()
234238
if self.reference_dist is not None:
235239
self.reference_dist_draws = self.reference_dist()(
236240
size=self.size, **self.reference_dist_params
@@ -240,7 +244,11 @@ def test_distribution(self):
240244
raise ValueError(
241245
"Custom check cannot start with `test_` or else it will be executed twice."
242246
)
243-
getattr(self, check_name)()
247+
if self.pymc_dist == pm.Wishart and check_name.startswith("check_rv_size"):
248+
with pytest.warns(UserWarning, match="can currently not be used for MCMC sampling"):
249+
getattr(self, check_name)()
250+
else:
251+
getattr(self, check_name)()
244252

245253
def _instantiate_pymc_rv(self, dist_params=None):
246254
params = dist_params if dist_params else self.pymc_dist_params

pymc/tests/test_distributions_timeseries.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
import aesara
1517
import numpy as np
1618
import pytest
@@ -125,7 +127,7 @@ def check_rv_inferred_size(self):
125127

126128
def test_steps_scalar_check(self):
127129
with pytest.raises(ValueError, match="steps must be an integer scalar"):
128-
self.pymc_dist.dist(steps=[1])
130+
self.pymc_dist.dist(steps=[1], init_dist=pm.DiracDelta.dist(0))
129131

130132
def test_gaussianrandomwalk_inference(self):
131133
mu, sigma, steps = 2, 1, 1000
@@ -136,7 +138,9 @@ def test_gaussianrandomwalk_inference(self):
136138
_sigma = pm.Uniform("sigma", 0, 10)
137139

138140
obs_data = pm.MutableData("obs_data", obs)
139-
grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs_data)
141+
grw = GaussianRandomWalk(
142+
"grw", _mu, _sigma, steps=steps, observed=obs_data, init_dist=Normal.dist(0, 100)
143+
)
140144

141145
trace = pm.sample(chains=1)
142146

@@ -147,26 +151,30 @@ def test_gaussianrandomwalk_inference(self):
147151
@pytest.mark.parametrize("init", [None, pm.Normal.dist()])
148152
def test_gaussian_random_walk_init_dist_shape(self, init):
149153
"""Test that init_dist is properly resized"""
150-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init)
151-
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
154+
with warnings.catch_warnings():
155+
warnings.filterwarnings("ignore", "Initial distribution not specified.*", UserWarning)
156+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init)
157+
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
152158

153-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init, size=(5,))
154-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
159+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init, size=(5,))
160+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
155161

156-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init, shape=2)
157-
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
162+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init, shape=2)
163+
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
158164

159-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init, shape=(5, 2))
160-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
165+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init_dist=init, shape=(5, 2))
166+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
161167

162-
grw = pm.GaussianRandomWalk.dist(mu=[0, 0], sigma=1, steps=1, init_dist=init)
163-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
168+
grw = pm.GaussianRandomWalk.dist(mu=[0, 0], sigma=1, steps=1, init_dist=init)
169+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
164170

165-
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=[1, 1], steps=1, init_dist=init)
166-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
171+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=[1, 1], steps=1, init_dist=init)
172+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
167173

168-
grw = pm.GaussianRandomWalk.dist(mu=np.zeros((3, 1)), sigma=[1, 1], steps=1, init_dist=init)
169-
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3, 2)
174+
grw = pm.GaussianRandomWalk.dist(
175+
mu=np.zeros((3, 1)), sigma=[1, 1], steps=1, init_dist=init
176+
)
177+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3, 2)
170178

171179
def test_shape_ellipsis(self):
172180
grw = pm.GaussianRandomWalk.dist(
@@ -184,28 +192,28 @@ def test_gaussianrandomwalk_broadcasted_by_init_dist(self):
184192

185193
@pytest.mark.parametrize("shape", ((6,), (3, 6)))
186194
def test_inferred_steps_from_shape(self, shape):
187-
x = GaussianRandomWalk.dist(shape=shape)
195+
x = GaussianRandomWalk.dist(shape=shape, init_dist=Normal.dist(0, 100))
188196
steps = x.owner.inputs[-1]
189197
assert steps.eval() == 5
190198

191199
@pytest.mark.parametrize("shape", (None, (5, ...)))
192200
def test_missing_steps(self, shape):
193201
with pytest.raises(ValueError, match="Must specify steps or shape parameter"):
194-
GaussianRandomWalk.dist(shape=shape)
202+
GaussianRandomWalk.dist(shape=shape, init_dist=Normal.dist(0, 100))
195203

196204
def test_inconsistent_steps_and_shape(self):
197205
with pytest.raises(AssertionError, match="Steps do not match last shape dimension"):
198-
x = GaussianRandomWalk.dist(steps=12, shape=45)
206+
x = GaussianRandomWalk.dist(steps=12, shape=45, init_dist=Normal.dist(0, 100))
199207

200208
def test_inferred_steps_from_dims(self):
201209
with pm.Model(coords={"batch": range(5), "steps": range(20)}):
202-
x = GaussianRandomWalk("x", dims=("batch", "steps"))
210+
x = GaussianRandomWalk("x", dims=("batch", "steps"), init_dist=Normal.dist(0, 100))
203211
steps = x.owner.inputs[-1]
204212
assert steps.eval() == 19
205213

206214
def test_inferred_steps_from_observed(self):
207215
with pm.Model():
208-
x = GaussianRandomWalk("x", observed=np.zeros(10))
216+
x = GaussianRandomWalk("x", observed=np.zeros(10), init_dist=Normal.dist(0, 100))
209217
steps = x.owner.inputs[-1]
210218
assert steps.eval() == 9
211219

pymc/tests/test_mixture.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,9 @@ def test_component_choice_random(self):
337337
@pytest.mark.parametrize(
338338
"comp_dists",
339339
(
340-
[Normal.dist(size=(2,))],
340+
Normal.dist(size=(2,)),
341341
[Normal.dist(), Normal.dist()],
342-
[MvNormal.dist(np.ones(3), np.eye(3), size=(2,))],
342+
MvNormal.dist(np.ones(3), np.eye(3), size=(2,)),
343343
[
344344
MvNormal.dist(np.ones(3), np.eye(3)),
345345
MvNormal.dist(np.ones(3), np.eye(3)),
@@ -348,7 +348,10 @@ def test_component_choice_random(self):
348348
)
349349
def test_components_expanded_by_weights(self, comp_dists):
350350
"""Test that components are expanded when size or weights are larger than components"""
351-
univariate = comp_dists[0].owner.op.ndim_supp == 0
351+
if isinstance(comp_dists, list):
352+
univariate = comp_dists[0].owner.op.ndim_supp == 0
353+
else:
354+
univariate = comp_dists.owner.op.ndim_supp == 0
352355

353356
mix = Mixture.dist(
354357
w=Dirichlet.dist([1, 1], shape=(3, 2)),
@@ -371,9 +374,9 @@ def test_components_expanded_by_weights(self, comp_dists):
371374
@pytest.mark.parametrize(
372375
"comp_dists",
373376
(
374-
[Normal.dist(size=(2,))],
377+
Normal.dist(size=(2,)),
375378
[Normal.dist(), Normal.dist()],
376-
[MvNormal.dist(np.ones(3), np.eye(3), size=(2,))],
379+
MvNormal.dist(np.ones(3), np.eye(3), size=(2,)),
377380
[
378381
MvNormal.dist(np.ones(3), np.eye(3)),
379382
MvNormal.dist(np.ones(3), np.eye(3)),
@@ -382,7 +385,10 @@ def test_components_expanded_by_weights(self, comp_dists):
382385
)
383386
@pytest.mark.parametrize("expand", (False, True))
384387
def test_change_size(self, comp_dists, expand):
385-
univariate = comp_dists[0].owner.op.ndim_supp == 0
388+
if isinstance(comp_dists, list):
389+
univariate = comp_dists[0].owner.op.ndim_supp == 0
390+
else:
391+
univariate = comp_dists.owner.op.ndim_supp == 0
386392

387393
mix = Mixture.dist(w=Dirichlet.dist([1, 1]), comp_dists=comp_dists)
388394
mix = Mixture.change_size(mix, new_size=(4,), expand=expand)
@@ -444,6 +450,7 @@ def test_single_poisson_sampling(self):
444450
step = Metropolis()
445451
with warnings.catch_warnings():
446452
warnings.filterwarnings("ignore", "More chains .* than draws.*", UserWarning)
453+
warnings.filterwarnings("ignore", "overflow encountered in exp", RuntimeWarning)
447454
trace = sample(
448455
5000,
449456
step,
@@ -467,6 +474,7 @@ def test_list_poissons_sampling(self):
467474
Mixture("x_obs", w, [Poisson.dist(mu[0]), Poisson.dist(mu[1])], observed=pois_x)
468475
with warnings.catch_warnings():
469476
warnings.filterwarnings("ignore", "More chains .* than draws.*", UserWarning)
477+
warnings.filterwarnings("ignore", "overflow encountered in exp", RuntimeWarning)
470478
trace = sample(
471479
5000,
472480
chains=1,
@@ -497,6 +505,7 @@ def test_list_normals_sampling(self):
497505
)
498506
with warnings.catch_warnings():
499507
warnings.filterwarnings("ignore", "More chains .* than draws.*", UserWarning)
508+
warnings.filterwarnings("ignore", "overflow encountered in exp", RuntimeWarning)
500509
trace = sample(
501510
5000,
502511
chains=1,
@@ -755,6 +764,7 @@ def test_normal_mixture_sampling(self):
755764
step = Metropolis()
756765
with warnings.catch_warnings():
757766
warnings.filterwarnings("ignore", "More chains .* than draws.*", UserWarning)
767+
warnings.filterwarnings("ignore", "overflow encountered in exp", RuntimeWarning)
758768
trace = sample(
759769
5000,
760770
step,
@@ -989,7 +999,8 @@ def test_with_multinomial(self, batch_shape):
989999
w = np.ones(self.mixture_comps) / self.mixture_comps
9901000
mixture_axis = len(batch_shape)
9911001
with Model() as model:
992-
comp_dists = Multinomial.dist(p=p, n=n, shape=(*batch_shape, self.mixture_comps, 3))
1002+
with pytest.warns(UserWarning, match="parameters sum up to"):
1003+
comp_dists = Multinomial.dist(p=p, n=n, shape=(*batch_shape, self.mixture_comps, 3))
9931004
mixture = Mixture(
9941005
"mixture",
9951006
w=w,

pymc/tests/test_model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import unittest
1515
import warnings
1616

17-
from functools import reduce
18-
1917
import aesara
2018
import aesara.sparse as sparse
2119
import aesara.tensor as at
@@ -354,7 +352,8 @@ def test_missing_data(self):
354352

355353
res = [gf(DictToArrayBijection.map(Point(pnt, model=m))) for i in range(5)]
356354

357-
assert reduce(lambda x, y: np.array_equal(x, y) and y, res) is not False
355+
# Assert that all the elements of res are equal
356+
assert res[1:] == res[:-1]
358357

359358
def test_aesara_switch_broadcast_edge_cases_1(self):
360359
# Tests against two subtle issues related to a previous bug in Theano

pymc/tests/test_ode.py

+12
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,9 @@ def system(y, t, p):
366366
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
367367
with warnings.catch_warnings():
368368
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
369+
warnings.filterwarnings(
370+
"ignore", "invalid value encountered in log", RuntimeWarning
371+
)
369372
idata = pm.sample(50, tune=0, chains=1)
370373

371374
assert idata.posterior["alpha"].shape == (1, 50)
@@ -399,6 +402,9 @@ def system(y, t, p):
399402
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
400403
with warnings.catch_warnings():
401404
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
405+
warnings.filterwarnings(
406+
"ignore", "invalid value encountered in log", RuntimeWarning
407+
)
402408
idata = pm.sample(50, tune=0, chains=1)
403409

404410
assert idata.posterior["alpha"].shape == (1, 50)
@@ -443,6 +449,9 @@ def system(y, t, p):
443449
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
444450
with warnings.catch_warnings():
445451
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
452+
warnings.filterwarnings(
453+
"ignore", "invalid value encountered in log", RuntimeWarning
454+
)
446455
idata = pm.sample(50, tune=0, chains=1)
447456

448457
assert idata.posterior["R"].shape == (1, 50)
@@ -486,6 +495,9 @@ def system(y, t, p):
486495
with aesara.config.change_flags(mode=fast_unstable_sampling_mode):
487496
with warnings.catch_warnings():
488497
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
498+
warnings.filterwarnings(
499+
"ignore", "invalid value encountered in log", RuntimeWarning
500+
)
489501
idata = pm.sample(50, tune=0, chains=1)
490502

491503
assert idata.posterior["beta"].shape == (1, 50)

pymc/tests/test_shape_handling.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,14 @@ class TestShapesBroadcasting:
110110
ids=str,
111111
)
112112
def test_type_check_raises(self, bad_input):
113-
with pytest.raises(TypeError):
114-
shapes_broadcasting(bad_input, tuple(), raise_exception=True)
115-
with pytest.raises(TypeError):
116-
shapes_broadcasting(bad_input, tuple(), raise_exception=False)
113+
with warnings.catch_warnings():
114+
warnings.filterwarnings(
115+
"ignore", ".*ragged nested sequences.*", np.VisibleDeprecationWarning
116+
)
117+
with pytest.raises(TypeError):
118+
shapes_broadcasting(bad_input, tuple(), raise_exception=True)
119+
with pytest.raises(TypeError):
120+
shapes_broadcasting(bad_input, tuple(), raise_exception=False)
117121

118122
def test_type_check_success(self):
119123
inputs = [3, 3.0, tuple(), [3], (3,), np.array(3), np.array([3])]

0 commit comments

Comments
 (0)