Skip to content

Commit a1f3789

Browse files
committed
cleanup Lorentz ODE example
- 1.0 syntax (fix Array ctor, use range()) - reduce memory footprint by avoid repeated array allocations - more inbounds fixes #116
1 parent 195b009 commit a1f3789

File tree

1 file changed

+66
-56
lines changed

1 file changed

+66
-56
lines changed

examples/ode_lorentz_attractor.jl

+66-56
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,33 @@
11
# Let's try parameter estimation for an ODE, here the famous Lorentz attractor.
22
# We base our implementation on the code from Paulo Marques available at:
33
# https://github.com/pjpmarques/Julia-Modeling-the-World/blob/master/Lorenz%20Attractor.ipynb
4-
function lorentz_equations(params::Vector{Float64}, r::Vector{Float64})
4+
5+
function lorentz_ode!(dr::AbstractVector{Float64},
6+
params::AbstractVector{Float64},
7+
r::AbstractVector{Float64})
8+
@assert length(dr) == length(r) == 3
9+
510
# Get individual params for this ODE
6-
sigma, rho, beta = params
11+
@inbounds sigma, rho, beta = params
712

813
# Extract the coordinates from the r vector
9-
x, y, z = r
14+
@inbounds x, y, z = r
1015

11-
# The Lorenz equations
12-
dx_dt = sigma*(y - x)
13-
dy_dt = x*(rho - z) - y
14-
dz_dt = x*y - beta*z
16+
# The Lorenz equations, put the derivative into dr
17+
@inbounds dr[1] = sigma*(y - x) # dx/dt
18+
@inbounds dr[2] = x*(rho - z) - y # dy/dt
19+
@inbounds dr[3] = x*y - beta*z # dz/dt
1520

16-
# Return the derivatives as a vector
17-
Float64[dx_dt, dy_dt, dz_dt]
21+
return dr
1822
end
1923

2024
# Define time vector and interval grid.
21-
dt = 0.001
22-
tf = 100.0
23-
tinterval = 0:dt:tf
24-
t = collect(tinterval)
25-
25+
t = range(0.0, step=0.001, stop=100.0)
2626
# But in order to compare to [Xiang2015] paper we use their choice instead:
27-
h = 0.01
28-
M = 300
29-
tstart = 0.0
30-
tstop = tstart + M * h
31-
tinterval_Xiang2015 = 0:h:tstop
32-
t_Xiang2015 = collect(tinterval_Xiang2015)
27+
t_Xiang2015 = range(0.0, step=0.01, length=301)
3328

3429
# Initial position in space
35-
r0 = [0.1; 0.0; 0.0]
30+
r0 = [0.1, 0.0, 0.0]
3631

3732
# Constants sigma, rho and beta. In a real ODE problem these would not be known
3833
# and would be estimated.
@@ -42,65 +37,80 @@ beta = 8.0/3.0
4237
real_params = [sigma, rho, beta]
4338

4439
# "Play" an ODE from a starting point and into the future given a sequence of time steps.
45-
function calc_state_vectors(params::Vector{Float64}, odefunc::Function,
46-
startx::Vector{Float64}, times::Vector{Float64}; states = nothing)
40+
# Put the restults into `states`
41+
function calc_states!(states::AbstractMatrix{Float64},
42+
params::AbstractVector{Float64}, odefunc!::Function,
43+
startx::AbstractVector{Float64}, times::AbstractVector{Float64})
4744

4845
numsamples = length(times)
49-
if states == nothing
50-
states = Array(Float64, length(startx), numsamples)
51-
end
46+
@assert size(states) == (length(startx), numsamples)
5247

5348
states[:, 1] = startx
5449
tprev = times[1]
50+
deriv = similar(startx)
5551
for i in 2:numsamples
56-
@inbounds derivatives = odefunc(params, states[:, (i-1)])
57-
@inbounds tnow = times[i]
58-
@inbounds states[:, i] = states[:, (i-1)] .+ derivatives * (tnow - tprev)
52+
tnow = times[i]
53+
stateprev = view(states, :, i-1)
54+
odefunc!(deriv, params, stateprev)
55+
@inbounds states[:, i] .= stateprev .+ deriv .* (tnow - tprev)
5956
tprev = tnow
6057
end
6158

6259
return states
6360
end
6461

62+
calc_states(params::AbstractVector{Float64}, odefunc!::Function,
63+
startx::AbstractVector{Float64}, times::AbstractVector{Float64}) =
64+
calc_states!(Matrix{Float64}(undef, length(startx), length(times)),
65+
params, odefunc!, startx, times)
66+
6567
# RSS = Residual Sum of Squares, columnwise
66-
function rss(actual::Array{Float64, 2}, estimated::Array{Float64, 2})
67-
M = size(actual, 2)
68-
sumsq = 0.0
69-
for i in 1:M
70-
@inbounds sumsq += sumabs2(actual[:, i] .- estimated[:, i])
71-
end
72-
sumsq
73-
end
68+
rss(actual::AbstractMatrix{Float64}, estimated::AbstractMatrix{Float64}) =
69+
@inbounds sum(i -> abs2(actual[i] - estimated[i]), eachindex(actual, estimated))
7470

7571
# Calculate the actual/original state vectors that we will use for parameter
7672
# estimation:
77-
origstates = calc_state_vectors(real_params, lorentz_equations, r0, t)
78-
origstates_Xiang2015 = calc_state_vectors(real_params, lorentz_equations, r0, t_Xiang2015)
73+
origstates = calc_states(real_params, lorentz_ode!, r0, t)
74+
origstates_Xiang2015 = calc_states(real_params, lorentz_ode!, r0, t_Xiang2015)
7975

