Skip to content

Commit fd08798

Browse files
committed
WIP: use nl_prob
1 parent 4b29bab commit fd08798

File tree

2 files changed

+60
-35
lines changed

2 files changed

+60
-35
lines changed

lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl

+29-12
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@ function initialize!(nlsolver::NLSolver{<:NonlinearSolveAlg, false},
3838
integrator.stats.nnonliniter += cache.cache.stats.nsteps
3939
integrator.stats.njacs += cache.cache.stats.njacs
4040
end
41-
if f isa DAEFunction
42-
nlp_params = (tmp, α, tstep, invγdt, p, dt, uprev, f)
41+
new_prob = __has_nlprob_data(f)
42+
update_nlprob!(cache.prob, u0=z, p=(;dt, γ, inner_tmp, outer_tmp, t, p))
4343
else
44-
nlp_params = (tmp, γ, α, tstep, invγdt, method, p, dt, f)
44+
if f isa DAEFunction
45+
nlp_params = (tmp, α, tstep, invγdt, p, dt, uprev, f)
46+
else
47+
nlp_params = (tmp, γ, α, tstep, invγdt, method, p, dt, f)
48+
end
49+
remake(cache.prob, p = nlp_params, u0 = z)
4550
end
46-
new_prob = remake(cache.prob, p = nlp_params, u0 = z)
4751
cache.cache = init(new_prob, alg.alg)
4852
nothing
4953
end
@@ -63,26 +67,35 @@ function initialize!(nlsolver::NLSolver{<:NonlinearSolveAlg, true},
6367
integrator.stats.nnonliniter += cache.cache.stats.nsteps
6468
integrator.stats.njacs += cache.cache.stats.njacs
6569
end
66-
if f isa DAEFunction
67-
nlp_params = (tmp, ztmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)
70+
71+
new_prob = __has_nlprob_data(f)
72+
update_nlprob!(cache.prob, u0=z, p=(;dt, γ, inner_tmp, outer_tmp, t, p))
6873
else
69-
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, method, p, dt, f)
74+
if f isa DAEFunction
75+
nlp_params = (tmp, ztmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)
76+
else
77+
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, method, p, dt, f)
78+
end
79+
remake(cache.prob, p = nlp_params, u0 = z)
7080
end
71-
new_prob = remake(cache.prob, p = nlp_params, u0 = z)
7281
cache.cache = init(new_prob, alg.alg)
7382
nothing
7483
end
7584

7685
## compute_step!
7786

7887
@muladd function compute_step!(nlsolver::NLSolver{<:NonlinearSolveAlg, false}, integrator)
79-
@unpack uprev, t, p, dt, opts = integrator
88+
@unpack uprev, t, p, dt, opts, f = integrator
8089
@unpack z, tmp, ztmp, γ, α, cache, method = nlsolver
8190
@unpack tstep, invγdt = cache
8291

8392
nlcache = nlsolver.cache.cache
8493
step!(nlcache)
85-
nlsolver.ztmp = nlcache.u
94+
if __has_nlprob_data(f)
95+
ztmp = nlprobmap(nlcache)
96+
else
97+
ztmp = nlcache.u
98+
end
8699

