Skip to content

Commit 67f869d

Browse files
committed
simplify dolinsolve
1 parent f5f1cc4 commit 67f869d

File tree

3 files changed

+11
-130
lines changed

3 files changed

+11
-130
lines changed

lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,19 @@ issuccess_W(W::Number) = !iszero(W)
33
issuccess_W(::Any) = true
44

55
function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
6-
du = nothing, u = nothing, p = nothing, t = nothing,
7-
weight = nothing, solverdata = nothing,
86
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
97
A !== nothing && (linsolve.A = A)
108
b !== nothing && (linsolve.b = b)
119
linu !== nothing && (linsolve.u = linu)
1210

13-
Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
14-
linsolve.Pl
15-
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
16-
linsolve.Pr
17-
1811
_alg = unwrap_alg(integrator, true)
1912

2013
_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
2114
solverdata)
22-
if (_Pl !== nothing || _Pr !== nothing)
23-
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
24-
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
25-
linsolve.Pl = __Pl
26-
linsolve.Pr = __Pr
15+
if !isnothing(A)
16+
(;du, u, p, t) = integrator
17+
p = isnothing(integrator) ? nothing : (du, u, p, t)
18+
reinit!(linsolve; A, p)
2719
end
2820

2921
linres = solve!(linsolve; reltol)

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
9595
linsolve_tmp = zero(rate_prototype)
9696

9797
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
98-
Pl, Pr = wrapprecs(
99-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
100-
nothing)..., weight, tmp)
10198
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
102-
Pl = Pl, Pr = Pr,
10399
assumptions = LinearSolve.OperatorAssumptions(true))
104100

105101
grad_config = build_grad_config(alg, f, tf, du1, t)
@@ -141,11 +137,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
141137
linsolve_tmp = zero(rate_prototype)
142138
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
143139

144-
Pl, Pr = wrapprecs(
145-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
146-
nothing)..., weight, tmp)
147140
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
148-
Pl = Pl, Pr = Pr,
149141
assumptions = LinearSolve.OperatorAssumptions(true))
150142
grad_config = build_grad_config(alg, f, tf, du1, t)
151143
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2, Val(false))
@@ -289,11 +281,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
289281
uf = UJacobianWrapper(f, t, p)
290282
linsolve_tmp = zero(rate_prototype)
291283
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
292-
Pl, Pr = wrapprecs(
293-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
294-
nothing)..., weight, tmp)
295284
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
296-
Pl = Pl, Pr = Pr,
297285
assumptions = LinearSolve.OperatorAssumptions(true))
298286
grad_config = build_grad_config(alg, f, tf, du1, t)
299287
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -375,11 +363,7 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
375363
uf = UJacobianWrapper(f, t, p)
376364
linsolve_tmp = zero(rate_prototype)
377365
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
378-
Pl, Pr = wrapprecs(
379-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
380-
nothing)..., weight, tmp)
381366
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
382-
Pl = Pl, Pr = Pr,
383367
assumptions = LinearSolve.OperatorAssumptions(true))
384368
grad_config = build_grad_config(alg, f, tf, du1, t)
385369
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -568,11 +552,7 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
568552
uf = UJacobianWrapper(f, t, p)
569553
linsolve_tmp = zero(rate_prototype)
570554
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
571-
Pl, Pr = wrapprecs(
572-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
573-
nothing)..., weight, tmp)
574555
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
575-
Pl = Pl, Pr = Pr,
576556
assumptions = LinearSolve.OperatorAssumptions(true))
577557
grad_config = build_grad_config(alg, f, tf, du1, t)
578558
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -612,11 +592,7 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits},
612592
uf = UJacobianWrapper(f, t, p)
613593
linsolve_tmp = zero(rate_prototype)
614594
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
615-
Pl, Pr = wrapprecs(
616-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
617-
nothing)..., weight, tmp)
618595
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
619-
Pl = Pl, Pr = Pr,
620596
assumptions = LinearSolve.OperatorAssumptions(true))
621597
grad_config = build_grad_config(alg, f, tf, du1, t)
622598
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -735,11 +711,7 @@ function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits},
735711
uf = UJacobianWrapper(f, t, p)
736712
linsolve_tmp = zero(rate_prototype)
737713
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
738-
Pl, Pr = wrapprecs(
739-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
740-
nothing)..., weight, tmp)
741714
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
742-
Pl = Pl, Pr = Pr,
743715
assumptions = LinearSolve.OperatorAssumptions(true))
744716
grad_config = build_grad_config(alg, f, tf, du1, t)
745717
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -795,11 +767,7 @@ function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits},
795767
uf = UJacobianWrapper(f, t, p)
796768
linsolve_tmp = zero(rate_prototype)
797769
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
798-
Pl, Pr = wrapprecs(
799-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
800-
nothing)..., weight, tmp)
801770
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
802-
Pl = Pl, Pr = Pr,
803771
assumptions = LinearSolve.OperatorAssumptions(true))
804772
grad_config = build_grad_config(alg, f, tf, du1, t)
805773
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -855,11 +823,7 @@ function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits},
855823
uf = UJacobianWrapper(f, t, p)
856824
linsolve_tmp = zero(rate_prototype)
857825
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
858-
Pl, Pr = wrapprecs(
859-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
860-
nothing)..., weight, tmp)
861826
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
862-
Pl = Pl, Pr = Pr,
863827
assumptions = LinearSolve.OperatorAssumptions(true))
864828
grad_config = build_grad_config(alg, f, tf, du1, t)
865829
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -915,11 +879,7 @@ function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits},
915879
uf = UJacobianWrapper(f, t, p)
916880
linsolve_tmp = zero(rate_prototype)
917881
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
918-
Pl, Pr = wrapprecs(
919-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
920-
nothing)..., weight, tmp)
921882
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
922-
Pl = Pl, Pr = Pr,
923883
assumptions = LinearSolve.OperatorAssumptions(true))
924884
grad_config = build_grad_config(alg, f, tf, du1, t)
925885
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -1032,11 +992,7 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
1032992
uf = UJacobianWrapper(f, t, p)
1033993
linsolve_tmp = zero(rate_prototype)
1034994
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
1035-
Pl, Pr = wrapprecs(
1036-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
1037-
nothing)..., weight, tmp)
1038995
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
1039-
Pl = Pl, Pr = Pr,
1040996
assumptions = LinearSolve.OperatorAssumptions(true))
1041997
grad_config = build_grad_config(alg, f, tf, du1, t)
1042998
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
@@ -1096,11 +1052,7 @@ function alg_cache(
10961052
uf = UJacobianWrapper(f, t, p)
10971053
linsolve_tmp = zero(rate_prototype)
10981054
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
1099-
Pl, Pr = wrapprecs(
1100-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
1101-
nothing)..., weight, tmp)
11021055
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
1103-
Pl = Pl, Pr = Pr,
11041056
assumptions = LinearSolve.OperatorAssumptions(true))
11051057
grad_config = build_grad_config(alg, f, tf, du1, t)
11061058
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,7 @@ end
5050
integrator.opts.abstol, integrator.opts.reltol,
5151
integrator.opts.internalnorm, t)
5252

