Skip to content

Commit a004da3

Browse files
committed
add some tests, edit demo
1 parent 75e2e09 commit a004da3

2 files changed

Lines changed: 129 additions & 15 deletions

File tree

demo.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
1+
"""
2+
Instructive examples of how to use jaxpurify.
3+
4+
1. Basic model with random parameters
5+
2. Demonstration of `shapes`, `zeros`, `fixed`, and `intermediates` convenience methods
6+
3. Demonstration of `ravel=True` and `unravel` functionality
7+
4. Complex mock example with function calls, log-normal and uniform variables, and higher-order primitives
8+
5. Bayesian inference application with a linear Gaussian process model
9+
10+
"""
11+
112
# %%
213
import jax
314
import jax.numpy as jnp
415
import jax.random as jr
16+
from jax.scipy.sparse.linalg import cg
17+
from functools import partial
18+
import matplotlib.pyplot as plt
519

620
import jaxpurify as jp
721
from jaxpurify import purify
822

923
rng = jr.key(0)
1024

1125

12-
# %% Minimal example
26+
# %% 1. Basic model with random parameters
1327
@purify
1428
def model():
1529
x = jp.param(3, name="x")
@@ -22,10 +36,10 @@ def model():
2236
print("Result:", result)
2337

2438

25-
# %% Demonstration of `shapes`, `zeros`, `fixed`, `intermediate`
39+
# %% 2. Demonstration of `shapes`, `zeros`, `fixed`, and `intermediates` convenience methods
2640
@purify
2741
def model():
28-
a = jnp.array([[1, 0, 1]])
42+
a = jnp.array([[1, 2, 3]])
2943
x = jp.param(3, name="x")
3044
b = jp.fixed(name="b")
3145
y = jp.intermediate(a * x, name="prod") + b
@@ -34,15 +48,17 @@ def model():
3448
param_shapes = model.shapes()
3549
zero_params = model.zeros()
3650
zero_fixed = model.fixed()
51+
result = model(zero_params, {"b": 2.0})
3752
intermediates = model.intermediates(zero_params, {"b": 2.0})
3853

3954
print("Param shapes:", param_shapes)
4055
print("Zero params:", zero_params)
4156
print("Zero fixed:", zero_fixed)
57+
print("Result:", result)
4258
print("Intermediates:", intermediates)
4359

4460

45-
# %% Demonstration of `ravel=True`, in `unraveled_params` we see unnamed parameters get object id as name
61+
# %% 3. Demonstration of `ravel=True` and `unravel` functionality
4662
@purify(ravel=True)
4763
def model():
4864
x = jp.param(2)
@@ -58,7 +74,7 @@ def model():
5874
print("Unraveled params:", unraveled_params)
5975
print("Result:", result)
6076

61-
# %% Complex mock example with many functions, log-normal and uniform variables, higher-order primitives, and intermediates
77+
# %% 4. Complex mock example with many functions, log-normal and uniform variables, and higher-order primitives
6278
def chop(vegetable, slices):
6379
return (vegetable / slices) * jnp.ones(slices)
6480

@@ -84,7 +100,6 @@ def eat(food):
84100

85101
@purify
86102
def dinner():
87-
88103
carrots = jnp.cos(jp.param(5, name="carrots"))
89104
celery = jnp.sin(jp.param(5, name="celery"))
90105
onions = jnp.tanh(jp.param(2, name="onions"))
@@ -102,7 +117,7 @@ def dinner():
102117
jp.UniformVariable("bay_leaf", low=2, high=3),
103118
])
104119