80-
function subsample(origstates::Array{Float64, 2}, times::Vector{Float64}, lenratio = 0.25)
81-
@assert size(origstates, 2) == length(times)
76+
function subsample(origstates::AbstractMatrix{Float64},
77+
times::AbstractVector{Float64};
78+
lenratio = 0.25)
8279
N = length(times)
80+
@assert size(origstates, 2) == N
8381
stopidx = round(Int, lenratio*N)
8482
indexes = 1:stopidx
8583
return origstates[:, indexes], times[indexes]
8684
end
8785

8886
# The [Xiang2015] paper, https://www.hindawi.com/journals/ddns/2015/740721/,
8987
# used these param bounds:
90-
Xiang2015Bounds = Tuple{Float64, Float64}[(9, 11), (20, 30), (2, 3)]
88+
Xiang2015Bounds = [(9., 11.), (20., 30.), (2., 3.)]
9189

9290
# Now we can optimize using BlackBoxOptim
9391
using BlackBoxOptim
9492

95-
function lorentz_fitness(params::Vector{Float64}, origstates::Array{Float64, 2}, times::Vector{Float64})
96-
states = calc_state_vectors(params, lorentz_equations, r0, times)
93+
# store temporary states for fitness calculation
94+
const tmpstates_pool = Dict{NTuple{2, Int}, Vector{Matrix{Float64}}}()
95+
96+
# get the RSS between the origstates trajectory and ODE solution for given params
97+
function lorentz_fitness(params::AbstractVector{Float64},
98+
origstates::AbstractMatrix{Float64},
99+
times::AbstractVector{Float64})
100+
# get the state matrix from the pool of the proper size
101+
statesize = size(origstates)
102+
states_pool = get!(() -> Vector{Matrix{Float64}}(), tmpstates_pool, statesize)
103+
states = !isempty(states_pool) ? pop!(states_pool) : Matrix{Float64}(undef, statesize...)
104+
# solve ODE for given params
105+
calc_states!(states, params, lorentz_ode!, r0, times)
106+
push!(states_pool, states) # return states matrix back to the pool
97107
return rss(origstates, states)
98108
end
99109

100110
# But optimizing all states in each optimization step is too much so lets
101111
# sample a small subset and use for first opt iteration.
102-
origstates1, times1 = subsample(origstates, t, 0.04); # Sample only first 4%
103-
origstates1_Xiang2015, times1_Xiang2015 = subsample(origstates_Xiang2015, t_Xiang2015, 1.00);
112+
origstates1, times1 = subsample(origstates, t; lenratio=0.04); # Sample only first 4%
113+
origstates1_Xiang2015, times1_Xiang2015 = subsample(origstates_Xiang2015, t_Xiang2015; lenratio=1.00);
104114

105115
res1 = bboptimize(params -> lorentz_fitness(params, origstates1, times1);
106116
SearchRange = Xiang2015Bounds, MaxSteps = 8e3)
@@ -109,24 +119,24 @@ res2 = bboptimize(params -> lorentz_fitness(params, origstates1_Xiang2015, times
109119
SearchRange = Xiang2015Bounds, MaxSteps = 11e3) # They allow 12k fitness evals for 3-param estimation
110120

111121
# But lets also try with less tight bounds
112-
LooserBounds = Tuple{Float64, Float64}[(0, 22), (0, 60), (1, 6)]
122+
LooserBounds = [(0., 22.), (0., 60.), (1., 6.)]
113123
res3 = bboptimize(params -> lorentz_fitness(params, origstates1_Xiang2015, times1_Xiang2015);
114124
SearchRange = LooserBounds, MaxSteps = 11e3) # They allow 12k fitness evals for 3-param estimation
115125

116-
println("Results on the long time sequence from Paulo Marques:")
126+
@info "Results on the long time sequence from Paulo Marques:"
117127
estfitness = lorentz_fitness(best_candidate(res1), origstates, t)
118-
@show (estfitness, best_candidate(res1), best_fitness(res1))
128+
@show estfitness best_candidate(res1) best_fitness(res1)
119129
origfitness = lorentz_fitness(real_params, origstates, t)
120-
@show (origfitness, real_params)
130+
@show origfitness real_params
121131

122-
println("Results on the short time sequence used in [Xiang2015] paper:")
132+
@info "Results on the short time sequence used in [Xiang2015] paper:"
123133
estfitness = lorentz_fitness(best_candidate(res2), origstates_Xiang2015, t_Xiang2015)
124-
@show (estfitness, best_candidate(res2), best_fitness(res2))
134+
@show estfitness best_candidate(res2) best_fitness(res2)
125135
origfitness = lorentz_fitness(real_params, origstates_Xiang2015, t_Xiang2015)
126-
@show (origfitness, real_params)
136+
@show origfitness real_params
127137

128-
println("Results on the short time sequence used in [Xiang2015] paper, but with looser bounds:")
138+
@info "Results on the short time sequence used in [Xiang2015] paper, but with looser bounds:"
129139
estfitness = lorentz_fitness(best_candidate(res3), origstates_Xiang2015, t_Xiang2015)
130-
@show (estfitness, best_candidate(res3), best_fitness(res3))
140+
@show estfitness best_candidate(res3) best_fitness(res3)
131141
origfitness = lorentz_fitness(real_params, origstates_Xiang2015, t_Xiang2015)
132-
@show (origfitness, real_params)
142+
@show origfitness real_params

0 commit comments

Comments
 (0)