53-
if repeat_step
54-
linres = dolinsolve(
55-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
56-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
57-
solverdata = (; gamma = γ))
58-
else
59-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
60-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
61-
solverdata = (; gamma = γ))
62-
end
53+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
6354

6455
vecu = _vec(linres.u)
6556
veck₁ = _vec(k₁)
@@ -162,16 +153,7 @@ end
162153
integrator.opts.abstol, integrator.opts.reltol,
163154
integrator.opts.internalnorm, t)
164155

165-
if repeat_step
166-
linres = dolinsolve(
167-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
168-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
169-
solverdata = (; gamma = γ))
170-
else
171-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
172-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
173-
solverdata = (; gamma = γ))
174-
end
156+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
175157

176158
vecu = _vec(linres.u)
177159
veck₁ = _vec(k₁)
@@ -521,16 +503,7 @@ end
521503
integrator.opts.abstol, integrator.opts.reltol,
522504
integrator.opts.internalnorm, t)
523505

524-
if repeat_step
525-
linres = dolinsolve(
526-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
527-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
528-
solverdata = (; gamma = dtgamma))
529-
else
530-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
531-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
532-
solverdata = (; gamma = dtgamma))
533-
end
506+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
534507

535508
vecu = _vec(linres.u)
536509
veck1 = _vec(k1)
@@ -716,16 +689,7 @@ end
716689
integrator.opts.abstol, integrator.opts.reltol,
717690
integrator.opts.internalnorm, t)
718691

719-
if repeat_step
720-
linres = dolinsolve(
721-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
722-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
723-
solverdata = (; gamma = dtgamma))
724-
else
725-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
726-
du = integrator.fsalfirst, u = u, p = p, t = t, weight = weight,
727-
solverdata = (; gamma = dtgamma))
728-
end
692+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
729693

730694
vecu = _vec(linres.u)
731695
veck1 = _vec(k1)
@@ -1024,16 +988,7 @@ end
1024988
integrator.opts.abstol, integrator.opts.reltol,
1025989
integrator.opts.internalnorm, t)
1026990

1027-
if repeat_step
1028-
linres = dolinsolve(
1029-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1030-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1031-
solverdata = (; gamma = dtgamma))
1032-
else
1033-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1034-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1035-
solverdata = (; gamma = dtgamma))
1036-
end
991+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
1037992

1038993
@.. broadcast=false $(_vec(k1))=-linres.u
1039994

@@ -1387,16 +1342,7 @@ end
13871342
integrator.opts.abstol, integrator.opts.reltol,
13881343
integrator.opts.internalnorm, t)
13891344

1390-
if repeat_step
1391-
linres = dolinsolve(
1392-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1393-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1394-
solverdata = (; gamma = dtgamma))
1395-
else
1396-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1397-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1398-
solverdata = (; gamma = dtgamma))
1399-
end
1345+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
14001346

14011347
@.. broadcast=false $(_vec(k1))=-linres.u
14021348

@@ -1790,16 +1736,7 @@ end
17901736
integrator.opts.abstol, integrator.opts.reltol,
17911737
integrator.opts.internalnorm, t)
17921738

1793-
if repeat_step
1794-
linres = dolinsolve(
1795-
integrator, cache.linsolve; A = nothing, b = _vec(linsolve_tmp),
1796-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1797-
solverdata = (; gamma = dtgamma))
1798-
else
1799-
linres = dolinsolve(integrator, cache.linsolve; A = W, b = _vec(linsolve_tmp),
1800-
du = cache.fsalfirst, u = u, p = p, t = t, weight = weight,
1801-
solverdata = (; gamma = dtgamma))
1802-
end
1739+
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
18031740

18041741
vecu = _vec(linres.u)
18051742
veck1 = _vec(k1)

0 commit comments

Comments
 (0)