105-
lentils = jp.param(3000, name="lentils")
120+
lentils = jp.log_normal(jp.param(300, name="lentils"), mean=1.0, sigma=0.1)
106121
lentils = jax.jvp(jnp.sin, (lentils,), (lentils,))[1]
107122
soup = boil(
108123
seasoning * jnp.concatenate([lentils, aromatics]),
@@ -116,15 +131,10 @@ def dinner():
116131
params = dinner.normal(rng)
117132
result = dinner(params)
118133
intermediates = dinner.intermediates(params)
119-
120134
lots_of_soup = jax.vmap(dinner.normal)(jr.split(rng, 10))
121135
grads = jax.jit(jax.vmap(jax.grad(dinner)))(lots_of_soup)
122136

123-
# %% Legitimate example of Bayesian inference for a (linear) Gaussian process model
124-
from jax.scipy.sparse.linalg import cg
125-
import matplotlib.pyplot as plt
126-
from functools import partial
127-
137+
# %% 5. Bayesian inference application with a linear Gaussian process model
128138
def field(xi, pad=10):
129139
n = xi.shape[0]
130140
k = jnp.fft.fftfreq(n, d=1.0/n)
@@ -139,7 +149,6 @@ def model():
139149
grid = jp.fixed(1000, name="grid")
140150
xi = jp.param(grid.shape[0] + 2*pad, name="xi")
141151
y = jp.intermediate(field(xi, pad=pad), name="y")
142-
143152
x_obs = jp.fixed(100, name="x_obs")
144153
y_obs = jnp.interp(x_obs, grid, y)
145154
return y_obs
@@ -176,4 +185,5 @@ def model():
176185
plt.errorbar(x_obs, y_obs, yerr=noise_std, fmt="o", c='k', alpha=0.5, ms=2, lw=1)
177186
plt.plot(grid, field_mean, c='C0')
178187
plt.plot(grid, field_samples.T, alpha=0.5, c='C0')
179-
plt.show()
188+
plt.show()
189+

tests/test_purify.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import jax.random as jr
4+
5+
import jaxpurify as jp
6+
from jaxpurify import purify
7+
8+
rng = jr.key(0)
9+
10+
def test_params():
11+
@purify
12+
def model():
13+
x = jp.param(3, name="x")
14+
return x**2
15+
16+
params = model.normal(rng)
17+
result = model(params)
18+
19+
assert params["x"].shape == (3,)
20+
assert jnp.allclose(result, params["x"]**2)
21+
22+
def test_shapes():
23+
@purify
24+
def model():
25+
x = jp.param((4,3), name="x")
26+
y = jp.param(2, name="y")
27+
28+
shapes = model.shapes()
29+
zeros = model.zeros()
30+
params = model.normal(rng)
31+
32+
assert shapes["x"].shape == (4, 3)
33+
assert shapes["y"].shape == (2,)
34+
assert zeros["x"].shape == (4, 3)
35+
assert zeros["y"].shape == (2,)
36+
assert params["x"].shape == (4, 3)
37+
assert params["y"].shape == (2,)
38+
39+
def test_ravel():
40+
@purify(ravel=True)
41+
def model():
42+
x = jp.param((4,3), name="x")
43+
y = jp.param(2, name="y")
44+
45+
shapes = model.shapes()
46+
zeros = model.zeros()
47+
params = model.normal(rng)
48+
49+
unraveled_params = model.unravel(params)
50+
51+
assert shapes.shape == (14,)
52+
assert zeros.shape == (14,)
53+
assert params.shape == (14,)
54+
assert unraveled_params["x"].shape == (4, 3)
55+
assert unraveled_params["y"].shape == (2,)
56+
57+
def test_fixed():
58+
@purify
59+
def model():
60+
x = jp.param(3, name="x")
61+
b = jp.fixed(name="b")
62+
return x + b
63+
64+
params = model.normal(rng)
65+
fixed = model.fixed()
66+
result = model(params, fixed)
67+
68+
assert fixed["b"] == 0.0
69+
assert jnp.allclose(result, params["x"] + fixed["b"])
70+
71+
def test_intermediates():
72+
@purify
73+
def model():
74+
x = jp.param(3, name="x")
75+
y = jp.intermediate(x**2, name="y")
76+
return y + 1
77+
78+
params = model.normal(rng)
79+
intermediates = model.intermediates(params)
80+
assert jnp.allclose(intermediates["y"], params["x"]**2)
81+
82+
def higher_order_primitives():
83+
@purify
84+
def model():
85+
x = jp.param(3, name="x")
86+
y = jax.vjp(jnp.sin, x)[1](x + 1)[0]
87+
z = jax.jit(jax.vmap(jnp.cos))(y)
88+
return z
89+
90+
def model_explicit(x):
91+
y = jax.vjp(jnp.sin, x)[1](x + 1)[0]
92+
z = jax.jit(jax.vmap(jnp.cos))(y)
93+
return z
94+
95+
params = model.normal(rng)
96+
result = model(params)
97+
result_explicit = model_explicit(params)
98+
assert jnp.allclose(result, result_explicit)
99+
100+
result_vjp = jax.vjp(model_explicit, params)[1](2*result)
101+
result_vjp_explicit = jax.vjp(model_explicit, params)[1](2*result_explicit)
102+
assert jnp.allclose(result_vjp, result_vjp_explicit)
103+
104+

0 commit comments

Comments
 (0)