Skip to content

CaoTauLeaping Implementation #603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -847,8 +847,102 @@ SKenCarp(;chunk_size=0,autodiff=true,diff_type=Val{:central},

# Jumps

struct TauLeaping <: StochasticDiffEqJumpAdaptiveAlgorithm end
struct CaoTauLeaping <: StochasticDiffEqJumpAdaptiveAlgorithm end
function TauLeaping_docstring(
description::String,
name::String;
references::String = "",
extra_keyword_description::String = "",
extra_keyword_default::String = "")
keyword_default = """
adaptive = true,
""" * "\n" * extra_keyword_default

keyword_default_description = """
- `adaptive`: Boolean to enable/disable adaptive step sizing. When `true`, the step size `τ` is adjusted dynamically based on error estimates or bounds. Defaults to `true`.
""" * "\n" * extra_keyword_description

docstring = """
$description

### Algorithm Type
Stochastic Jump Method

### References
$references

### Keyword Arguments
$keyword_default_description

### Default Values
$keyword_default
"""
return docstring
end

@doc TauLeaping_docstring(
"An explicit tau-leaping method for stochastic jump processes with optional post-leap step size adaptivity. " *
"This algorithm approximates the stochastic simulation algorithm (SSA) by advancing the system state over " *
"a fixed time step `τ` using Poisson-distributed jump counts based on initial propensities. When `adaptive=true`, " *
"it adjusts `τ` dynamically based on post-leap error estimates derived from propensity changes.",
"TauLeaping",
references = """@article{gillespie2001approximate,
title={Approximate accelerated stochastic simulation of chemically reacting systems},
author={Gillespie, Daniel T},
journal={The Journal of Chemical Physics},
volume={115},
number={4},
pages={1716--1733},
year={2001},
publisher={AIP Publishing}}""",
extra_keyword_description = """
- `dtmax`: Maximum allowed step size.
- `dtmin`: Minimum allowed step size.
""",
extra_keyword_default = """
dtmax = 10.0,
dtmin = 1e-6
Comment on lines +897 to +903
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these doing here?

""")
struct TauLeaping <: StochasticDiffEqJumpAdaptiveAlgorithm
adaptive::Bool
end

function TauLeaping(; adaptive=true)
TauLeaping(adaptive)
end

@doc TauLeaping_docstring(
"An adaptive tau-leaping method for stochastic jump processes that selects the step size `τ` prior to each leap " *
"based on bounds on the expected change in state variables. Introduced by Cao et al., this method ensures stability " *
"and accuracy by constraining the relative change in propensities, controlled by the `epsilon` parameter. " *
"When `adaptive=false`, a fixed step size is used.",
"CaoTauLeaping",
references = """@article{cao2006efficient,
title={Efficient step size selection for the tau-leaping simulation method},
author={Cao, Yang and Gillespie, Daniel T and Petzold, Linda R},
journal={The Journal of Chemical Physics},
volume={124},
number={4},
pages={044109},
year={2006},
publisher={AIP Publishing}}""",
extra_keyword_description = """
- `epsilon`: Tolerance parameter controlling the relative change in state variables for step size selection.
- `dtmax`: Maximum allowed step size.
- `dtmin`: Minimum allowed step size.
""",
extra_keyword_default = """
epsilon = 0.03,
dtmax = 10.0,
dtmin = 1e-6
Comment on lines +935 to +936
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these don't belong here

""")
struct CaoTauLeaping <: StochasticDiffEqJumpAdaptiveAlgorithm
adaptive::Bool
epsilon::Float64
end

function CaoTauLeaping(; adaptive=true, epsilon=0.03)
CaoTauLeaping(adaptive, epsilon)
end

################################################################################

Expand Down
36 changes: 25 additions & 11 deletions src/caches/tau_caches.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
struct TauLeapingConstantCache <: StochasticDiffEqConstantCache end

@cache struct TauLeapingCache{uType,rateType} <: StochasticDiffEqMutableCache
@cache mutable struct TauLeapingCache{uType, rateType, jumpRateType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
tmp::uType
rate::rateType
newrate::rateType
EEstcache::rateType
EEstcache::jumpRateType
end

alg_cache(alg::TauLeaping,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} = TauLeapingConstantCache()

function alg_cache(alg::TauLeaping,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
function alg_cache(alg::TauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tmp = zero(u)
rate = zero(jump_rate_prototype)
newrate = zero(jump_rate_prototype)
EEstcache = zero(jump_rate_prototype)
TauLeapingCache(u,uprev,tmp,newrate,EEstcache)
TauLeapingCache(u, uprev, tmp, rate, newrate, EEstcache)
end

alg_cache(alg::CaoTauLeaping,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} = TauLeapingConstantCache()
@cache mutable struct CaoTauLeapingCache{uType, rateType, muType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
tmp::uType
rate::rateType
mu::muType
sigma2::muType
end

function alg_cache(alg::CaoTauLeaping,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
function alg_cache(alg::CaoTauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tmp = zero(u)
TauLeapingCache(u,uprev,tmp,nothing,nothing)
rate = zero(jump_rate_prototype)
mu = zero(u)
sigma2 = zero(u)
CaoTauLeapingCache(u, uprev, tmp, rate, mu, sigma2)
end

struct TauLeapingConstantCache <: StochasticDiffEqConstantCache end
struct CaoTauLeapingConstantCache <: StochasticDiffEqConstantCache end

alg_cache(alg::TauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} = TauLeapingConstantCache()
alg_cache(alg::CaoTauLeaping, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} = CaoTauLeapingConstantCache()
68 changes: 63 additions & 5 deletions src/integrators/stepsize_controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,85 @@ end


function stepsize_controller!(integrator::SDEIntegrator, alg::TauLeaping)
nothing
nothing # Post-leap adjustment happens in perform_step!
end

function step_accept_controller!(integrator::SDEIntegrator, alg::TauLeaping)
if alg.adaptive
integrator.q = min(integrator.opts.gamma / integrator.EEst, integrator.opts.qmax)
return integrator.dt * integrator.q
else
return integrator.dt
end
end

function step_reject_controller!(integrator::SDEIntegrator, alg::TauLeaping)
if alg.adaptive
integrator.dt = integrator.opts.gamma * integrator.dt / integrator.EEst
end
Comment on lines +35 to +37
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

end


# CaoTauLeaping: Pre-leap τ computation
function stepsize_controller!(integrator::SDEIntegrator, alg::CaoTauLeaping)
nothing
if !alg.adaptive
return
end

@unpack u, p, t, P, opts, c = integrator
cache = integrator.cache

P === nothing && error("CaoTauLeaping requires a JumpProblem with a RegularJump")

# Handle both constant and mutable caches
if isa(cache, CaoTauLeapingConstantCache)
rate = P.cache.rate(u, p, t) # Compute propensities directly
mu = zero(u)
sigma2 = zero(u)
else # CaoTauLeapingCache
@unpack mu, sigma2, rate = cache
P.cache.rate(rate, u, p, t) # Compute propensities into cache
fill!(mu, zero(eltype(mu)))
fill!(sigma2, zero(eltype(sigma2)))
end

# Infer ν_ij using c by applying unit counts for each reaction
num_reactions = length(rate)
ν = zeros(eltype(u), length(u), num_reactions)
unit_counts = zeros(eltype(rate), num_reactions)
for j in 1:num_reactions
unit_counts[j] = 1
c(ν[:, j], u, p, t, unit_counts, nothing) # ν[:, j] is the change vector for reaction j
unit_counts[j] = 0 # Reset
end
Comment on lines +63 to +71
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, but we really should be storing that differently in the jumps. @isaacsas this is worth thinking about.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be worthwhile to make leaping work with individually defined jumps too. Split-step methods require that anyways.

But that wouldn’t solve this issue for general non-mass action jumps.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and we'd need two implementations of each leaping? I'm thinking the interface should maybe require more information in the jump building, under the assumption users will be using something like Catalyst rather than building most jumps by hand. To do some of this optimally we have to shift been a few different representations, where for a user this would be redundant information to the solver it's different views that have different optimal uses.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For methods that work on the full vector at once, can’t we can manually build the needed function from the list of jumps? So those methods can stick with the current interface, but then we could support methods that really need jump-by-jump rate/affect access too.

But yes, Catalyst should build the needed inputs for different representations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For methods that work on the full vector at once, can’t we can manually build the needed function from the list of jumps?

You get a lot of compile time trying to inline all of the small functions into the large one. Also there's a lack of fusion. So it's quite slow to do it that way.

I think the point is, some times you need jump-by-jump information, and sometimes you want the aggregate information (all tau leap steppers will use an aggregate function a lot of the time), and those are the same information in two very different computational representations, and building one from the other isn't easy to do from the implementation but it is trivial to do symbolically.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, so it should be preferred that methods that work on the full vector get a single input function, but having a fallback would be useful. Catalyst/MTK, should just generate both inputs (or whichever is most appropriate for the selected integrator -- we can add a function to query the preferred input type).


# Compute μ_i and σ_i^2
for i in eachindex(u)
for j in 1:num_reactions
ν_ij = ν[i, j]
mu[i] += ν_ij * rate[j]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formula uses the derivative of the rate function, not the rate itself.

sigma2[i] += ν_ij^2 * rate[j]
end
end

# Compute τ per species
ϵ = alg.epsilon
τ_vals = similar(u, Float64)
for i in eachindex(u)
max_term = max(ϵ * u[i], 1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is supposed to use the sum of rates term, not u[i], also why the 1.0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is missing the g_i factor

τ1 = abs(mu[i]) > 0 ? max_term / abs(mu[i]) : Inf
τ2 = sigma2[i] > 0 ? max_term^2 / sigma2[i] : Inf
τ_vals[i] = min(τ1, τ2)
end

τ = min(minimum(τ_vals), opts.dtmax)
integrator.dt = max(τ, opts.dtmin)
integrator.EEst = 1.0
end

function step_accept_controller!(integrator::SDEIntegrator, alg::CaoTauLeaping)
return integrator.EEst # use EEst for the τ
return integrator.dt
end

function step_reject_controller!(integrator::SDEIntegrator, alg::CaoTauLeaping)
error("CaoTauLeaping should never reject steps")
error("CaoTauLeaping should never reject steps")
end
80 changes: 53 additions & 27 deletions src/perform_step/tau_leaping.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,66 @@
@muladd function perform_step!(integrator,cache::TauLeapingConstantCache)
@unpack t,dt,uprev,u,W,p,P,c = integrator
# Perform Step
@muladd function perform_step!(integrator, cache::TauLeapingConstantCache)
@unpack t, dt, uprev, u, p, P, c = integrator

P === nothing && error("TauLeaping requires a JumpProblem with a RegularJump")
P.dt = dt
tmp = c(uprev, p, t, P.dW, nothing)
integrator.u = uprev .+ tmp

if integrator.opts.adaptive
if integrator.alg isa TauLeaping
oldrate = P.cache.currate
newrate = P.cache.rate(integrator.u,p,t+dt)
EEstcache = @. abs(newrate - oldrate) / max(50integrator.opts.reltol*oldrate,integrator.rate_constants/integrator.dt)
integrator.EEst = maximum(EEstcache)
if integrator.EEst <= 1
P.cache.currate = newrate
end
elseif integrator.alg isa CaoTauLeaping
# Calculate τ as EEst
if integrator.alg.adaptive
oldrate = P.cache.currate
newrate = P.cache.rate(integrator.u, p, t + dt)
EEstcache = @. abs(newrate - oldrate) / max(50 * integrator.opts.reltol * oldrate, integrator.rate_constants / integrator.dt)
integrator.EEst = integrator.opts.internalnorm(EEstcache, t)
if integrator.EEst <= 1
P.cache.currate = newrate
end
else
integrator.EEst = 1.0
end
end

@muladd function perform_step!(integrator,cache::TauLeapingCache)
@unpack t,dt,uprev,u,W,p,P,c = integrator
@unpack tmp, newrate, EEstcache = cache
@muladd function perform_step!(integrator, cache::TauLeapingCache)
@unpack t, dt, uprev, u, p, P, c = integrator
@unpack tmp, rate, newrate, EEstcache = cache

P === nothing && error("TauLeaping requires a JumpProblem with a RegularJump")
P.dt = dt
c(tmp, uprev, p, t, P.dW, nothing)
@.. u = uprev + tmp

if integrator.opts.adaptive
if integrator.alg isa TauLeaping
oldrate = P.cache.currate
P.cache.rate(newrate,u,p,t+dt)
@.. EEstcache = abs(newrate - oldrate) / max(50integrator.opts.reltol*oldrate,integrator.rate_constants/integrator.dt)
integrator.EEst = maximum(EEstcache)
if integrator.EEst <= 1
P.cache.currate .= newrate
end
elseif integrator.alg isa CaoTauLeaping
# Calculate τ as EEst
if integrator.alg.adaptive
P.cache.rate(newrate, u, p, t + dt)
P.cache.rate(rate, uprev, p, t)
@.. EEstcache = abs(newrate - rate) / max(50 * integrator.opts.reltol * rate, integrator.rate_constants / integrator.dt)
integrator.EEst = integrator.opts.internalnorm(EEstcache, t)
if integrator.EEst <= 1
P.cache.currate .= newrate
end
else
integrator.EEst = 1.0
end
end

@muladd function perform_step!(integrator, cache::CaoTauLeapingConstantCache)
@unpack t, dt, uprev, u, p, P, c = integrator

P === nothing && error("CaoTauLeaping requires a JumpProblem with a RegularJump")
P.dt = dt
tmp = c(uprev, p, t, P.dW, nothing)
integrator.u = uprev .+ tmp

integrator.EEst = 1.0
end

@muladd function perform_step!(integrator, cache::CaoTauLeapingCache)
@unpack t, dt, uprev, u, p, P, c = integrator
@unpack tmp = cache

P === nothing && error("CaoTauLeaping requires a JumpProblem with a RegularJump")
P.dt = dt
c(tmp, uprev, p, t, P.dW, nothing)
@.. u = uprev + tmp

integrator.EEst = 1.0
end
8 changes: 8 additions & 0 deletions test/tau_leaping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ jump_iipprob = JumpProblem(iip_prob,Direct(),rj)
N = 40_000
sol1 = solve(EnsembleProblem(jump_iipprob),SimpleTauLeaping();dt=1.0,trajectories = N)
sol2 = solve(EnsembleProblem(jump_iipprob),TauLeaping();dt=1.0,adaptive=false,save_everystep=false,trajectories = N)
sol3 = solve(EnsembleProblem(jump_iipprob),CaoTauLeaping();dt=1.0,trajectories = N)

mean1 = mean([sol1[i][end,end] for i in 1:N])
mean2 = mean([sol2[i][end,end] for i in 1:N])
mean3 = mean([sol3[i][end,end] for i in 1:N])
@test mean1 ≈ mean2 rtol=1e-2
@test mean2 ≈ mean3 rtol=1e-2
@test mean1 ≈ mean3 rtol=1e-2

f(du,u,p,t) = (du .= 0)
g(du,u,p,t) = (du .= 0)
Expand Down Expand Up @@ -68,8 +72,12 @@ jump_prob = JumpProblem(prob,Direct(),rj)
sol = solve(jump_prob,TauLeaping(),reltol=5e-2)

sol2 = solve(EnsembleProblem(jump_prob),TauLeaping();dt=1.0,adaptive=false,save_everystep=false,trajectories = N)
sol3 = solve(EnsembleProblem(jump_prob),CaoTauLeaping();dt=1.0,adaptive=false,save_everystep=false,trajectories = N)
mean2 = mean([sol2[i][end,end] for i in 1:N])
mean3 = mean([sol3[i][end,end] for i in 1:N])
@test mean1 ≈ mean2 rtol=1e-2
@test mean2 ≈ mean3 rtol=1e-2
@test mean1 ≈ mean3 rtol=1e-2

foop(u,p,t) = [0.0,0.0,0.0]
goop(u,p,t) = [0.0,0.0,0.0]
Expand Down