Skip to content

Commit 25415c9

Browse files
committed
simplify dolinsolve
1 parent ceae110 commit 25415c9

File tree

21 files changed

+349
-605
lines changed

21 files changed

+349
-605
lines changed

lib/OrdinaryDiffEqBDF/src/algorithms.jl

+39-55
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@ an Adaptive BDF2 Formula and Comparison with The MATLAB Ode15s. Procedia Compute
66
ABDF2: Multistep Method
77
An adaptive order 2 L-stable fixed leading coefficient multistep BDF method.
88
"""
9-
struct ABDF2{CS, AD, F, F2, P, FDT, ST, CJ, K, T, StepLimiter} <:
9+
struct ABDF2{CS, AD, F, F2, FDT, ST, CJ, K, T, StepLimiter} <:
1010
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
1111
linsolve::F
1212
nlsolve::F2
13-
precs::P
1413
κ::K
1514
tol::T
1615
smooth_est::Bool
@@ -20,14 +19,14 @@ struct ABDF2{CS, AD, F, F2, P, FDT, ST, CJ, K, T, StepLimiter} <:
2019
end
2120
function ABDF2(; chunk_size = Val{0}(), autodiff = true, standardtag = Val{true}(),
2221
concrete_jac = nothing, diff_type = Val{:forward},
23-
κ = nothing, tol = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
22+
κ = nothing, tol = nothing, linsolve = nothing,
2423
nlsolve = NLNewton(),
2524
smooth_est = true, extrapolant = :linear,
2625
controller = :Standard, step_limiter! = trivial_limiter!)
2726
ABDF2{
2827
_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
29-
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
30-
typeof(κ), typeof(tol), typeof(step_limiter!)}(linsolve, nlsolve, precs, κ, tol,
28+
diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
29+
typeof(κ), typeof(tol), typeof(step_limiter!)}(linsolve, nlsolve, κ, tol,
3130
smooth_est, extrapolant, controller, step_limiter!)
3231
end
3332

@@ -36,11 +35,10 @@ Uri M. Ascher, Steven J. Ruuth, Brian T. R. Wetton. Implicit-Explicit Methods fo
3635
Dependent Partial Differential Equations. 1995 Society for Industrial and Applied Mathematics
3736
Journal on Numerical Analysis, 32(3), pp 797-823, 1995. doi: https://doi.org/10.1137/0732037
3837
"""
39-
struct SBDF{CS, AD, F, F2, P, FDT, ST, CJ, K, T} <:
38+
struct SBDF{CS, AD, F, F2, FDT, ST, CJ, K, T} <:
4039
OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ}
4140
linsolve::F
4241
nlsolve::F2
43-
precs::P
4442
κ::K
4543
tol::T
4644
extrapolant::Symbol
@@ -50,14 +48,13 @@ end
5048

