Skip to content

Commit 617cd00

Browse files
kc611aseyboldt
andcommitted
Added ZeroSumNormal
Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent 9e4c7f9 commit 617cd00

File tree

7 files changed

+195
-2
lines changed

7 files changed

+195
-2
lines changed

pymc3/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
VonMises,
4949
Wald,
5050
Weibull,
51+
ZeroSumNormal,
5152
)
5253
from pymc3.distributions.discrete import (
5354
Bernoulli,
@@ -123,6 +124,7 @@
123124
"HalfStudentT",
124125
"ChiSquared",
125126
"HalfNormal",
127+
"ZeroSumNormal",
126128
"Wald",
127129
"Pareto",
128130
"InverseGamma",

pymc3/distributions/continuous.py

+62
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"Lognormal",
6666
"ChiSquared",
6767
"HalfNormal",
68+
"ZeroSumNormal",
6869
"Wald",
6970
"Pareto",
7071
"InverseGamma",
@@ -924,6 +925,67 @@ def logcdf(self, value):
924925
)
925926

926927

928+
class ZeroSumNormal(Continuous):
929+
def __init__(self, sigma=1, zerosum_dims=None, zerosum_axes=None, **kwargs):
930+
shape = kwargs.get("shape", ())
931+
dims = kwargs.get("dims", None)
932+
if isinstance(shape, int):
933+
shape = (shape,)
934+
935+
if isinstance(dims, str):
936+
dims = (dims,)
937+
938+
self.mu = self.median = self.mode = tt.zeros(shape)
939+
self.sigma = tt.as_tensor_variable(sigma)
940+
941+
if zerosum_dims is None and zerosum_axes is None:
942+
if shape:
943+
zerosum_axes = (-1,)
944+
else:
945+
zerosum_axes = ()
946+
947+
if isinstance(zerosum_axes, int):
948+
zerosum_axes = (zerosum_axes,)
949+
950+
if isinstance(zerosum_dims, str):
951+
zerosum_dims = (zerosum_dims,)
952+
953+
if zerosum_axes is not None and zerosum_dims is not None:
954+
raise ValueError("Only one of zerosum_axes and zerosum_dims can be specified.")
955+
956+
if zerosum_dims is not None:
957+
if dims is None:
958+
raise ValueError("zerosum_dims can only be used with the dims kwargs.")
959+
zerosum_axes = []
960+
for dim in zerosum_dims:
961+
zerosum_axes.append(dims.index(dim))
962+
self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes]
963+
964+
if "transform" not in kwargs or kwargs["transform"] == None:
965+
kwargs["transform"] = transforms.ZeroSumTransform(zerosum_axes)
966+
967+
super().__init__(**kwargs)
968+
969+
def logp(self, value):
970+
return Normal.dist(sigma=self.sigma).logp(value)
971+
972+
def _random(self, scale, size):
973+
samples = stats.norm.rvs(loc=0, scale=scale, size=size)
974+
for axis in self.zerosum_axes:
975+
samples -= np.mean(samples, axis=axis, keepdims=True)
976+
return samples
977+
978+
def random(self, point=None, size=None):
979+
(sigma,) = draw_values([self.sigma], point=point, size=size)
980+
return generate_samples(self._random, scale=sigma, dist_shape=self.shape, size=size)
981+
982+
def _distr_parameters_for_repr(self):
983+
return ["sigma"]
984+
985+
def logcdf(self, value):
986+
raise NotImplementedError()
987+
988+
927989
class Wald(PositiveContinuous):
928990
r"""
929991
Wald log-likelihood.

pymc3/distributions/distribution.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __new__(cls, name, *args, **kwargs):
9898
raise TypeError("observed needs to be data but got: {}".format(type(data)))
9999
total_size = kwargs.pop("total_size", None)
100100

101-
dims = kwargs.pop("dims", None)
101+
dims = kwargs["dims"] if "dims" in kwargs else None
102102
has_shape = "shape" in kwargs
103103
shape = kwargs.pop("shape", None)
104104
if dims is not None:

pymc3/distributions/transforms.py

+77
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import warnings
1616

17+
from typing import List
18+
1719
import numpy as np
1820
import theano.tensor as tt
1921

@@ -565,3 +567,78 @@ def jacobian_det(self, y):
565567
else:
566568
det += det_
567569
return det
570+
571+
572+
def _extend_axis(array, axis):
573+
n = array.shape[axis] + 1
574+
sum_vals = array.sum(axis, keepdims=True)
575+
norm = sum_vals / (np.sqrt(n) + n)
576+
fill_val = norm - sum_vals / np.sqrt(n)
577+
578+
out = tt.concatenate([array, fill_val], axis=axis)
579+
return out - norm
580+
581+
582+
def _extend_axis_rev(array, axis):
583+
if axis < 0:
584+
axis = axis % array.ndim
585+
assert axis >= 0 and axis < array.ndim
586+
587+
n = array.shape[axis]
588+
last = tt.take(array, [-1], axis=axis)
589+
590+
sum_vals = -last * np.sqrt(n)
591+
norm = sum_vals / (np.sqrt(n) + n)
592+
slice_before = (slice(None, None),) * axis
593+
return array[slice_before + (slice(None, -1),)] + norm
594+
595+
596+
def _extend_axis_val(array, axis):
597+
n = array.shape[axis] + 1
598+
sum_vals = array.sum(axis, keepdims=True)
599+
norm = sum_vals / (np.sqrt(n) + n)
600+
fill_val = norm - sum_vals / np.sqrt(n)
601+
602+
out = np.concatenate([array, fill_val], axis=axis)
603+
return out - norm
604+
605+
606+
def _extend_axis_rev_val(array, axis):
607+
n = array.shape[axis]
608+
last = np.take(array, [-1], axis=axis)
609+
610+
sum_vals = -last * np.sqrt(n)
611+
norm = sum_vals / (np.sqrt(n) + n)
612+
slice_before = (slice(None, None),) * len(array.shape[:axis])
613+
return array[slice_before + (slice(None, -1),)] + norm
614+
615+
616+
class ZeroSumTransform(Transform):
617+
name = "zerosum"
618+
619+
_zerosum_axes: List[int]
620+
621+
def __init__(self, zerosum_axes):
622+
self._zerosum_axes = zerosum_axes
623+
624+
def forward(self, x):
625+
for axis in self._zerosum_axes:
626+
x = _extend_axis_rev(x, axis=axis)
627+
return floatX(x)
628+
629+
def forward_val(self, x, point):
630+
for axis in self._zerosum_axes:
631+
x = _extend_axis_rev_val(x, axis=axis)
632+
return x
633+
634+
def backward(self, z):
635+
z = tt.as_tensor_variable(z)
636+
for axis in self._zerosum_axes:
637+
z = _extend_axis(z, axis=axis)
638+
return floatX(z)
639+
640+
def jacobian_det(self, x):
641+
return tt.constant(0.0)
642+
643+
644+
zerosum = ZeroSumTransform

pymc3/tests/test_distributions.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@
9393
ZeroInflatedBinomial,
9494
ZeroInflatedNegativeBinomial,
9595
ZeroInflatedPoisson,
96+
ZeroSumNormal,
9697
continuous,
9798
)
99+
from pymc3.distributions.transforms import zerosum
98100
from pymc3.math import kronecker, logsumexp
99101
from pymc3.model import Deterministic, Model, Point
100102
from pymc3.tests.helpers import select_by_precision
@@ -556,6 +558,7 @@ def check_logp(
556558
n_samples=100,
557559
extra_args=None,
558560
scipy_args=None,
561+
transform=None,
559562
):
560563
"""
561564
Generic test for PyMC3 logp methods
@@ -599,13 +602,18 @@ def check_logp(
599602

600603
def logp_reference(args):
601604
args.update(scipy_args)
605+
if transform:
606+
args["value"] = args.pop(f"value_{transform.name}__")
602607
return scipy_logp(**args)
603608

604609
model = build_model(pymc3_dist, domain, paramdomains, extra_args)
605610
logp = model.fastlogp
606611

607612
domains = paramdomains.copy()
608-
domains["value"] = domain
613+
if transform:
614+
domains[f"value_{transform.name}__"] = domain
615+
else:
616+
domains["value"] = domain
609617
for pt in product(domains, n_samples=n_samples):
610618
pt = Point(pt, model=model)
611619
assert_almost_equal(
@@ -932,6 +940,22 @@ def test_half_normal(self):
932940
lambda value, sigma: sp.halfnorm.logcdf(value, scale=sigma),
933941
)
934942

943+
def test_zerosum_normal(self):
944+
zerosum_axes = [-1]
945+
946+
def ref_fn(value, sigma):
947+
mu = 0
948+
return sp.norm.logpdf(value, mu, sigma)
949+
950+
self.check_logp(
951+
ZeroSumNormal,
952+
R,
953+
{"sigma": Rplus},
954+
ref_fn,
955+
decimal=select_by_precision(float64=6, float32=1),
956+
transform=zerosum(zerosum_axes),
957+
)
958+
935959
def test_chi_squared(self):
936960
self.check_logp(
937961
ChiSquared,

pymc3/tests/test_distributions_random.py

+20
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,11 @@ class TestHalfNormal(BaseTestCases.BaseTestCase):
343343
params = {"tau": 1.0}
344344

345345

346+
class TestZeroSumNormal(BaseTestCases.BaseTestCase):
347+
distribution = pm.ZeroSumNormal
348+
params = {"sigma": 1.0}
349+
350+
346351
class TestUniform(BaseTestCases.BaseTestCase):
347352
distribution = pm.Uniform
348353
params = {"lower": 0.0, "upper": 1.0}
@@ -622,6 +627,21 @@ def ref_rand(size, tau):
622627

623628
pymc3_random(pm.HalfNormal, {"tau": Rplus}, ref_rand=ref_rand)
624629

630+
def test_zerosum_normal(self):
631+
def ref_rand(size, sigma):
632+
shape = sigma.shape
633+
zerosum_axes = (-1,) if shape else ()
634+
zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes]
635+
n = shape[-1] if shape else 1
636+
samples = st.multivariate_normal.rvs(
637+
cov=sigma ** 2 * (np.eye(n) - np.ones(n) / n), size=n
638+
)
639+
for axis in zerosum_axes:
640+
samples -= np.mean(samples, axis=axis, keepdims=True)
641+
return samples
642+
643+
pymc3_random(pm.ZeroSumNormal, {"sigma": PdMatrix(3)}, ref_rand=ref_rand)
644+
625645
def test_wald(self):
626646
# Cannot do anything too exciting as scipy wald is a
627647
# location-scale model of the *standard* wald with mu=1 and lam=1

pymc3/tests/test_transforms.py

+8
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,14 @@ def test_chain():
242242
close_to_logical(np.diff(vals) >= 0, True, tol)
243243

244244

245+
def test_zerosum():
246+
zerosum_axes = [0]
247+
zerosum_transf = tr.ZeroSumTransform(zerosum_axes)
248+
249+
vals = get_values(zerosum_transf, Vector(R, 5), tt.dvector, np.random.random(5))
250+
close_to_logical(np.mean(vals) >= 0, True, tol)
251+
252+
245253
class TestElementWiseLogp(SeededTest):
246254
def build_model(self, distfam, params, shape, transform, testval=None):
247255
if testval is not None:

0 commit comments

Comments
 (0)