Skip to content

Commit 0b191ad

Browse files
purna135Sayam753
andauthored
Allow for batched alpha in StickBreakingWeights (#6042)
Co-authored-by: Sayam Kumar <[email protected]>
1 parent c8ce9c9 commit 0b191ad

File tree

4 files changed

+77
-16
lines changed

4 files changed

+77
-16
lines changed

pymc/distributions/multivariate.py

+8-16
Original file line numberDiff line numberDiff line change
@@ -2200,32 +2200,23 @@ def make_node(self, rng, size, dtype, alpha, K):
22002200
alpha = at.as_tensor_variable(alpha)
22012201
K = at.as_tensor_variable(intX(K))
22022202

2203-
if alpha.ndim > 0:
2204-
raise ValueError("The concentration parameter needs to be a scalar.")
2205-
22062203
if K.ndim > 0:
22072204
raise ValueError("K must be a scalar.")
22082205

22092206
return super().make_node(rng, size, dtype, alpha, K)
22102207

2211-
def _infer_shape(self, size, dist_params, param_shapes=None):
2212-
alpha, K = dist_params
2213-
2214-
size = tuple(size)
2215-
2216-
return size + (K + 1,)
2208+
def _supp_shape_from_params(self, dist_params, **kwargs):
2209+
K = dist_params[1]
2210+
return (K + 1,)
22172211

22182212
@classmethod
22192213
def rng_fn(cls, rng, alpha, K, size):
22202214
if K < 0:
22212215
raise ValueError("K needs to be positive.")
22222216

2223-
if size is None:
2224-
size = (K,)
2225-
elif isinstance(size, int):
2226-
size = (size,) + (K,)
2227-
else:
2228-
size = tuple(size) + (K,)
2217+
size = to_tuple(size) if size is not None else alpha.shape
2218+
size = size + (K,)
2219+
alpha = alpha[..., np.newaxis]
22292220

22302221
betas = rng.beta(1, alpha, size=size)
22312222

@@ -2294,9 +2285,10 @@ def dist(cls, alpha, K, *args, **kwargs):
22942285
return super().dist([alpha, K], **kwargs)
22952286

22962287
def moment(rv, size, alpha, K):
2288+
alpha = alpha[..., np.newaxis]
22972289
moment = (alpha / (1 + alpha)) ** at.arange(K)
22982290
moment *= 1 / (1 + alpha)
2299-
moment = at.concatenate([moment, [(alpha / (1 + alpha)) ** K]], axis=-1)
2291+
moment = at.concatenate([moment, (alpha / (1 + alpha)) ** K], axis=-1)
23002292
if not rv_size_is_none(size):
23012293
moment_size = at.concatenate(
23022294
[

pymc/tests/test_distributions.py

+31
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from aeppl.logprob import ParameterValueError
2525
from aesara.tensor.random.utils import broadcast_params
2626

27+
from pymc.aesaraf import compile_pymc
2728
from pymc.distributions.continuous import get_tau_sigma
2829
from pymc.util import UNSET
2930

@@ -953,6 +954,17 @@ def test_hierarchical_obs_logp():
953954
assert not any(isinstance(o, RandomVariable) for o in ops)
954955

955956

957+
@pytest.fixture(scope="module")
958+
def stickbreakingweights_logpdf():
959+
_value = at.vector()
960+
_alpha = at.scalar()
961+
_k = at.iscalar()
962+
_logp = logp(StickBreakingWeights.dist(_alpha, _k), _value)
963+
core_fn = compile_pymc([_value, _alpha, _k], _logp)
964+
965+
return np.vectorize(core_fn, signature="(n),(),()->()")
966+
967+
956968
class TestMatchesScipy:
957969
def test_uniform(self):
958970
check_logp(
@@ -2318,6 +2330,25 @@ def test_stickbreakingweights_invalid(self):
23182330
assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf
23192331
assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf
23202332

2333+
@pytest.mark.parametrize(
2334+
"alpha,K",
2335+
[
2336+
(np.array([0.5, 1.0, 2.0]), 3),
2337+
(np.arange(1, 7, dtype="float64").reshape(2, 3), 5),
2338+
],
2339+
)
2340+
def test_stickbreakingweights_vectorized(self, alpha, K, stickbreakingweights_logpdf):
2341+
value = pm.StickBreakingWeights.dist(alpha, K).eval()
2342+
with Model():
2343+
sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None)
2344+
pt = {"sbw": value}
2345+
assert_almost_equal(
2346+
pm.logp(sbw, value).eval(),
2347+
stickbreakingweights_logpdf(value, alpha, K),
2348+
decimal=select_by_precision(float64=6, float32=2),
2349+
err_msg=str(pt),
2350+
)
2351+
23212352
@aesara.config.change_flags(compute_test_value="raise")
23222353
def test_categorical_bounds(self):
23232354
with Model():

pymc/tests/test_distributions_moments.py

+26
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,32 @@ def test_rice_moment(nu, sigma, size, expected):
11661166
fill_value=np.append((1 / 3) ** np.arange(5) * 2 / 3, (1 / 3) ** 5),
11671167
),
11681168
),
1169+
(
1170+
np.array([1, 3]),
1171+
11,
1172+
None,
1173+
np.array(
1174+
[
1175+
np.append((1 / 2) ** np.arange(11) * 1 / 2, (1 / 2) ** 11),
1176+
np.append((3 / 4) ** np.arange(11) * 1 / 4, (3 / 4) ** 11),
1177+
]
1178+
),
1179+
),
1180+
(
1181+
np.array([1, 3, 5]),
1182+
9,
1183+
(5, 3),
1184+
np.full(
1185+
shape=(5, 3, 10),
1186+
fill_value=np.array(
1187+
[
1188+
np.append((1 / 2) ** np.arange(9) * 1 / 2, (1 / 2) ** 9),
1189+
np.append((3 / 4) ** np.arange(9) * 1 / 4, (3 / 4) ** 9),
1190+
np.append((5 / 6) ** np.arange(9) * 1 / 6, (5 / 6) ** 9),
1191+
]
1192+
),
1193+
),
1194+
),
11691195
],
11701196
)
11711197
def test_stickbreakingweights_moment(alpha, K, size, expected):

pymc/tests/test_distributions_random.py

+12
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,18 @@ def check_basic_properties(self):
13291329
assert np.all(draws <= 1)
13301330

13311331

1332+
class TestStickBreakingWeights_1D_alpha(BaseTestDistributionRandom):
1333+
pymc_dist = pm.StickBreakingWeights
1334+
pymc_dist_params = {"alpha": [1.0, 2.0, 3.0], "K": 19}
1335+
expected_rv_op_params = {"alpha": [1.0, 2.0, 3.0], "K": 19}
1336+
sizes_to_check = [None, (3,), (5, 3)]
1337+
sizes_expected = [(3, 20), (3, 20), (5, 3, 20)]
1338+
checks_to_run = [
1339+
"check_pymc_params_match_rv_op",
1340+
"check_rv_size",
1341+
]
1342+
1343+
13321344
class TestCategorical(BaseTestDistributionRandom):
13331345
pymc_dist = pm.Categorical
13341346
pymc_dist_params = {"p": np.array([0.28, 0.62, 0.10])}

0 commit comments

Comments
 (0)