Skip to content

Added ZeroSumNormal Distribution #4776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
VonMises,
Wald,
Weibull,
ZeroSumNormal,
)
from pymc3.distributions.discrete import (
Bernoulli,
Expand Down Expand Up @@ -123,6 +124,7 @@
"HalfStudentT",
"ChiSquared",
"HalfNormal",
"ZeroSumNormal",
"Wald",
"Pareto",
"InverseGamma",
Expand Down
68 changes: 68 additions & 0 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"Lognormal",
"ChiSquared",
"HalfNormal",
"ZeroSumNormal",
"Wald",
"Pareto",
"InverseGamma",
Expand Down Expand Up @@ -924,6 +925,73 @@ def logcdf(self, value):
)


class ZeroSumNormal(Continuous):
def __new__(cls, name, *args, **kwargs):
zerosum_axes = kwargs.get("zerosum_axes", None)
zerosum_dims = kwargs.get("zerosum_dims", None)
dims = kwargs.get("dims", None)

if isinstance(zerosum_dims, str):
zerosum_dims = (zerosum_dims,)
if isinstance(dims, str):
dims = (dims,)

if zerosum_dims is not None:
if dims is None:
raise ValueError("zerosum_dims can only be used with the dims kwargs.")
if zerosum_axes is not None:
raise ValueError("Only one of zerosum_axes and zerosum_dims can be specified.")
zerosum_axes = []
for dim in zerosum_dims:
zerosum_axes.append(dims.index(dim))
kwargs["zerosum_axes"] = zerosum_axes

return super().__new__(cls, name, *args, **kwargs)

def __init__(self, sigma=1, zerosum_axes=None, zerosum_dims=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that zerosum_dims is not used in __init__, but if I don't put it here, it doesn't seem to be passed on to __new__: TypeError: __init__() got an unexpected keyword argument 'zerosum_dims'
Not sure we can do it otherwise though. If someone has a better idea, I'm all ears

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the zerosum_dims is probably still in kwargs from line 949? We could just remove there.

shape = kwargs.get("shape", ())
if isinstance(shape, int):
shape = (shape,)

self.mu = self.median = self.mode = tt.zeros(shape)
self.sigma = tt.as_tensor_variable(sigma)

if zerosum_axes is None:
if shape:
zerosum_axes = (-1,)
else:
zerosum_axes = ()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes no sense to have a ZeroSumNormal when shape=() or None. In that case, the RV should also be exactly equal to zero. I think that we should test if shape is None or len(shape) == 0 and raise a ValueError in that case. Something that says, ZeroSumNormal is defined only for RVs that are not scalar.


if isinstance(zerosum_axes, int):
zerosum_axes = (zerosum_axes,)

self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enforcing positive axis here leads to problems when you draw samples from the prior predictive. It's better to replace this line with this

Suggested change
self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes]
self.zerosum_axes = [a if a < 0 else a - len(shape) for a in zerosum_axes]


if "transform" not in kwargs or kwargs["transform"] is None:
kwargs["transform"] = transforms.ZeroSumTransform(zerosum_axes)

super().__init__(**kwargs)

def logp(self, value):
return Normal.dist(sigma=self.sigma).logp(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don’t add the scaling of sigma here, our random method will be inconsistent with the logp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How so ? Isn't random also drawing samples using self.sigma ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logp is somewhat strange still. From the math side this should be pm.MvNormal with cov=I - J / n where J is a matrix of all 1s. We don't want to write it like this though, because we don't want to do matrix factorization, and pm.MvNormal doesn't work if an eigenvalues is 0.
It would be great, if instead we could define the logp simply in the transformed space. This would imply changes to TransformedDistribution though.

Copy link
Contributor

@lucianopaz lucianopaz Jan 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid the MvNormal.logp problem if we force self.sigma to be a scalar or to have a single element in all the zerosum_axes. In this case, all the directions in the zerosum manifold are uncorrelated and have equal variance. This means that we can use the Normal.logp as long as we also include a bound condition that guarantees that we are on the zerosum manifold. You can do this by using this logp:

Suggested change
return Normal.dist(sigma=self.sigma).logp(value)
zerosums = [tt.all(tt.abs_(tt.mean(x, axis=axis)) <= 1e-9) for axis in self.zerosum_axes]
return bound(
pm.Normal.dist(sigma=self.sigma).logp(x),
tt.all(self.sigma > 0),
broadcast_conditions=False,
*zerosums,
)

Copy link
Contributor

@lucianopaz lucianopaz Feb 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came across this wiki section that talks about the degenerate MvNormal case (which is what we have with the ZeroSumNormal). We could use that formula as the expected logp value and test if the logp that we are using in the distribution matches it. The expected logp would look something like this:

def pseudo_log_det(A, tol=1e-13):
    v, w = np.linalg.eigh(A)
    return np.sum(np.log(np.where(np.abs(v) >= tol, v, 1)), axis=-1)

def logp(value, sigma):
    n = value.shape[-1]
    cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
    psdet = 0.5 *  pseudo_log_det(2 * np.pi * cov)
    exp = 0.5 * (value[..., None, :] @ np.linalg.pinv(cov) @ value[..., None])[..., 0, 0]
    return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf)

