-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Added ZeroSumNormal Distribution #4776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -65,6 +65,7 @@ | |||||||||||||||||
"Lognormal", | ||||||||||||||||||
"ChiSquared", | ||||||||||||||||||
"HalfNormal", | ||||||||||||||||||
"ZeroSumNormal", | ||||||||||||||||||
"Wald", | ||||||||||||||||||
"Pareto", | ||||||||||||||||||
"InverseGamma", | ||||||||||||||||||
|
@@ -924,6 +925,73 @@ def logcdf(self, value): | |||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class ZeroSumNormal(Continuous): | ||||||||||||||||||
def __new__(cls, name, *args, **kwargs): | ||||||||||||||||||
zerosum_axes = kwargs.get("zerosum_axes", None) | ||||||||||||||||||
zerosum_dims = kwargs.get("zerosum_dims", None) | ||||||||||||||||||
dims = kwargs.get("dims", None) | ||||||||||||||||||
|
||||||||||||||||||
if isinstance(zerosum_dims, str): | ||||||||||||||||||
zerosum_dims = (zerosum_dims,) | ||||||||||||||||||
if isinstance(dims, str): | ||||||||||||||||||
dims = (dims,) | ||||||||||||||||||
|
||||||||||||||||||
if zerosum_dims is not None: | ||||||||||||||||||
if dims is None: | ||||||||||||||||||
raise ValueError("zerosum_dims can only be used with the dims kwargs.") | ||||||||||||||||||
if zerosum_axes is not None: | ||||||||||||||||||
raise ValueError("Only one of zerosum_axes and zerosum_dims can be specified.") | ||||||||||||||||||
zerosum_axes = [] | ||||||||||||||||||
for dim in zerosum_dims: | ||||||||||||||||||
zerosum_axes.append(dims.index(dim)) | ||||||||||||||||||
kwargs["zerosum_axes"] = zerosum_axes | ||||||||||||||||||
|
||||||||||||||||||
return super().__new__(cls, name, *args, **kwargs) | ||||||||||||||||||
|
||||||||||||||||||
def __init__(self, sigma=1, zerosum_axes=None, zerosum_dims=None, **kwargs): | ||||||||||||||||||
shape = kwargs.get("shape", ()) | ||||||||||||||||||
if isinstance(shape, int): | ||||||||||||||||||
shape = (shape,) | ||||||||||||||||||
|
||||||||||||||||||
self.mu = self.median = self.mode = tt.zeros(shape) | ||||||||||||||||||
self.sigma = tt.as_tensor_variable(sigma) | ||||||||||||||||||
|
||||||||||||||||||
if zerosum_axes is None: | ||||||||||||||||||
if shape: | ||||||||||||||||||
zerosum_axes = (-1,) | ||||||||||||||||||
else: | ||||||||||||||||||
zerosum_axes = () | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes no sense to have a |
||||||||||||||||||
|
||||||||||||||||||
if isinstance(zerosum_axes, int): | ||||||||||||||||||
zerosum_axes = (zerosum_axes,) | ||||||||||||||||||
|
||||||||||||||||||
self.zerosum_axes = [a if a >= 0 else len(shape) + a for a in zerosum_axes] | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enforcing positive axis here leads to problems when you draw samples from the prior predictive. It's better to replace this line with this
Suggested change
|
||||||||||||||||||
|
||||||||||||||||||
if "transform" not in kwargs or kwargs["transform"] is None: | ||||||||||||||||||
kwargs["transform"] = transforms.ZeroSumTransform(zerosum_axes) | ||||||||||||||||||
|
||||||||||||||||||
super().__init__(**kwargs) | ||||||||||||||||||
|
||||||||||||||||||
def logp(self, value): | ||||||||||||||||||
return Normal.dist(sigma=self.sigma).logp(value) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don’t add the scaling of sigma here, our There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How so ? Isn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logp is somewhat strange still. From the math side this should be pm.MvNormal with cov=I - J / n where J is a matrix of all 1s. We don't want to write it like this though, because we don't want to do matrix factorization, and pm.MvNormal doesn't work if an eigenvalues is 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can avoid the
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I came across this wiki section that talks about the degenerate MvNormal case (which is what we have with the ZeroSumNormal). We could use that formula as the expected logp value and test if the def pseudo_log_det(A, tol=1e-13):
v, w = np.linalg.eigh(A)
return np.sum(np.log(np.where(np.abs(v) >= tol, v, 1)), axis=-1)
def logp(value, sigma):
n = value.shape[-1]
cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
psdet = 0.5 * pseudo_log_det(2 * np.pi * cov)
exp = 0.5 * (value[..., None, :] @ np.linalg.pinv(cov) @ value[..., None])[..., 0, 0]
return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ran a few tests with the logp and it looks like the def logp(value, sigma):
n = value.shape[-1]
cov = np.asarray(sigma)[..., None, None]**2 * (np.eye(n) - np.ones((n, n)) / n)
v, w = np.linalg.eigh(cov)
psdet = 0.5 * (np.sum(np.log(v[..., 1:])) + (n - 1) * np.log(2 * np.pi))
cov_pinv = w[:, 1:] @ np.diag(1 / v[1:]) @ w[:, 1:].T
exp = 0.5 * (value[..., None, :] @ cov_pinv @ value[..., None])[..., 0, 0]
return np.where(np.abs(np.sum(value, axis=-1)) < 1e-9, -psdet - exp, -np.inf) This is different from the psdet = n * (0.5 * np.log(2 * np.pi) + np.log(np.sigma)) This means that we have to multiply the |
||||||||||||||||||
|
||||||||||||||||||
def _random(self, scale, size): | ||||||||||||||||||
samples = stats.norm.rvs(loc=0, scale=scale, size=size) | ||||||||||||||||||
for axis in self.zerosum_axes: | ||||||||||||||||||
samples -= np.mean(samples, axis=axis, keepdims=True) | ||||||||||||||||||
return samples | ||||||||||||||||||
|
||||||||||||||||||
def random(self, point=None, size=None): | ||||||||||||||||||
(sigma,) = draw_values([self.sigma], point=point, size=size) | ||||||||||||||||||
return generate_samples(self._random, scale=sigma, dist_shape=self.shape, size=size) | ||||||||||||||||||
|
||||||||||||||||||
def _distr_parameters_for_repr(self): | ||||||||||||||||||
return ["sigma"] | ||||||||||||||||||
|
||||||||||||||||||
def logcdf(self, value): | ||||||||||||||||||
raise NotImplementedError() | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class Wald(PositiveContinuous): | ||||||||||||||||||
r""" | ||||||||||||||||||
Wald log-likelihood. | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that
zerosum_dims
is not used in__init__
, but if I don't put it here, it doesn't seem to be passed on to__new__
:TypeError: __init__() got an unexpected keyword argument 'zerosum_dims'
Not sure we can do it otherwise though. If someone has a better idea, I'm all ears
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the zerosum_dims is probably still in kwargs from line 949? We could just remove there.