Skip to content

Commit 20d304b

Browse files
penelopeysmsunxd3
andauthored
Fix differential equation tutorial not converging (#592)
* Fix differential equation tutorial not converging * Add a note about the priors --------- Co-authored-by: Xianda Sun <[email protected]>
1 parent e09085b commit 20d304b

File tree

1 file changed

+29
-26
lines changed
  • tutorials/bayesian-differential-equations

1 file changed

+29
-26
lines changed

tutorials/bayesian-differential-equations/index.qmd

+29-26
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ To make the example more realistic we add random normally distributed noise to t
7979

8080
```{julia}
8181
sol = solve(prob, Tsit5(); saveat=0.1)
82-
odedata = Array(sol) + 0.8 * randn(size(Array(sol)))
82+
odedata = Array(sol) + 0.5 * randn(size(Array(sol)))
8383
8484
# Plot simulation and noisy observations.
8585
plot(sol; alpha=0.3)
@@ -93,23 +93,26 @@ Alternatively, we can use real-world data from Hudson’s Bay Company records (a
9393
Previously, functions in Turing and DifferentialEquations were not inter-composable, so Bayesian inference of differential equations needed to be handled by another package called [DiffEqBayes.jl](https://github.com/SciML/DiffEqBayes.jl) (note that DiffEqBayes works also with CmdStan.jl, Turing.jl, DynamicHMC.jl and ApproxBayes.jl - see the [DiffEqBayes docs](https://docs.sciml.ai/latest/analysis/parameter_estimation/#Bayesian-Methods-1) for more info).
9494

9595
Nowadays, however, Turing and DifferentialEquations are completely composable and we can just simulate differential equations inside a Turing `@model`.
96-
Therefore, we write the Lotka-Volterra parameter estimation problem using the Turing `@model` macro as below:
96+
Therefore, we write the Lotka-Volterra parameter estimation problem using the Turing `@model` macro as below.
97+
For the purposes of this tutorial, we choose priors for the parameters that are quite close to the ground truth.
98+
This helps us to illustrate the results without needing to run overly long MCMC chains:
99+
97100

98101
```{julia}
99102
@model function fitlv(data, prob)
100103
# Prior distributions.
101-
σ ~ InverseGamma(2, 3)
102-
α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
103-
β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
104-
γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
105-
δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)
104+
σ ~ InverseGamma(3, 2)
105+
α ~ truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5)
106+
β ~ truncated(Normal(1.1, 0.2); lower=0, upper=2)
107+
γ ~ truncated(Normal(3.0, 0.2); lower=1, upper=4)
108+
δ ~ truncated(Normal(1.0, 0.2); lower=0, upper=2)
106109
107110
# Simulate Lotka-Volterra model.
108111
p = [α, β, γ, δ]
109112
predicted = solve(prob, Tsit5(); p=p, saveat=0.1)
110113
111114
# Observations.
112-
for i in 1:length(predicted)
115+
for i in eachindex(predicted)
113116
data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
114117
end
115118
@@ -160,11 +163,11 @@ I.e., we fit the model only to the $y$ variable of the system without providing
160163
```{julia}
161164
@model function fitlv2(data::AbstractVector, prob)
162165
# Prior distributions.
163-
σ ~ InverseGamma(2, 3)
164-
α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
165-
β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
166-
γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
167-
δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)
166+
σ ~ InverseGamma(3, 2)
167+
α ~ truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5)
168+
β ~ truncated(Normal(1.1, 0.2); lower=0, upper=2)
169+
γ ~ truncated(Normal(3.0, 0.2); lower=1, upper=4)
170+
δ ~ truncated(Normal(1.0, 0.2); lower=0, upper=2)
168171
169172
# Simulate Lotka-Volterra model but save only the second state of the system (predators).
170173
p = [α, β, γ, δ]
@@ -260,18 +263,18 @@ Now we define the Turing model for the Lotka-Volterra model with delay and sampl
260263
```{julia}
261264
@model function fitlv_dde(data, prob)
262265
# Prior distributions.
263-
σ ~ InverseGamma(2, 3)
264-
α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
265-
β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
266-
γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
267-
δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)
266+
σ ~ InverseGamma(3, 2)
267+
α ~ truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5)
268+
β ~ truncated(Normal(1.1, 0.2); lower=0, upper=2)
269+
γ ~ truncated(Normal(3.0, 0.2); lower=1, upper=4)
270+
δ ~ truncated(Normal(1.0, 0.2); lower=0, upper=2)
268271
269272
# Simulate Lotka-Volterra model.
270273
p = [α, β, γ, δ]
271274
predicted = solve(prob, MethodOfSteps(Tsit5()); p=p, saveat=0.1)
272275
273276
# Observations.
274-
for i in 1:length(predicted)
277+
for i in eachindex(predicted)
275278
data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
276279
end
277280
end
@@ -340,18 +343,18 @@ Here we will not choose a `sensealg` and let it use the default choice:
340343
```{julia}
341344
@model function fitlv_sensealg(data, prob)
342345
# Prior distributions.
343-
σ ~ InverseGamma(2, 3)
344-
α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
345-
β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
346-
γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
347-
δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)
346+
σ ~ InverseGamma(3, 2)
347+
α ~ truncated(Normal(1.5, 0.2); lower=0.5, upper=2.5)
348+
β ~ truncated(Normal(1.1, 0.2); lower=0, upper=2)
349+
γ ~ truncated(Normal(3.0, 0.2); lower=1, upper=4)
350+
δ ~ truncated(Normal(1.0, 0.2); lower=0, upper=2)
348351
349352
# Simulate Lotka-Volterra model and use a specific algorithm for computing sensitivities.
350353
p = [α, β, γ, δ]
351354
predicted = solve(prob; p=p, saveat=0.1)
352355
353356
# Observations.
354-
for i in 1:length(predicted)
357+
for i in eachindex(predicted)
355358
data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
356359
end
357360
@@ -361,7 +364,7 @@ end;
361364
model_sensealg = fitlv_sensealg(odedata, prob)
362365
363366
# Sample a single chain with 1000 samples using Zygote.
364-
sample(model_sensealg, NUTS(;adtype=AutoZygote()), 1000; progress=false)
367+
sample(model_sensealg, NUTS(; adtype=AutoZygote()), 1000; progress=false)
365368
```
366369

367370
For more examples of adjoint usage on large parameter models, consult the [DiffEqFlux documentation](https://diffeqflux.sciml.ai/dev/).

0 commit comments

Comments
 (0)