Copy link
Contributor

@lucianopaz lucianopaz Feb 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran a few tests with the logp and it looks like the logp that we are using in this PR, doesn't match what one would expect from a degenerate multivariate normal distribution. In my comment above, I posted what a degenerate MvNormal logp looks like. For this particular problem, where we know that we have only one eigenvector with zero eigenvalue, we can re-write the logp as:

def logp(value, sigma):
    n = value.shape[-1]
    cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
    v, w = np.linalg.eigh(cov)
    psdet =  0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi))
    cov_pinv = w[:, 1:] @ np.diag(1 / v[1:]) @ w[:, 1:].T
    exp = 0.5 * (value[..., None, :] @ cov_pinv @ value[..., None])[..., 0, 0]
    return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf)

This is different from the logp that we are currently using in this PR. The difference is in the normalization constant:
psdet = 0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi)). In particular, since, all eigenvalues v except the first one are the same and are equal to sigma**2, psdet = (n - 1) * (0.5 * np.log(2 * np.pi) + np.log(np.sigma)). Whereas, with the assumed pm.Normal.dist(sigma=self.sigma).logp(x) the normalization factor we are getting is:

psdet = n * (0.5 * np.log(2 * np.pi) + np.log(np.sigma))

This means that we have to multiply the logp that we are using by (n-1)/n (in the case where only one axis sums to zero) to get the correct log probability density. I'll check what happens when more than one axes has to zerosum.


def _random(self, scale, size):
samples = stats.norm.rvs(loc=0, scale=scale, size=size)
for axis in self.zerosum_axes:
samples -= np.mean(samples, axis=axis, keepdims=True)
return samples

def random(self, point=None, size=None):
(sigma,) = draw_values([self.sigma], point=point, size=size)
return generate_samples(self._random, scale=sigma, dist_shape=self.shape, size=size)

def _distr_parameters_for_repr(self):
return ["sigma"]

def logcdf(self, value):
raise NotImplementedError()