5149
function SBDF(order; chunk_size = Val{0}(), autodiff = Val{true}(),
5250
standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward},
53-
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
51+
linsolve = nothing, nlsolve = NLNewton(), κ = nothing,
5452
tol = nothing,
5553
extrapolant = :linear, ark = false)
5654
SBDF{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
57-
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
55+
diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
5856
typeof(κ), typeof(tol)}(linsolve,
5957
nlsolve,
60-
precs,
6158
κ,
6259
tol,
6360
extrapolant,
@@ -68,15 +65,14 @@ end
6865
# All keyword form needed for remake
6966
function SBDF(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(),
7067
concrete_jac = nothing, diff_type = Val{:forward},
71-
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
68+
linsolve = nothing, nlsolve = NLNewton(), κ = nothing,
7269
tol = nothing,
7370
extrapolant = :linear,
7471
order, ark = false)
7572
SBDF{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
76-
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
73+
diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
7774
typeof(κ), typeof(tol)}(linsolve,
7875
nlsolve,
79-
precs,
8076
κ,
8177
tol,
8278
extrapolant,
@@ -136,11 +132,10 @@ Optional parameter kappa defaults to Shampine's accuracy-optimal -0.1850.
136132
137133
See also `QNDF`.
138134
"""
139-
struct QNDF1{CS, AD, F, F2, P, FDT, ST, CJ, κType, StepLimiter} <:
135+
struct QNDF1{CS, AD, F, F2, FDT, ST, CJ, κType, StepLimiter} <:
140136
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
141137
linsolve::F
142138
nlsolve::F2
143-
precs::P
144139
extrapolant::Symbol
145140
kappa::κType
146141
controller::Symbol
@@ -149,15 +144,14 @@ end
149144

150145
function QNDF1(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(),
151146
concrete_jac = nothing, diff_type = Val{:forward},
152-
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
147+
linsolve = nothing, nlsolve = NLNewton(),
153148
extrapolant = :linear, kappa = -37 // 200,
154149
controller = :Standard, step_limiter! = trivial_limiter!)
155150
QNDF1{
156151
_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
157-
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
152+
diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
158153
typeof(kappa), typeof(step_limiter!)}(linsolve,
159154
nlsolve,
160-
precs,
161155
extrapolant,
162156
kappa,
163157
controller,
@@ -170,11 +164,10 @@ An adaptive order 2 quasi-constant timestep L-stable numerical differentiation f
170164
171165
See also `QNDF`.
172166
"""
173-
struct QNDF2{CS, AD, F, F2, P, FDT, ST, CJ, κType, StepLimiter} <:
167+
struct QNDF2{CS, AD, F, F2, FDT, ST, CJ, κType, StepLimiter} <:
174168
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
175169
linsolve::F
176170
nlsolve::F2
177-
precs::P
178171
extrapolant::Symbol
179172
kappa::κType
180173
controller::Symbol
@@ -183,15 +176,14 @@ end
183176

184177
function QNDF2(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(),
185178
concrete_jac = nothing, diff_type = Val{:forward},
186-
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
179+
linsolve = nothing, nlsolve = NLNewton(),
187180
extrapolant = :linear, kappa = -1 // 9,
188181
controller = :Standard, step_limiter! = trivial_limiter!)
189182
QNDF2{
190183
_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
191-
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
184+
diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
192185
typeof(kappa), typeof(step_limiter!)}(linsolve,
193186
nlsolve,
194-
precs,
195187
extrapolant,
196188
kappa,
197189
controller,
@@ -214,12 +206,11 @@ year={1997},
214206
publisher={SIAM}
215207
}
216208
"""
217-
struct QNDF{MO, CS, AD, F, F2, P, FDT, ST, CJ, K, T, κType, StepLimiter} <:
209+
struct QNDF{MO, CS, AD, F, F2, FDT, ST, CJ, K, T, κType, StepLimiter} <:
218210
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
219211
max_order::Val{MO}
220212
linsolve::F
221213
nlsolve::F2
222-
precs::P
223214
κ::K
224215
tol::T
225216
extrapolant::Symbol
@@ -231,16 +222,15 @@ end
231222
function QNDF(; max_order::Val{MO} = Val{5}(), chunk_size = Val{0}(),
232223
autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing,
233224
diff_type = Val{:forward},
234-
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
225+
linsolve = nothing, nlsolve = NLNewton(), κ = nothing,
235226
tol = nothing,
236-
extrapolant = :linear, kappa = (
237-
-37 // 200, -1 // 9, -823 // 10000, -83 // 2000, 0 // 1),
227+
extrapolant = :linear, kappa = (-37 // 200, -1 // 9, -823 // 10000, -83 // 2000, 0 // 1),
238228
controller = :Standard, step_limiter! = trivial_limiter!) where {MO}
239229
QNDF{MO, _unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
240-
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
230+
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
241231
_unwrap_val(concrete_jac),
242232
typeof(κ), typeof(tol), typeof(kappa), typeof(step_limiter!)}(
243-
max_order, linsolve, nlsolve, precs, κ, tol,
233+
max_order, linsolve, nlsolve, κ, tol,
244234
extrapolant, kappa, controller, step_limiter!)
245235
end
246236

@@ -251,22 +241,20 @@ MEBDF2: Multistep Method
251241
The second order Modified Extended BDF method, which has improved stability properties over the standard BDF.
252242
Fixed timestep only.
253243
"""
254-
struct MEBDF2{CS, AD, F, F2, P, FDT, ST, CJ} <:
244+
struct MEBDF2{CS, AD, F, F2, FDT, ST, CJ} <:
255245
OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ}
256246
linsolve::F
257247
nlsolve::F2
258-
precs::P
259248
extrapolant::Symbol
260249
end
261250
function MEBDF2(; chunk_size = Val{0}(), autodiff = true, standardtag = Val{true}(),
262251
concrete_jac = nothing, diff_type = Val{:forward},
263-
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
252+
linsolve = nothing, nlsolve = NLNewton(),
264253
extrapolant = :constant)
265254
MEBDF2{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
266-
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
255+
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
267256
_unwrap_val(concrete_jac)}(linsolve,
268257
nlsolve,
269-
precs,
270258
extrapolant)
271259
end
272260

@@ -283,12 +271,11 @@ year={2002},
283271
publisher={Walter de Gruyter GmbH \\& Co. KG}
284272
}
285273
"""
286-
struct FBDF{MO, CS, AD, F, F2, P, FDT, ST, CJ, K, T, StepLimiter} <:
274+
struct FBDF{MO, CS, AD, F, F2, FDT, ST, CJ, K, T, StepLimiter} <:
287275
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
288276
max_order::Val{MO}
289277
linsolve::F
290278
nlsolve::F2
291-
precs::P
292279
κ::K
293280
tol::T
294281
extrapolant::Symbol
@@ -299,14 +286,14 @@ end
299286
function FBDF(; max_order::Val{MO} = Val{5}(), chunk_size = Val{0}(),
300287
autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing,
301288
diff_type = Val{:forward},
302-
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
289+
linsolve = nothing, nlsolve = NLNewton(), κ = nothing,
303290
tol = nothing,
304291
extrapolant = :linear, controller = :Standard, step_limiter! = trivial_limiter!) where {MO}
305292
FBDF{MO, _unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
306-
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
293+
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
307294
_unwrap_val(concrete_jac),
308295
typeof(κ), typeof(tol), typeof(step_limiter!)}(
309-
max_order, linsolve, nlsolve, precs, κ, tol, extrapolant,
296+
max_order, linsolve, nlsolve, κ, tol, extrapolant,
310297
controller, step_limiter!)
311298
end
312299

@@ -390,41 +377,39 @@ See also `SBDF`, `IMEXEuler`.
390377
"""
391378
IMEXEulerARK(; kwargs...) = SBDF(1; ark = true, kwargs...)
392379

