Skip to content

feat: add doc example for implementing the interface #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Documenter = "1"
OrdinaryDiffEqTsit5 = "1.1.0"
SciMLSensitivity = "7.69"
SciMLStructures = "1"
Zygote = "0.6.72"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
pages = [
"Home" => "index.md",
"interface.md",
"example.md",
"api.md"
]

Expand Down
171 changes: 171 additions & 0 deletions docs/src/example.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# An example implementation of the interface

In this tutorial we will implement the SciMLStructures.jl interface for a parameter
object. This is useful when differentiating through ODE solves using SciMLSensitivity.jl
and only part of the parameters are differentiable.

```@example basic_tutorial
using OrdinaryDiffEqTsit5
using LinearAlgebra

mutable struct SubproblemParameters{P, Q, R}
p::P # tunable
q::Q
r::R
end

mutable struct Parameters{P, C}
subparams::P
coeffs::C # tunable matrix
end

# the rhs is `du[i] = p[i] * u[i]^2 + q[i] * u[i] + r[i] * t` for i in 1:length(subparams)
# and `du[length(subparams)+1:end] .= coeffs * u`
function rhs!(du, u, p::Parameters, t)
for (i, subpars) in enumerate(p.subparams)
du[i] = subpars.p * u[i]^2 + subpars.q * u[i] + subpars.r * t
end
N = length(p.subparams)
mul!(view(du, (N+1):(length(du))), p.coeffs, u)
return nothing
end

u = sin.(0.1:0.1:1.0)
subparams = [SubproblemParameters(0.1i, 0.2i, 0.3i) for i in 1:5]
p = Parameters(subparams, cos.([0.1i+0.33j for i in 1:5, j in 1:10]))
tspan = (0.0, 1.0)

prob = ODEProblem(rhs!, u, tspan, p)
solve(prob, Tsit5())
```

The ODE solves fine. Now let's try to differentiate with respect to the tunable parameters.

```@example basic_tutorial
using Zygote
using SciMLSensitivity

# 5 subparams[i].p, 50 elements in coeffs
function simulate_with_tunables(tunables)
subpars = [SubproblemParameters(tunables[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
coeffs = reshape(tunables[6:end], size(p.coeffs))
newp = Parameters(subpars, coeffs)
newprob = remake(prob; p = newp)
sol = solve(prob, Tsit5())
return sum(sol.u[end])
end
```

SciMLSensitivity does not know how to handle the parameter object, because it does not
implement the SciMLStructures interface. The bare minimum necessary for SciMLSensitivity
is the `Tunable` portion.

```@example basic_tutorial
import SciMLStructures as SS

# Mark the struct as a SciMLStructure
SS.isscimlstructure(::Parameters) = true
# It is mutable
SS.ismutablescimlstructure(::Parameters) = true

# Only contains `Tunable` portion
# We could also add a `Constants` portion to contain the values that are
# not tunable. The implementation would be similar to this one.
SS.hasportion(::SS.Tunable, ::Parameters) = true

function SS.canonicalize(::SS.Tunable, p::Parameters)
# concatenate all tunable values into a single vector
buffer = vcat([subpar.p for subpar in p.subparams], vec(p.coeffs))

# repack takes a new vector of the same length as `buffer`, and constructs
# a new `Parameters` object using the values from the new vector for tunables
# and retaining old values for other parameters. This is exactly what replace does,
# so we can use that instead.
repack = let p = p
function repack(newbuffer)
SS.replace(SS.Tunable(), p, newbuffer)
end
end
# the canonicalized vector, the repack function, and a boolean indicating
# whether the buffer aliases values in the parameter object (here, it doesn't)
return buffer, repack, false
end

function SS.replace(::SS.Tunable, p::Parameters, newbuffer)
N = length(p.subparams) + length(p.coeffs)
@assert length(newbuffer) == N
subparams = [SubproblemParameters(newbuffer[i], subpar.q, subpar.r) for (i, subpar) in enumerate(p.subparams)]
coeffs = reshape(view(newbuffer, (length(p.subparams)+1):length(newbuffer)), size(p.coeffs))
return Parameters(subparams, coeffs)
end

function SS.replace!(::SS.Tunable, p::Parameters, newbuffer)
N = length(p.subparams) + length(p.coeffs)
@assert length(newbuffer) == N
for (subpar, val) in zip(p.subparams, newbuffer)
subpar.p = val
end
copyto!(coeffs, view(newbuffer, (length(p.subparams)+1):length(newbuffer)))
return p
end
```

Now, we should be able to differentiate through the ODE solve.

```@example basic_tutorial
Zygote.gradient(simulate_with_tunables, 0.1ones(length(SS.canonicalize(SS.Tunable(), p)[1])))
```

We can also implement a `Constants` portion to store the rest of the values:

```@example basic_tutorial
SS.hasportion(::SS.Constants, ::Parameters) = true

function SS.canonicalize(::SS.Constants, p::Parameters)
buffer = mapreduce(vcat, p.subparams) do subpar
[subpar.q, subpar.r]
end
repack = let p = p
function repack(newbuffer)
SS.replace(SS.Constants(), p, newbuffer)
end
end

return buffer, repack, false
end

function SS.replace(::SS.Constants, p::Parameters, newbuffer)
subpars = [SubproblemParameters(p.subparams[i].p, newbuffer[2i-1], newbuffer[2i]) for i in eachindex(p.subparams)]
return Parameters(subpars, p.coeffs)
end

function SS.replace!(::SS.Constants, p::Parameters, newbuffer)
for i in eachindex(p.subparams)
p.subparams[i].q = newbuffer[2i-1]
p.subparams[i].r = newbuffer[2i]
end
return p
end

buf, repack, alias = SS.canonicalize(SS.Constants(), p)
buf
```

```@example basic_tutorial
repack(ones(length(buf)))
```

```@example basic_tutorial
SS.replace(SS.Constants(), p, ones(length(buf)))
```

```@example basic_tutorial
SS.replace!(SS.Constants(), p, ones(length(buf)))
p
```

In general, all values belonging to a portion should be concatenated into an array of the
appropriate length in `canonicalize`. If a higher dimensional array is part of the portion,
it should be flattened. If a portion contains values of multiple types, a non-concrete
array should be used to store the values. `replace` and `replace!` should assume the array
they receive have the same ordering as the one returned from `canonicalize`.
Loading