class Wald(PositiveContinuous):
r"""
Wald log-likelihood.
Expand Down
77 changes: 77 additions & 0 deletions pymc3/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import warnings

from typing import List

import numpy as np
import theano.tensor as tt

Expand Down Expand Up @@ -565,3 +567,78 @@ def jacobian_det(self, y):
else:
det += det_
return det


def _extend_axis(array, axis):
n = array.shape[axis] + 1
sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (np.sqrt(n) + n)
fill_val = norm - sum_vals / np.sqrt(n)

out = tt.concatenate([array, fill_val], axis=axis)
return out - norm


def _extend_axis_rev(array, axis):
if axis < 0:
axis = axis % array.ndim
assert axis >= 0 and axis < array.ndim

n = array.shape[axis]
last = tt.take(array, [-1], axis=axis)

sum_vals = -last * np.sqrt(n)
norm = sum_vals / (np.sqrt(n) + n)
slice_before = (slice(None, None),) * axis
return array[slice_before + (slice(None, -1),)] + norm


def _extend_axis_val(array, axis):
n = array.shape[axis] + 1
sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (np.sqrt(n) + n)
fill_val = norm - sum_vals / np.sqrt(n)

out = np.concatenate([array, fill_val], axis=axis)
return out - norm


def _extend_axis_rev_val(array, axis):
n = array.shape[axis]
last = np.take(array, [-1], axis=axis)

sum_vals = -last * np.sqrt(n)
norm = sum_vals / (np.sqrt(n) + n)
slice_before = (slice(None, None),) * len(array.shape[:axis])
return array[slice_before + (slice(None, -1),)] + norm


class ZeroSumTransform(Transform):
name = "zerosum"

_zerosum_axes: List[int]

def __init__(self, zerosum_axes):
self._zerosum_axes = zerosum_axes

def forward(self, x):
for axis in self._zerosum_axes:
x = _extend_axis_rev(x, axis=axis)
return floatX(x)

def forward_val(self, x, point):
for axis in self._zerosum_axes:
x = _extend_axis_rev_val(x, axis=axis)
return x

def backward(self, z):
z = tt.as_tensor_variable(z)
for axis in self._zerosum_axes:
z = _extend_axis(z, axis=axis)
return floatX(z)

def jacobian_det(self, x):
return tt.constant(0.0)


zerosum = ZeroSumTransform
26 changes: 25 additions & 1 deletion pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@
ZeroInflatedBinomial,
ZeroInflatedNegativeBinomial,
ZeroInflatedPoisson,
ZeroSumNormal,
continuous,
)
from pymc3.distributions.transforms import zerosum
from pymc3.math import kronecker, logsumexp
from pymc3.model import Deterministic, Model, Point
from pymc3.tests.helpers import select_by_precision
Expand Down Expand Up @@ -556,6 +558,7 @@ def check_logp(
n_samples=100,
extra_args=None,
scipy_args=None,
transform=None,
):
"""
Generic test for PyMC3 logp methods
Expand Down Expand Up @@ -599,13 +602,18 @@ def check_logp(

def logp_reference(args):
args.update(scipy_args)
if transform:
args["value"] = args.pop(f"value_{transform.name}__")
return scipy_logp(**args)

model = build_model(pymc3_dist, domain, paramdomains, extra_args)
logp = model.fastlogp

domains = paramdomains.copy()
domains["value"] = domain
if transform:
domains[f"value_{transform.name}__"] = domain
else:
domains["value"] = domain
for pt in product(domains, n_samples=n_samples):
pt = Point(pt, model=model)
assert_almost_equal(
Expand Down Expand Up @@ -932,6 +940,22 @@ def test_half_normal(self):
lambda value, sigma: sp.halfnorm.logcdf(value, scale=sigma),
)

def test_zerosum_normal(self):
zerosum_axes = [-1]

def ref_fn(value, sigma):
mu = 0
return sp.norm.logpdf(value, mu, sigma)

self.check_logp(
ZeroSumNormal,
R,
{"sigma": Rplus},
ref_fn,
decimal=select_by_precision(float64=6, float32=1),
transform=zerosum(zerosum_axes),
)

def test_chi_squared(self):
self.check_logp(
ChiSquared,
Expand Down
20 changes: 20 additions & 0 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ class TestHalfNormal(BaseTestCases.BaseTestCase):
params = {"tau": 1.0}


class TestZeroSumNormal(BaseTestCases.BaseTestCase):
distribution = pm.ZeroSumNormal
params = {"sigma": 1.0}


class TestUniform(BaseTestCases.BaseTestCase):
distribution = pm.Uniform
params = {"lower": 0.0, "upper": 1.0}
Expand Down Expand Up @@ -622,6 +627,21 @@ def ref_rand(size, tau):

pymc3_random(pm.HalfNormal, {"tau": Rplus}, ref_rand=ref_rand)

def test_zerosum_normal(self):
def ref_rand(size, sigma):
shape = sigma.shape
zerosum_axes = (-1,) if shape else ()
zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes]
n = shape[-1] if shape else 1
samples = st.multivariate_normal.rvs(
cov=sigma ** 2 * (np.eye(n) - np.ones(n) / n), size=n
)
for axis in zerosum_axes:
samples -= np.mean(samples, axis=axis, keepdims=True)
return samples

pymc3_random(pm.ZeroSumNormal, {"sigma": PdMatrix(3)}, ref_rand=ref_rand)

def test_wald(self):
# Cannot do anything too exciting as scipy wald is a
# location-scale model of the *standard* wald with mu=1 and lam=1
Expand Down
8 changes: 8 additions & 0 deletions pymc3/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ def test_chain():
close_to_logical(np.diff(vals) >= 0, True, tol)


def test_zerosum():
zerosum_axes = [0]
zerosum_transf = tr.ZeroSumTransform(zerosum_axes)

vals = get_values(zerosum_transf, Vector(R, 5), tt.dvector, np.random.random(5))
close_to_logical(np.mean(vals) >= 0, True, tol)


class TestElementWiseLogp(SeededTest):
def build_model(self, distfam, params, shape, transform, testval=None):
if testval is not None:
Expand Down