393-
struct DImplicitEuler{CS, AD, F, F2, P, FDT, ST, CJ} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
380+
struct DImplicitEuler{CS, AD, F, F2, FDT, ST, CJ} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
394381
linsolve::F
395382
nlsolve::F2
396-
precs::P
397383
extrapolant::Symbol
398384
controller::Symbol
399385
end
400386
function DImplicitEuler(;
401387
chunk_size = Val{0}(), autodiff = true, standardtag = Val{true}(),
402388
concrete_jac = nothing, diff_type = Val{:forward},
403-
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
389+
linsolve = nothing, nlsolve = NLNewton(),
404390
extrapolant = :constant,
405391
controller = :Standard)
406392
DImplicitEuler{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
407-
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
393+
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
408394
_unwrap_val(concrete_jac)}(linsolve,
409-
nlsolve, precs, extrapolant, controller)
395+
nlsolve, extrapolant, controller)
410396
end
411397

412-
struct DABDF2{CS, AD, F, F2, P, FDT, ST, CJ} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
398+
struct DABDF2{CS, AD, F, F2, FDT, ST, CJ} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
413399
linsolve::F
414400
nlsolve::F2
415-
precs::P
416401
extrapolant::Symbol
417402
controller::Symbol
418403
end
419404
function DABDF2(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(),
420405
concrete_jac = nothing, diff_type = Val{:forward},
421-
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
406+
linsolve = nothing, nlsolve = NLNewton(),
422407
extrapolant = :constant,
423408
controller = :Standard)
424409
DABDF2{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
425-
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
410+
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
426411
_unwrap_val(concrete_jac)}(linsolve,
427-
nlsolve, precs, extrapolant, controller)
412+
nlsolve, extrapolant, controller)
428413
end
429414

430415
#=
@@ -441,11 +426,10 @@ DBDF(;chunk_size=Val{0}(),autodiff=Val{true}(), standardtag = Val{true}(), concr
441426
linsolve,nlsolve,precs,extrapolant)
442427
=#
443428

