1+ """
2+ Instructive examples of how to use jaxpurify.
3+
4+ 1. Basic model with random parameters
5+ 2. Demonstration of `shapes`, `zeros`, `fixed`, and `intermediates` convenience methods
6+ 3. Demonstration of `ravel=True` and `unravel` functionality
7+ 4. Complex mock example with function calls, log-normal and uniform variables, and higher-order primitives
8+ 5. Bayesian inference application with a linear Gaussian process model
9+
10+ """
11+
112# %%
213import jax
314import jax .numpy as jnp
415import jax .random as jr
16+ from jax .scipy .sparse .linalg import cg
17+ from functools import partial
18+ import matplotlib .pyplot as plt
519
620import jaxpurify as jp
721from jaxpurify import purify
822
923rng = jr .key (0 )
1024
1125
12- # %% Minimal example
26+ # %% 1. Basic model with random parameters
1327@purify
1428def model ():
1529 x = jp .param (3 , name = "x" )
@@ -22,10 +36,10 @@ def model():
2236print ("Result:" , result )
2337
2438
25- # %% Demonstration of `shapes`, `zeros`, `fixed`, `intermediate`
39+ # %% 2. Demonstration of `shapes`, `zeros`, `fixed`, and `intermediates` convenience methods
2640@purify
2741def model ():
28- a = jnp .array ([[1 , 0 , 1 ]])
42+ a = jnp .array ([[1 , 2 , 3 ]])
2943 x = jp .param (3 , name = "x" )
3044 b = jp .fixed (name = "b" )
3145 y = jp .intermediate (a * x , name = "prod" ) + b
@@ -34,15 +48,17 @@ def model():
3448param_shapes = model .shapes ()
3549zero_params = model .zeros ()
3650zero_fixed = model .fixed ()
51+ result = model (zero_params , {"b" : 2.0 })
3752intermediates = model .intermediates (zero_params , {"b" : 2.0 })
3853
3954print ("Param shapes:" , param_shapes )
4055print ("Zero params:" , zero_params )
4156print ("Zero fixed:" , zero_fixed )
57+ print ("Result:" , result )
4258print ("Intermediates:" , intermediates )
4359
4460
45- # %% Demonstration of `ravel=True`, in `unraveled_params` we see unnamed parameters get object id as name
61+ # %% 3. Demonstration of `ravel=True` and `unravel` functionality
4662@purify (ravel = True )
4763def model ():
4864 x = jp .param (2 )
@@ -58,7 +74,7 @@ def model():
5874print ("Unraveled params:" , unraveled_params )
5975print ("Result:" , result )
6076
61- # %% Complex mock example with many functions, log-normal and uniform variables, higher-order primitives, and intermediates
77+ # %% 4. Complex mock example with many functions, log-normal and uniform variables, and higher-order primitives
6278def chop (vegetable , slices ):
6379 return (vegetable / slices ) * jnp .ones (slices )
6480
@@ -84,7 +100,6 @@ def eat(food):
84100
85101@purify
86102def dinner ():
87-
88103 carrots = jnp .cos (jp .param (5 , name = "carrots" ))
89104 celery = jnp .sin (jp .param (5 , name = "celery" ))
90105 onions = jnp .tanh (jp .param (2 , name = "onions" ))
@@ -102,7 +117,7 @@ def dinner():
102117 jp .UniformVariable ("bay_leaf" , low = 2 , high = 3 ),
103118 ])
104119
105- lentils = jp .param (3000 , name = "lentils" )
120+ lentils = jp .log_normal ( jp . param (300 , name = "lentils" ), mean = 1.0 , sigma = 0.1 )
106121 lentils = jax .jvp (jnp .sin , (lentils ,), (lentils ,))[1 ]
107122 soup = boil (
108123 seasoning * jnp .concatenate ([lentils , aromatics ]),
@@ -116,15 +131,10 @@ def dinner():
116131params = dinner .normal (rng )
117132result = dinner (params )
118133intermediates = dinner .intermediates (params )
119-
120134lots_of_soup = jax .vmap (dinner .normal )(jr .split (rng , 10 ))
121135grads = jax .jit (jax .vmap (jax .grad (dinner )))(lots_of_soup )
122136
123- # %% Legitimate example of Bayesian inference for a (linear) Gaussian process model
124- from jax .scipy .sparse .linalg import cg
125- import matplotlib .pyplot as plt
126- from functools import partial
127-
137+ # %% 5. Bayesian inference application with a linear Gaussian process model
128138def field (xi , pad = 10 ):
129139 n = xi .shape [0 ]
130140 k = jnp .fft .fftfreq (n , d = 1.0 / n )
@@ -139,7 +149,6 @@ def model():
139149 grid = jp .fixed (1000 , name = "grid" )
140150 xi = jp .param (grid .shape [0 ] + 2 * pad , name = "xi" )
141151 y = jp .intermediate (field (xi , pad = pad ), name = "y" )
142-
143152 x_obs = jp .fixed (100 , name = "x_obs" )
144153 y_obs = jnp .interp (x_obs , grid , y )
145154 return y_obs
@@ -176,4 +185,5 @@ def model():
176185plt .errorbar (x_obs , y_obs , yerr = noise_std , fmt = "o" , c = 'k' , alpha = 0.5 , ms = 2 , lw = 1 )
177186plt .plot (grid , field_mean , c = 'C0' )
178187plt .plot (grid , field_samples .T , alpha = 0.5 , c = 'C0' )
179- plt .show ()
188+ plt .show ()
189+
0 commit comments