Skip to content

Commit b13c3a3

Browse files
Merge pull request #28 from AayushSabharwal/as/doc-example
feat: add doc example for implementing the interface
2 parents 0b7dd1b + e69ac19 commit b13c3a3

File tree

3 files changed

+178
-0
lines changed

3 files changed

+178
-0
lines changed

docs/Project.toml

+6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
4+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
35
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
6+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
47

58
[compat]
69
Documenter = "1"
10+
OrdinaryDiffEqTsit5 = "1.1.0"
11+
SciMLSensitivity = "7.69"
712
SciMLStructures = "1"
13+
Zygote = "0.6.72"

docs/make.jl

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
66
pages = [
77
"Home" => "index.md",
88
"interface.md",
9+
"example.md",
910
"api.md"
1011
]
1112

docs/src/example.md

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# An example implementation of the interface
2+
3+
In this tutorial we will implement the SciMLStructures.jl interface for a parameter
4+
object. This is useful when differentiating through ODE solves using SciMLSensitivity.jl
5+
and only part of the parameters are differentiable.
6+
7+
```@example basic_tutorial
8+
using OrdinaryDiffEqTsit5
9+
using LinearAlgebra
10+
11+
mutable struct SubproblemParameters{P, Q, R}
12+
p::P # tunable
13+
q::Q
14+
r::R
15+
end
16+
17+
mutable struct Parameters{P, C}
18+
subparams::P
19+
coeffs::C # tunable matrix
20+
end
21+
22+
# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
23+
# and `du[length(subparams)+1:end] .= coeffs * u`
24+
function rhs!(du, u, p::Parameters, t)
25+
for (i, subpars) in enumerate(p.subparams)
26+
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
27+
end
28+
N = length(p.subparams)
29+
mul!(view(du, (N+1):(length(du))), p.coeffs, u)
30+
return nothing
31+
end
32+
33+
u = sin.(0.1:0.1:1.0)
34+
subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5]
35+
p = Parameters(subparams, cos.([0.1i+0.33j for i in 1:5, j in 1:10]))
36+
tspan = (0.0, 1.0)
37+
38+
prob = ODEProblem(rhs!, u, tspan, p)
39+
solve(prob, Tsit5())
40+
```
41+
42+
The ODE solves fine. Now let's try to differentiate with respect to the tunable parameters.
43+
44+
```@example basic_tutorial
45+
using Zygote
46+
using SciMLSensitivity
47+
48+
# 5 subparams[i].p, 50 elements in coeffs
49+
function simulate_with_tunables(tunables)
50+
subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
51+
coeffs = reshape(tunables[6:end], size(p.coeffs))
52+
newp = Parameters(subpars, coeffs)
53+
newprob = remake(prob; p = newp)
54+
sol = solve(prob, Tsit5())
55+
return sum(sol.u[end])
56+
end
57+
```
58+
59+
SciMLSensitivity does not know how to handle the parameter object, because it does not
60+
implement the SciMLStructures interface. The bare minimum necessary for SciMLSensitivity
61+
is the `Tunable` portion.
62+
63+
```@example basic_tutorial
64+
import SciMLStructures as SS
65+
66+
# Mark the struct as a SciMLStructure
67+
SS.isscimlstructure(::Parameters) = true
68+
# It is mutable
69+
SS.ismutablescimlstructure(::Parameters) = true
70+
71+
# Only contains `Tunable` portion
72+
# We could also add a `Constants` portion to contain the values that are
73+
# not tunable. The implementation would be similar to this one.
74+
SS.hasportion(::SS.Tunable, ::Parameters) = true
75+
76+
function SS.canonicalize(::SS.Tunable, p::Parameters)
77+
# concatenate all tunable values into a single vector
78+
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))
79+
80+
# repack takes a new vector of the same length as `buffer`, and constructs
81+
# a new `Parameters` object using the values from the new vector for tunables
82+
# and retaining old values for other parameters. This is exactly what replace does,
83+
# so we can use that instead.
84+
repack = let p = p
85+
function repack(newbuffer)
86+
SS.replace(SS.Tunable(), p, newbuffer)
87+
end
88+
end
89+
# the canonicalized vector, the repack function, and a boolean indicating
90+
# whether the buffer aliases values in the parameter object (here, it doesn't)
91+
return buffer, repack, false
92+
end
93+
94+
function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
95+
N = length(p.subparams) + length(p.coeffs)
96+
@assert length(newbuffer) == N
97+
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
98+
coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs))
99+
return Parameters(subparams, coeffs)
100+
end
101+
102+
function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
103+
N = length(p.subparams) + length(p.coeffs)
104+
@assert length(newbuffer) == N
105+
for (subpar, val) in zip(p.subparams, newbuffer)
106+
subpar.p = val
107+
end
108+
copyto!(coeffs, view(newbuffer, (length(p.subparams)+1):length(newbuffer)))
109+
return p
110+
end
111+
```
112+
113+
Now, we should be able to differentiate through the ODE solve.
114+
115+
```@example basic_tutorial
116+
Zygote.gradient(simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1])))
117+
```
118+
119+
We can also implement a `Constants` portion to store the rest of the values:
120+
121+
```@example basic_tutorial
122+
SS.hasportion(::SS.Constants, ::Parameters) = true
123+
124+
function SS.canonicalize(::SS.Constants, p::Parameters)
125+
buffer = mapreduce(vcat, p.subparams) do subpar
126+
[subpar.q, subpar.r]
127+
end
128+
repack = let p = p
129+
function repack(newbuffer)
130+
SS.replace(SS.Constants(), p, newbuffer)
131+
end
132+
end
133+
134+
return buffer, repack, false
135+
end
136+
137+
function SS.replace(::SS.Constants, p::Parameters, newbuffer)
138+
subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i-1], newbuffer[2i]) for i in eachindex(p.subparams)]
139+
return Parameters(subpars, p.coeffs)
140+
end
141+
142+
function SS.replace!(::SS.Constants, p::Parameters, newbuffer)
143+
for i in eachindex(p.subparams)
144+
p.subparams[i].q = newbuffer[2i-1]
145+
p.subparams[i].r = newbuffer[2i]
146+
end
147+
return p
148+
end
149+
150+
buf, repack, alias = SS.canonicalize(SS.Constants(), p)
151+
buf
152+
```
153+
154+
```@example basic_tutorial
155+
repack(ones(length(buf)))
156+
```
157+
158+
```@example basic_tutorial
159+
SS.replace(SS.Constants(), p, ones(length(buf)))
160+
```
161+
162+
```@example basic_tutorial
163+
SS.replace!(SS.Constants(), p, ones(length(buf)))
164+
p
165+
```
166+
167+
In general, all values belonging to a portion should be concatenated into an array of the
168+
appropriate length in `canonicalize`. If a higher dimensional array is part of the portion,
169+
it should be flattened. If a portion contains values of multiple types, a non-concrete
170+
array should be used to store the values. `replace` and `replace!` should assume the array
171+
they receive have the same ordering as the one returned from `canonicalize`.

0 commit comments

Comments
 (0)