Skip to content

Commit 65ae00e

Browse files
authored
Add logjac to logdensity_fn (#751)
* Add logjac to logdensity_fn * Refactor logprior_fn in SMCLinearRegressionTestCase
1 parent b107f9f commit 65ae00e

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

docs/examples/howto_sample_multiple_chains.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ observed = np.random.normal(loc, scale, size=1_000)
5757
def logdensity_fn(loc, log_scale, observed=observed):
5858
"""Univariate Normal"""
5959
scale = jnp.exp(log_scale)
60+
logjac = log_scale
6061
logpdf = stats.norm.logpdf(observed, loc, scale)
61-
return jnp.sum(logpdf)
62+
return logjac + jnp.sum(logpdf)
6263
6364
6465
def logdensity(x):

docs/examples/quickstart.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ observed = np.random.normal(loc, scale, size=1_000)
4848
def logdensity_fn(loc, log_scale, observed=observed):
4949
"""Univariate Normal"""
5050
scale = jnp.exp(log_scale)
51+
logjac = log_scale
5152
logpdf = stats.norm.logpdf(observed, loc, scale)
52-
return jnp.sum(logpdf)
53+
return logjac + jnp.sum(logpdf)
5354
5455
5556
logdensity = lambda x: logdensity_fn(**x)

tests/smc/__init__.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ def logdensity_fn(self, log_scale, coefs, preds, x):
1616
logpdf = self.logdensity_by_observation(log_scale, coefs, preds, x)
1717
return jnp.sum(logpdf)
1818

19+
def logprior_fn(self, log_scale, coefs):
20+
return log_scale + stats.norm.logpdf(log_scale) + stats.norm.logpdf(coefs)
21+
1922
def observations(self):
2023
num_particles = 100
2124

@@ -27,9 +30,7 @@ def observations(self):
2730
def particles_prior_loglikelihood(self):
2831
observations, num_particles = self.observations()
2932

30-
logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf(
31-
x["coefs"]
32-
)
33+
logprior_fn = lambda x: self.logprior_fn(**x)
3334
loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations)
3435

3536
log_scale_init = np.random.randn(num_particles)
@@ -45,9 +46,7 @@ def partial_posterior_test_case(self):
4546
y_data = 3 * x_data + np.random.normal(size=x_data.shape)
4647
observations = {"x": x_data, "preds": y_data}
4748

48-
logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf(
49-
x["coefs"]
50-
)
49+
logprior_fn = lambda x: self.logprior_fn(**x)
5150

5251
log_scale_init = np.random.randn(num_particles)
5352
coeffs_init = np.random.randn(num_particles)

0 commit comments

Comments
 (0)