From 617cd00674de3143a9ae068ed0492a8b0098bf28 Mon Sep 17 00:00:00 2001 From: kc611 Date: Mon, 16 Aug 2021 22:09:17 +0530 Subject: [PATCH 1/2] Added ZeroSumNormal Co-authored-by: Adrian Seyboldt --- pymc3/distributions/__init__.py | 2 + pymc3/distributions/continuous.py | 62 +++++++++++++++++++ pymc3/distributions/distribution.py | 2 +- pymc3/distributions/transforms.py | 77 ++++++++++++++++++++++++ pymc3/tests/test_distributions.py | 26 +++++++- pymc3/tests/test_distributions_random.py | 20 ++++++ pymc3/tests/test_transforms.py | 8 +++ 7 files changed, 195 insertions(+), 2 deletions(-) diff --git a/pymc3/distributions/__init__.py b/pymc3/distributions/__init__.py index 51958f541b..b1d28bcaa7 100644 --- a/pymc3/distributions/__init__.py +++ b/pymc3/distributions/__init__.py @@ -48,6 +48,7 @@ VonMises, Wald, Weibull, + ZeroSumNormal, ) from pymc3.distributions.discrete import ( Bernoulli, @@ -123,6 +124,7 @@ "HalfStudentT", "ChiSquared", "HalfNormal", + "ZeroSumNormal", "Wald", "Pareto", "InverseGamma", diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index 234ed935f2..d493630689 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -65,6 +65,7 @@ "Lognormal", "ChiSquared", "HalfNormal", + "ZeroSumNormal", "Wald", "Pareto", "InverseGamma", @@ -924,6 +925,67 @@ def logcdf(self, value): ) +class ZeroSumNormal(Continuous): + def __init__(self, sigma=1, zerosum_dims=None, zerosum_axes=None, **kwargs): + shape = kwargs.get("shape", ()) + dims = kwargs.get("dims", None) + if isinstance(shape, int): + shape = (shape,) + + if isinstance(dims, str): + dims = (dims,) + + self.mu = self.median = self.mode = tt.zeros(shape) + self.sigma = tt.as_tensor_variable(sigma) + + if zerosum_dims is None and zerosum_axes is None: + if shape: + zerosum_axes = (-1,) + else: + zerosum_axes = () + + if isinstance(zerosum_axes, int): + zerosum_axes = (zerosum_axes,) + + if isinstance(zerosum_dims, str): + zerosum_dims = (zerosum_dims,) + + if zerosum_axes is not None and zerosum_dims is not None: + raise ValueError("Only one of zerosum_axes and zerosum_dims can be specified.") + + if zerosum_dims is not None: + if dims is None: + raise ValueError("zerosum_dims can only be used with the dims kwargs.") + zerosum_axes = [] + for dim in zerosum_dims: + zerosum_axes.append(dims.index(dim)) + self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes] + + if "transform" not in kwargs or kwargs["transform"] == None: + kwargs["transform"] = transforms.ZeroSumTransform(zerosum_axes) + + super().__init__(**kwargs) + + def logp(self, value): + return Normal.dist(sigma=self.sigma).logp(value) + + def _random(self, scale, size): + samples = stats.norm.rvs(loc=0, scale=scale, size=size) + for axis in self.zerosum_axes: + samples -= np.mean(samples, axis=axis, keepdims=True) + return samples + + def random(self, point=None, size=None): + (sigma,) = draw_values([self.sigma], point=point, size=size) + return generate_samples(self._random, scale=sigma, dist_shape=self.shape, size=size) + + def _distr_parameters_for_repr(self): + return ["sigma"] + + def logcdf(self, value): + raise NotImplementedError() + + class Wald(PositiveContinuous): r""" Wald log-likelihood. diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 3ba2d0a040..0a9d4a967c 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -98,7 +98,7 @@ def __new__(cls, name, *args, **kwargs): raise TypeError("observed needs to be data but got: {}".format(type(data))) total_size = kwargs.pop("total_size", None) - dims = kwargs.pop("dims", None) + dims = kwargs["dims"] if "dims" in kwargs else None has_shape = "shape" in kwargs shape = kwargs.pop("shape", None) if dims is not None: diff --git a/pymc3/distributions/transforms.py b/pymc3/distributions/transforms.py index 880301182c..a32130a29e 100644 --- a/pymc3/distributions/transforms.py +++ b/pymc3/distributions/transforms.py @@ -14,6 +14,8 @@ import warnings +from typing import List + import numpy as np import theano.tensor as tt @@ -565,3 +567,78 @@ def jacobian_det(self, y): else: det += det_ return det + + +def _extend_axis(array, axis): + n = array.shape[axis] + 1 + sum_vals = array.sum(axis, keepdims=True) + norm = sum_vals / (np.sqrt(n) + n) + fill_val = norm - sum_vals / np.sqrt(n) + + out = tt.concatenate([array, fill_val], axis=axis) + return out - norm + + +def _extend_axis_rev(array, axis): + if axis < 0: + axis = axis % array.ndim + assert axis >= 0 and axis < array.ndim + + n = array.shape[axis] + last = tt.take(array, [-1], axis=axis) + + sum_vals = -last * np.sqrt(n) + norm = sum_vals / (np.sqrt(n) + n) + slice_before = (slice(None, None),) * axis + return array[slice_before + (slice(None, -1),)] + norm + + +def _extend_axis_val(array, axis): + n = array.shape[axis] + 1 + sum_vals = array.sum(axis, keepdims=True) + norm = sum_vals / (np.sqrt(n) + n) + fill_val = norm - sum_vals / np.sqrt(n) + + out = np.concatenate([array, fill_val], axis=axis) + return out - norm + + +def _extend_axis_rev_val(array, axis): + n = array.shape[axis] + last = np.take(array, [-1], axis=axis) + + sum_vals = -last * np.sqrt(n) + norm = sum_vals / (np.sqrt(n) + n) + slice_before = (slice(None, None),) * len(array.shape[:axis]) + return array[slice_before + (slice(None, -1),)] + norm + + +class ZeroSumTransform(Transform): + name = "zerosum" + + _zerosum_axes: List[int] + + def __init__(self, zerosum_axes): + self._zerosum_axes = zerosum_axes + + def forward(self, x): + for axis in self._zerosum_axes: + x = _extend_axis_rev(x, axis=axis) + return floatX(x) + + def forward_val(self, x, point): + for axis in self._zerosum_axes: + x = _extend_axis_rev_val(x, axis=axis) + return x + + def backward(self, z): + z = tt.as_tensor_variable(z) + for axis in self._zerosum_axes: + z = _extend_axis(z, axis=axis) + return floatX(z) + + def jacobian_det(self, x): + return tt.constant(0.0) + + +zerosum = ZeroSumTransform diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 06efc90b8d..6523d2f718 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -93,8 +93,10 @@ ZeroInflatedBinomial, ZeroInflatedNegativeBinomial, ZeroInflatedPoisson, + ZeroSumNormal, continuous, ) +from pymc3.distributions.transforms import zerosum from pymc3.math import kronecker, logsumexp from pymc3.model import Deterministic, Model, Point from pymc3.tests.helpers import select_by_precision @@ -556,6 +558,7 @@ def check_logp( n_samples=100, extra_args=None, scipy_args=None, + transform=None, ): """ Generic test for PyMC3 logp methods @@ -599,13 +602,18 @@ def check_logp( def logp_reference(args): args.update(scipy_args) + if transform: + args["value"] = args.pop(f"value_{transform.name}__") return scipy_logp(**args) model = build_model(pymc3_dist, domain, paramdomains, extra_args) logp = model.fastlogp domains = paramdomains.copy() - domains["value"] = domain + if transform: + domains[f"value_{transform.name}__"] = domain + else: + domains["value"] = domain for pt in product(domains, n_samples=n_samples): pt = Point(pt, model=model) assert_almost_equal( @@ -932,6 +940,22 @@ def test_half_normal(self): lambda value, sigma: sp.halfnorm.logcdf(value, scale=sigma), ) + def test_zerosum_normal(self): + zerosum_axes = [-1] + + def ref_fn(value, sigma): + mu = 0 + return sp.norm.logpdf(value, mu, sigma) + + self.check_logp( + ZeroSumNormal, + R, + {"sigma": Rplus}, + ref_fn, + decimal=select_by_precision(float64=6, float32=1), + transform=zerosum(zerosum_axes), + ) + def test_chi_squared(self): self.check_logp( ChiSquared, diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index a56f3f3b7b..93c5c231bd 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -343,6 +343,11 @@ class TestHalfNormal(BaseTestCases.BaseTestCase): params = {"tau": 1.0} +class TestZeroSumNormal(BaseTestCases.BaseTestCase): + distribution = pm.ZeroSumNormal + params = {"sigma": 1.0} + + class TestUniform(BaseTestCases.BaseTestCase): distribution = pm.Uniform params = {"lower": 0.0, "upper": 1.0} @@ -622,6 +627,21 @@ def ref_rand(size, tau): pymc3_random(pm.HalfNormal, {"tau": Rplus}, ref_rand=ref_rand) + def test_zerosum_normal(self): + def ref_rand(size, sigma): + shape = sigma.shape + zerosum_axes = (-1,) if shape else () + zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes] + n = shape[-1] if shape else 1 + samples = st.multivariate_normal.rvs( + cov=sigma ** 2 * (np.eye(n) - np.ones(n) / n), size=n + ) + for axis in zerosum_axes: + samples -= np.mean(samples, axis=axis, keepdims=True) + return samples + + pymc3_random(pm.ZeroSumNormal, {"sigma": PdMatrix(3)}, ref_rand=ref_rand) + def test_wald(self): # Cannot do anything too exciting as scipy wald is a # location-scale model of the *standard* wald with mu=1 and lam=1 diff --git a/pymc3/tests/test_transforms.py b/pymc3/tests/test_transforms.py index e9ab89938b..d16a4fc159 100644 --- a/pymc3/tests/test_transforms.py +++ b/pymc3/tests/test_transforms.py @@ -242,6 +242,14 @@ def test_chain(): close_to_logical(np.diff(vals) >= 0, True, tol) +def test_zerosum(): + zerosum_axes = [0] + zerosum_transf = tr.ZeroSumTransform(zerosum_axes) + + vals = get_values(zerosum_transf, Vector(R, 5), tt.dvector, np.random.random(5)) + close_to_logical(np.mean(vals) >= 0, True, tol) + + class TestElementWiseLogp(SeededTest): def build_model(self, distfam, params, shape, transform, testval=None): if testval is not None: From f08bfcb8174ab363ffb9201c5665569bd1bb179b Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Tue, 5 Oct 2021 16:58:59 +0200 Subject: [PATCH 2/2] Enable zerosum_dims --- pymc3/distributions/continuous.py | 42 ++++++++++++++++------------- pymc3/distributions/distribution.py | 2 +- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index d493630689..e3e874797f 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -926,19 +926,37 @@ def logcdf(self, value): class ZeroSumNormal(Continuous): - def __init__(self, sigma=1, zerosum_dims=None, zerosum_axes=None, **kwargs): - shape = kwargs.get("shape", ()) + def __new__(cls, name, *args, **kwargs): + zerosum_axes = kwargs.get("zerosum_axes", None) + zerosum_dims = kwargs.get("zerosum_dims", None) dims = kwargs.get("dims", None) - if isinstance(shape, int): - shape = (shape,) + if isinstance(zerosum_dims, str): + zerosum_dims = (zerosum_dims,) if isinstance(dims, str): dims = (dims,) + if zerosum_dims is not None: + if dims is None: + raise ValueError("zerosum_dims can only be used with the dims kwargs.") + if zerosum_axes is not None: + raise ValueError("Only one of zerosum_axes and zerosum_dims can be specified.") + zerosum_axes = [] + for dim in zerosum_dims: + zerosum_axes.append(dims.index(dim)) + kwargs["zerosum_axes"] = zerosum_axes + + return super().__new__(cls, name, *args, **kwargs) + + def __init__(self, sigma=1, zerosum_axes=None, zerosum_dims=None, **kwargs): + shape = kwargs.get("shape", ()) + if isinstance(shape, int): + shape = (shape,) + self.mu = self.median = self.mode = tt.zeros(shape) self.sigma = tt.as_tensor_variable(sigma) - if zerosum_dims is None and zerosum_axes is None: + if zerosum_axes is None: if shape: zerosum_axes = (-1,) else: @@ -947,21 +965,9 @@ def __init__(self, sigma=1, zerosum_dims=None, zerosum_axes=None, **kwargs): if isinstance(zerosum_axes, int): zerosum_axes = (zerosum_axes,) - if isinstance(zerosum_dims, str): - zerosum_dims = (zerosum_dims,) - - if zerosum_axes is not None and zerosum_dims is not None: - raise ValueError("Only one of zerosum_axes and zerosum_dims can be specified.") - - if zerosum_dims is not None: - if dims is None: - raise ValueError("zerosum_dims can only be used with the dims kwargs.") - zerosum_axes = [] - for dim in zerosum_dims: - zerosum_axes.append(dims.index(dim)) self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes] - if "transform" not in kwargs or kwargs["transform"] == None: + if "transform" not in kwargs or kwargs["transform"] is None: kwargs["transform"] = transforms.ZeroSumTransform(zerosum_axes) super().__init__(**kwargs) diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 0a9d4a967c..3ba2d0a040 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -98,7 +98,7 @@ def __new__(cls, name, *args, **kwargs): raise TypeError("observed needs to be data but got: {}".format(type(data))) total_size = kwargs.pop("total_size", None) - dims = kwargs["dims"] if "dims" in kwargs else None + dims = kwargs.pop("dims", None) has_shape = "shape" in kwargs shape = kwargs.pop("shape", None) if dims is not None: