Skip to content

Commit 8c652af

Browse files
feat: implement BVProblem for System
1 parent e679e0c commit 8c652af

File tree

3 files changed

+95
-0
lines changed

3 files changed

+95
-0
lines changed

src/problems/bvproblem.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
@fallback_iip_specialize function SciMLBase.BVProblem{iip, spec}(
2+
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
3+
check_compatibility = true, cse = true, checkbounds = false, eval_expression = false,
4+
eval_module = @__MODULE__, guesses = Dict(), kwargs...) where {iip, spec}
5+
check_complete(sys, BVProblem)
6+
check_compatibility && check_compatible_system(BVProblem, sys)
7+
8+
# ODESystems without algebraic equations should use both fixed values + guesses
9+
# for initialization.
10+
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
11+
fode, u0, p = process_SciMLProblem(
12+
ODEFunction{iip, spec}, sys, _u0map, parammap; guesses,
13+
t = tspan !== nothing ? tspan[1] : tspan, check_compatibility = false, cse, checkbounds,
14+
time_dependent_init = false, kwargs...)
15+
16+
dvs = unknowns(sys)
17+
stidxmap = Dict([v => i for (i, v) in enumerate(dvs)])
18+
u0_idxs = has_alg_eqs(sys) ? collect(1:length(dvs)) : [stidxmap[k] for (k, v) in u0map]
19+
fbc = generate_boundary_conditions(
20+
sys, u0, u0_idxs, tspan; expression = Val{false}, cse, checkbounds)
21+
kwargs = process_kwargs(sys; kwargs...)
22+
# Call `remake` so it runs initialization if it is trivial
23+
return remake(BVProblem{iip}(fode, fbc, u0, tspan[1], p; kwargs...))
24+
end
25+
26+
function check_compatible_system(T::Union{Type{BVPFunction}, Type{BVProblem}}, sys::System)
27+
check_time_dependent(sys, T)
28+
check_not_dde(sys)
29+
check_no_cost(sys, T)
30+
check_has_constraints(sys, T)
31+
check_no_jumps(sys, T)
32+
check_no_noise(sys, T)
33+
check_is_continuous(sys, T)
34+
end

src/problems/compatibility.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ function check_no_constraints(sys::System, T)
6969
end
7070
end
7171

72+
function check_has_constraints(sys::System, T)
73+
if isempty(constraints(sys))
74+
throw(SystemCompatibilityError("""
75+
A system without constraints cannot be used to construct a `$T`. Consider an \
76+
`ODEProblem` instead.
77+
"""))
78+
end
79+
end
80+
7281
function check_no_jumps(sys::System, T)
7382
if !isempty(jumps(sys))
7483
throw(SystemCompatibilityError("""

src/systems/codegen.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,55 @@ function isautonomous(sys::System)
317317
tgrad = calculate_tgrad(sys; simplify = true)
318318
all(iszero, tgrad)
319319
end
320+
321+
function get_bv_solution_symbol(ns)
322+
only(@variables BV_SOLUTION(..)[1:ns])
323+
end
324+
325+
function get_constraint_unknown_subs!(subs::Dict, cons::Vector, stidxmap::Dict, iv, sol)
326+
vs = vars(cons)
327+
for v in vs
328+
iscall(v) || continue
329+
op = operation(v)
330+
args = arguments(v)
331+
issym(op) && length(args) == 1 || continue
332+
newv = op(iv)
333+
haskey(stidxmap, newv) || continue
334+
subs[v] = sol(args[1])[stidxmap[newv]]
335+
end
336+
end
337+
338+
function generate_boundary_conditions(sys::System, u0, u0_idxs, t0; expression = Val{true},
339+
eval_expression = false, eval_module = @__MODULE__, kwargs...)
340+
iv = get_iv(sys)
341+
sts = unknowns(sys)
342+
ps = parameters(sys)
343+
np = length(ps)
344+
ns = length(sts)
345+
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
346+
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
347+
348+
sol = get_bv_solution_symbol(ns)
349+
350+
cons = [con.lhs - con.rhs for con in constraints(sys)]
351+
conssubs = Dict()
352+
get_constraint_unknown_subs!(conssubs, cons, stidxmap, iv, sol)
353+
cons = map(x -> fast_substitute(x, conssubs), cons)
354+
355+
init_conds = Any[]
356+
for i in u0_idxs
357+
expr = sol(t0)[i] - u0[i]
358+
push!(init_conds, expr)
359+
end
360+
361+
exprs = vcat(init_conds, cons)
362+
_p = reorder_parameters(sys, ps)
363+
364+
res = build_function_wrapper(sys, exprs, sol, _p..., iv; output_type = Array, kwargs...)
365+
if expression == Val{true}
366+
return res
367+
end
368+
369+
f_oop, f_iip = eval_or_rgf.(res; eval_expression, eval_module)
370+
return GeneratedFunctionWrapper{(2, 3, is_split(sys))}(f_oop, f_iip)
371+
end

0 commit comments

Comments
 (0)