Skip to content

Commit 1656084

Browse files
feat: implement NonlinearProblem and NonlinearFunction for System
1 parent c3365b1 commit 1656084

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

src/problems/compatibility.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ function check_time_dependent(sys::System, T)
2323
end
2424
end
2525

26+
function check_time_independent(sys::System, T)
27+
if is_time_dependent(sys)
28+
throw(SystemCompatibilityError("""
29+
`$T` requires a time-independent system.
30+
"""))
31+
end
32+
end
33+
2634
function check_is_dde(sys::System)
2735
altT = get_noise_eqs(sys) === nothing ? ODEProblem : SDEProblem
2836
if !is_dde(sys)

src/problems/nonlinearproblem.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
@fallback_iip_specialize function SciMLBase.NonlinearFunction{iip, spec}(
2+
sys::System, _d = nothing, u0 = nothing, p = nothing; jac = false,
3+
eval_expression = false, eval_module = @__MODULE__, sparse = false,
4+
checkbounds = false, sparsity = false, analytic = nothing,
5+
simplify = false, cse = true, initialization_data = nothing,
6+
check_compatibility = true, kwargs...) where {iip, spec}
7+
check_complete(sys, NonlinearFunction)
8+
check_compatibility && check_compatible_system(NonlinearFunction, sys)
9+
10+
dvs = unknowns(sys)
11+
ps = parameters(sys)
12+
f = generate_rhs(sys, dvs, ps; expression = Val{false},
13+
eval_expression, eval_module, checkbounds = checkbounds, cse,
14+
kwargs...)
15+
16+
if spec === SciMLBase.FunctionWrapperSpecialize && iip
17+
if u0 === nothing || p === nothing
18+
error("u0, and p must be specified for FunctionWrapperSpecialize on NonlinearFunction.")
19+
end
20+
f = SciMLBase.wrapfun_iip(f, (u0, u0, p))
21+
end
22+
23+
if jac
24+
_jac = generate_jacobian(sys, dvs, ps; expression = Val{false},
25+
simplify, sparse, cse, eval_expression, eval_module, checkbounds, kwargs...)
26+
else
27+
_jac = nothing
28+
end
29+
30+
observedfun = ObservedFunctionCache(
31+
sys; steady_state = false, eval_expression, eval_module, checkbounds, cse)
32+
33+
if length(dvs) == length(equations(sys))
34+
resid_prototype = nothing
35+
else
36+
resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p)
37+
end
38+
39+
if sparse
40+
jac_prototype = similar(calculate_jacobian(sys; sparse), eltype(u0))
41+
else
42+
jac_prototype = nothing
43+
end
44+
45+
NonlinearFunction{iip, spec}(f;
46+
sys = sys,
47+
jac = _jac,
48+
observed = observedfun,
49+
analytic = analytic,
50+
jac_prototype,
51+
resid_prototype,
52+
initialization_data)
53+
end
54+
55+
@fallback_iip_specialize function SciMLBase.NonlinearProblem{iip, spec}(
56+
sys::System, u0map, parammap = SciMLBase.NullParameters();
57+
check_length = true, check_compatibility = true, kwargs...) where {iip, spec}
58+
check_complete(sys, NonlinearProblem)
59+
check_compatibility && check_compatible_system(NonlinearProblem, sys)
60+
61+
f, u0, p = process_SciMLProblem(NonlinearFunction{iip, spec}, sys, u0map, parammap;
62+
check_length, check_compatibility, kwargs...)
63+
64+
kwargs = process_kwargs(sys; kwargs...)
65+
# Call `remake` so it runs initialization if it is trivial
66+
return remake(NonlinearProblem{iip}(
67+
f, u0, p, StandardNonlinearProblem(); kwargs...))
68+
end
69+
70+
function check_compatible_system(
71+
T::Union{Type{NonlinearFunction}, Type{NonlinearProblem}}, sys::System)
72+
check_time_independent(sys, T)
73+
check_not_dde(sys)
74+
check_no_cost(sys, T)
75+
check_no_constraints(sys, T)
76+
check_no_jumps(sys, T)
77+
check_no_noise(sys, T)
78+
end

src/systems/codegen.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ function generate_jacobian(sys::System, dvs = unknowns(sys),
143143
jac = calculate_jacobian(sys; simplify, sparse, dvs)
144144
p = reorder_parameters(sys, ps)
145145
t = get_iv(sys)
146-
if t !== nothing
146+
if t === nothing
147+
wrap_code = (identity, identity)
148+
else
147149
wrap_code = sparse ? assert_jac_length_header(sys) : (identity, identity)
148150
end
149151
res = build_function_wrapper(sys, jac, dvs, p..., t; wrap_code, expression = Val{true},

0 commit comments

Comments
 (0)