87100
ustep = compute_ustep(tmp, γ, z, method)
88101
atmp = calculate_residuals(nlcache.fu, uprev, ustep, opts.abstol, opts.reltol,
@@ -98,13 +111,17 @@ end
98111
end
99112

100113
@muladd function compute_step!(nlsolver::NLSolver{<:NonlinearSolveAlg, true}, integrator)
101-
@unpack uprev, t, p, dt, opts = integrator
114+
@unpack uprev, t, p, dt, opts, f = integrator
102115
@unpack z, tmp, ztmp, γ, α, cache, method = nlsolver
103116
@unpack tstep, invγdt, atmp, ustep = cache
104117

105118
nlcache = nlsolver.cache.cache
106119
step!(nlcache)
107-
@.. broadcast=false ztmp=nlcache.u
120+
if __has_nlprob_data(f)
121+
@.. ztmp = nlprobmap(nlcache)
122+
else
123+
@.. ztmp=nlcache.u
124+
end
108125

109126
ustep = compute_ustep!(ustep, tmp, γ, z, method)
110127
calculate_residuals!(atmp, nlcache.fu, uprev, ustep, opts.abstol, opts.reltol,

lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl

+31-23
Original file line numberDiff line numberDiff line change
@@ -206,21 +206,25 @@ function build_nlsolver(
206206
if nlalg isa NonlinearSolveAlg
207207
α = tTypeNoUnits(α)
208208
dt = tTypeNoUnits(dt)
209-
if isdae
210-
nlf = (ztmp, z, p) -> begin
211-
tmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f = p
212-
_compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f, z)[1]
213-
end
214-
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)
209+
prob = if __has_nlprob_data(f)
210+
f.nlprob_data.nlprob
215211
else
216-
nlf = (ztmp, z, p) -> begin
217-
tmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f = p
218-
_compute_rhs!(
219-
tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f, z)[1]
212+
if isdae
213+
nlf = (ztmp, z, p) -> begin
214+
tmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f = p
215+
_compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f, z)[1]
216+
end
217+
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)
218+
else
219+
nlf = (ztmp, z, p) -> begin
220+
tmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f = p
221+
_compute_rhs!(
222+
tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f, z)[1]
223+
end
224+
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, DIRK, p, dt, f)
220225
end
221-
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, DIRK, p, dt, f)
226+
NonlinearProblem(NonlinearFunction(nlf), ztmp, nlp_params)
222227
end
223-
prob = NonlinearProblem(NonlinearFunction(nlf), ztmp, nlp_params)
224228
cache = init(prob, nlalg.alg)
225229
nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache)
226230
else
@@ -291,20 +295,24 @@ function build_nlsolver(
291295
if nlalg isa NonlinearSolveAlg
292296
α = tTypeNoUnits(α)
293297
dt = tTypeNoUnits(dt)
294-
if isdae
295-
nlf = (z, p) -> begin
296-
tmp, α, tstep, invγdt, _p, dt, uprev, f = p
297-
_compute_rhs(tmp, α, tstep, invγdt, p, dt, uprev, f, z)[1]
298-
end
299-
nlp_params = (tmp, α, tstep, invγdt, _p, dt, uprev, f)
298+
prob = if __has_nlprob_data(f)
299+
f.nlprob_data.nlprob
300300
else
301-
nlf = (z, p) -> begin
302-
tmp, γ, α, tstep, invγdt, method, _p, dt, f = p
303-
_compute_rhs(tmp, γ, α, tstep, invγdt, method, _p, dt, f, z)[1]
301+
if isdae
302+
nlf = (z, p) -> begin
303+
tmp, α, tstep, invγdt, _p, dt, uprev, f = p
304+
_compute_rhs(tmp, α, tstep, invγdt, p, dt, uprev, f, z)[1]
305+
end
306+
nlp_params = (tmp, α, tstep, invγdt, _p, dt, uprev, f)
307+
else
308+
nlf = (z, p) -> begin
309+
tmp, γ, α, tstep, invγdt, method, _p, dt, f = p
310+
_compute_rhs(tmp, γ, α, tstep, invγdt, method, _p, dt, f, z)[1]
311+
end
312+
nlp_params = (tmp, γ, α, tstep, invγdt, DIRK, p, dt, f)
304313
end
305-
nlp_params = (tmp, γ, α, tstep, invγdt, DIRK, p, dt, f)
314+
NonlinearProblem(NonlinearFunction(nlf), copy(ztmp), nlp_params)
306315
end
307-
prob = NonlinearProblem(NonlinearFunction(nlf), copy(ztmp), nlp_params)
308316
cache = init(prob, nlalg.alg)
309317
nlcache = NonlinearSolveCache(
310318
nothing, tstep, nothing, nothing, invγdt, prob, cache)

0 commit comments

Comments
 (0)