444-
struct DFBDF{MO, CS, AD, F, F2, P, FDT, ST, CJ, K, T} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
429+
struct DFBDF{MO, CS, AD, F, F2, FDT, ST, CJ, K, T} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
445430
max_order::Val{MO}
446431
linsolve::F
447432
nlsolve::F2
448-
precs::P
449433
κ::K
450434
tol::T
451435
extrapolant::Symbol
@@ -454,13 +438,13 @@ end
454438
function DFBDF(; max_order::Val{MO} = Val{5}(), chunk_size = Val{0}(),
455439
autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing,
456440
diff_type = Val{:forward},
457-
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
441+
linsolve = nothing, nlsolve = NLNewton(), κ = nothing,
458442
tol = nothing,
459443
extrapolant = :linear, controller = :Standard) where {MO}
460444
DFBDF{MO, _unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
461-
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
445+
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
462446
_unwrap_val(concrete_jac),
463-
typeof(κ), typeof(tol)}(max_order, linsolve, nlsolve, precs, κ, tol, extrapolant,
447+
typeof(κ), typeof(tol)}(max_order, linsolve, nlsolve, κ, tol, extrapolant,
464448
controller)
465449
end
466450

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ end
106106
end
107107
const TryAgain = SlowConvergence
108108

109-
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
109+
DEFAULT_PRECS(W, p) = nothing, nothing
110110
isdiscretecache(cache) = false
111111

112112
include("doc_utils.jl")

lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, S
2727
using DiffEqBase: TimeGradientWrapper,
2828
UJacobianWrapper, TimeDerivativeWrapper,
2929
UDerivativeWrapper
30-
using SciMLBase: AbstractSciMLOperator
30+
using SciMLBase: AbstractSciMLOperator, DEIntegrator
3131
import OrdinaryDiffEqCore
3232
using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm,
3333
DAEAlgorithm,

lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl

+21-26
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,20 @@ issuccess_W(W::Number) = !iszero(W)
33
issuccess_W(::Any) = true
44

55
function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
6-
du = nothing, u = nothing, p = nothing, t = nothing,
7-
weight = nothing, solverdata = nothing,
86
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
9-
A !== nothing && (linsolve.A = A)
107
b !== nothing && (linsolve.b = b)
118
linu !== nothing && (linsolve.u = linu)
129

13-
Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
14-
linsolve.Pl
15-
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
16-
linsolve.Pr
17-
1810
_alg = unwrap_alg(integrator, true)
19-
20-
_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
21-
solverdata)
22-
if (_Pl !== nothing || _Pr !== nothing)
23-
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
24-
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
25-
linsolve.Pl = __Pl
26-
linsolve.Pr = __Pr
11+
if !isnothing(A)
12+
if integrator isa DEIntegrator
13+
(;u, p, t) = integrator
14+
du = hasproperty(integrator, :du) ? integrator.du : nothing
15+
p = (du, u, p, t)
16+
reinit!(linsolve; A, p)
17+
else
18+
reinit!(linsolve; A)
19+
end
2720
end
2821

2922
linres = solve!(linsolve; reltol)
@@ -44,16 +37,18 @@ function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothi
4437
return linres
4538
end
4639

47-
function wrapprecs(_Pl::Nothing, _Pr::Nothing, weight, u)
48-
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
49-
Pr = Diagonal(_vec(weight))
50-
Pl, Pr
51-
end
52-
53-
function wrapprecs(_Pl, _Pr, weight, u)
54-
Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pl
55-
Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pr
56-
Pl, Pr
40+
function wrapprecs(linsolver, W, weight)
41+
if isnothing(linsolver)
42+
linsolver = LinearSolve.defaultalg(W, weight, LinearSolve.OperatorAssumptions(true))
43+
end
44+
if hasproperty(linsolver, :precs) && isnothing(linsolver.precs)
45+
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
46+
Pr = Diagonal(_vec(weight))
47+
precs = Returns((Pl, Pr))
48+
return remake(linsolver; precs)
49+
else
50+
return linsolver
51+
end
5752
end
5853

5954
Base.resize!(p::LinearSolve.LinearCache, i) = p

0 commit comments

Comments
 (0)