Skip to content

Commit cfa4d58

Browse files
committed
fix tests and nonlinear precs
1 parent 1a76866 commit cfa4d58

File tree

8 files changed

+79
-137
lines changed

8 files changed

+79
-137
lines changed

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ end
107107
end
108108
const TryAgain = SlowConvergence
109109

110-
DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
110+
DEFAULT_PRECS(W, p) = nothing, nothing
111111
isdiscretecache(cache) = false
112112

113113
include("doc_utils.jl")

lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ issuccess_W(::Any) = true
44

55
function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
66
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
7-
A !== nothing && (linsolve.A = A)
87
b !== nothing && (linsolve.b = b)
98
linu !== nothing && (linsolve.u = linu)
109

lib/OrdinaryDiffEqExtrapolation/src/extrapolation_caches.jl

+8-24
Original file line numberDiff line numberDiff line change
@@ -263,18 +263,14 @@ function alg_cache(alg::ImplicitEulerExtrapolation, u, rate_prototype,
263263
linsolve_tmps[i] = zero(rate_prototype)
264264
end
265265

266-
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]); u0 = _vec(k_tmps[1]))
266+
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]), (nothing, u, p, t); u0 = _vec(k_tmps[1]))
267267
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
268-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
269-
#Pr = Diagonal(_vec(weight)))
270268

271269
linsolve = Array{typeof(linsolve1), 1}(undef, Threads.nthreads())
272270
linsolve[1] = linsolve1
273271
for i in 2:Threads.nthreads()
274-
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]); u0 = _vec(k_tmps[i]))
272+
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]), (nothing, u, p, t); u0 = _vec(k_tmps[i]))
275273
linsolve[i] = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
276-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
277-
#Pr = Diagonal(_vec(weight)))
278274
end
279275

280276
res = uEltypeNoUnits.(zero(u))
@@ -1150,18 +1146,14 @@ function alg_cache(alg::ImplicitDeuflhardExtrapolation, u, rate_prototype,
11501146
linsolve_tmps[i] = zero(rate_prototype)
11511147
end
11521148

1153-
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]); u0 = _vec(k_tmps[1]))
1149+
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]), (nothing, u, p, t); u0 = _vec(k_tmps[1]))
11541150
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
1155-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
1156-
#Pr = Diagonal(_vec(weight)))
11571151

11581152
linsolve = Array{typeof(linsolve1), 1}(undef, Threads.nthreads())
11591153
linsolve[1] = linsolve1
11601154
for i in 2:Threads.nthreads()
1161-
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]); u0 = _vec(k_tmps[i]))
1155+
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]), (nothing, u, p, t); u0 = _vec(k_tmps[i]))
11621156
linsolve[i] = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
1163-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
1164-
#Pr = Diagonal(_vec(weight)))
11651157
end
11661158
grad_config = build_grad_config(alg, f, tf, du1, t)
11671159
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, du1, du2)
@@ -1478,18 +1470,14 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation, u, rate_prototype,
14781470
linsolve_tmps[i] = zero(rate_prototype)
14791471
end
14801472

1481-
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]); u0 = _vec(k_tmps[1]))
1473+
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]), (nothing, u, p, t); u0 = _vec(k_tmps[1]))
14821474
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
1483-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
1484-
#Pr = Diagonal(_vec(weight)))
14851475

14861476
linsolve = Array{typeof(linsolve1), 1}(undef, Threads.nthreads())
14871477
linsolve[1] = linsolve1
14881478
for i in 2:Threads.nthreads()
1489-
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]); u0 = _vec(k_tmps[i]))
1479+
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]), (nothing, u, p, t); u0 = _vec(k_tmps[i]))
14901480
linsolve[i] = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
1491-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
1492-
#Pr = Diagonal(_vec(weight)))
14931481
end
14941482
grad_config = build_grad_config(alg, f, tf, du1, t)
14951483
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, du1, du2)
@@ -1674,18 +1662,14 @@ function alg_cache(alg::ImplicitEulerBarycentricExtrapolation, u, rate_prototype
16741662
linsolve_tmps[i] = zero(rate_prototype)
16751663
end
16761664

1677-
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]); u0 = _vec(k_tmps[1]))
1665+
linprob = LinearProblem(W[1], _vec(linsolve_tmps[1]), (nothing, u, p, t); u0 = _vec(k_tmps[1]))
16781666
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
1679-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
1680-
#Pr = Diagonal(_vec(weight)))
16811667

16821668
linsolve = Array{typeof(linsolve1), 1}(undef, Threads.nthreads())
16831669
linsolve[1] = linsolve1
16841670
for i in 2:Threads.nthreads()
1685-
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]); u0 = _vec(k_tmps[i]))
1671+
linprob = LinearProblem(W[i], _vec(linsolve_tmps[i]), (nothing, u, p, t); u0 = _vec(k_tmps[i]))
16861672
linsolve[i] = init(linprob, alg.linsolve, alias_A = true, alias_b = true)
1687-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
1688-
#Pr = Diagonal(_vec(weight)))
16891673
end
16901674
grad_config = build_grad_config(alg, f, tf, du1, t)
16911675
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, du1, du2)

lib/OrdinaryDiffEqFIRK/src/firk_caches.jl

+6-18
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,9 @@ function alg_cache(alg::RadauIIA3, u, rate_prototype, ::Type{uEltypeNoUnits},
107107
recursivefill!(atmp, false)
108108
jac_config = jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw12)
109109

110-
linprob = LinearProblem(W1, _vec(cubuff); u0 = _vec(dw12))
110+
linprob = LinearProblem(W1, _vec(cubuff), (nothing,u,p,t); u0 = _vec(dw12))
111111
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
112112
assumptions = LinearSolve.OperatorAssumptions(true))
113-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
114-
#Pr = Diagonal(_vec(weight)))
115113

116114
rtol = reltol isa Number ? reltol : zero(reltol)
117115
atol = reltol isa Number ? reltol : zero(reltol)
@@ -252,16 +250,12 @@ function alg_cache(alg::RadauIIA5, u, rate_prototype, ::Type{uEltypeNoUnits},
252250
recursivefill!(atmp, false)
253251
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
254252

255-
linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
253+
linprob = LinearProblem(W1, _vec(ubuff), (nothing,u,p,t); u0 = _vec(dw1))
256254
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
257255
assumptions = LinearSolve.OperatorAssumptions(true))
258-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
259-
#Pr = Diagonal(_vec(weight)))
260-
linprob = LinearProblem(W2, _vec(cubuff); u0 = _vec(dw23))
256+
linprob = LinearProblem(W2, _vec(cubuff), (nothing,u,p,t); u0 = _vec(dw23))
261257
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
262258
assumptions = LinearSolve.OperatorAssumptions(true))
263-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
264-
#Pr = Diagonal(_vec(weight)))
265259

266260
rtol = reltol isa Number ? reltol : zero(reltol)
267261
atol = reltol isa Number ? reltol : zero(reltol)
@@ -441,21 +435,15 @@ function alg_cache(alg::RadauIIA9, u, rate_prototype, ::Type{uEltypeNoUnits},
441435
recursivefill!(atmp, false)
442436
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
443437

444-
linprob = LinearProblem(W1, _vec(ubuff); u0 = _vec(dw1))
438+
linprob = LinearProblem(W1, _vec(ubuff), (nothing,u,p,t); u0 = _vec(dw1))
445439
linsolve1 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
446440
assumptions = LinearSolve.OperatorAssumptions(true))
447-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
448-
#Pr = Diagonal(_vec(weight)))
449-
linprob = LinearProblem(W2, _vec(cubuff1); u0 = _vec(dw23))
441+
linprob = LinearProblem(W2, _vec(cubuff1), (nothing,u,p,t); u0 = _vec(dw23))
450442
linsolve2 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
451443
assumptions = LinearSolve.OperatorAssumptions(true))
452-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
453-
#Pr = Diagonal(_vec(weight)))
454-
linprob = LinearProblem(W3, _vec(cubuff2); u0 = _vec(dw45))
444+
linprob = LinearProblem(W3, _vec(cubuff2), (nothing,u,p,t); u0 = _vec(dw45))
455445
linsolve3 = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
456446
assumptions = LinearSolve.OperatorAssumptions(true))
457-
#Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
458-
#Pr = Diagonal(_vec(weight)))
459447

460448
rtol = reltol isa Number ? reltol : zero(reltol)
461449
atol = reltol isa Number ? reltol : zero(reltol)

lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl

-5
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,7 @@ function build_nlsolver(
192192
jac_config = build_jac_config(alg, nf, uf, du1, uprev, u, ztmp, dz)
193193
end
194194
linprob = LinearProblem(W, _vec(k), (isdae ? du1 : nothing,u,p,t); u0 = _vec(dz))
195-
Pl, Pr = wrapprecs(
196-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
197-
nothing)...,
198-
weight, dz)
199195
linsolve = init(linprob, alg.linsolve, (isdae ? du1 : nothing,u,p,t); alias_A = true, alias_b = true,
200-
Pl = Pl, Pr = Pr,
201196
assumptions = LinearSolve.OperatorAssumptions(true))
202197

203198
tType = typeof(t)

test/interface/linear_nonlinear_tests.jl

+29-40
Original file line numberDiff line numberDiff line change
@@ -8,33 +8,22 @@ end
88
u0 = rand(3)
99
prob = ODEProblem(rn, u0, (0, 50.0))
1010

11-
function precsl(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
12-
if newW === nothing || newW
13-
Pl = lu(convert(AbstractMatrix, W), check = false)
14-
else
15-
Pl = Plprev
16-
end
17-
Pl, nothing
11+
function precsl(W, p)
12+
Pl = lu(convert(AbstractMatrix, W), check = false)
13+
Pl, IdentityOperator(size(W, 1))
1814
end
1915

20-
function precsr(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
21-
if newW === nothing || newW
22-
Pr = lu(convert(AbstractMatrix, W), check = false)
23-
else
24-
Pr = Prprev
25-
end
26-
nothing, Pr
16+
function precsr(W, p)
17+
Pr = lu(convert(AbstractMatrix, W), check = false)
18+
IdentityOperator(size(W, 1)), Pr
2719
end
2820

29-
function precslr(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
30-
if newW === nothing || newW
31-
Pr = lu(convert(AbstractMatrix, W), check = false)
32-
else
33-
Pr = Prprev
34-
end
21+
function precslr(W, p)
22+
Pr = lu(convert(AbstractMatrix, W), check = false)
3523
Pr, Pr
3624
end
3725

26+
3827
sol = @test_nowarn solve(prob, TRBDF2(autodiff = false));
3928
@test length(sol.t) < 20
4029
sol = @test_nowarn solve(prob,
@@ -45,29 +34,29 @@ solref = @test_nowarn solve(prob,
4534
smooth_est = false));
4635
@test length(sol.t) < 20
4736
sol = @test_nowarn solve(prob,
48-
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
49-
precs = precsl, smooth_est = false, concrete_jac = true));
37+
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precsl),
38+
smooth_est = false, concrete_jac = true));
5039
@test length(sol.t) < 20
5140
sol = @test_nowarn solve(prob,
52-
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
53-
precs = precsr, smooth_est = false, concrete_jac = true));
41+
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precsr),
42+
smooth_est = false, concrete_jac = true));
5443
@test length(sol.t) < 20
5544
sol = @test_nowarn solve(prob,
56-
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
57-
precs = precslr, smooth_est = false, concrete_jac = true));
45+
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precslr)
46+
, smooth_est = false, concrete_jac = true));
5847
@test length(sol.t) < 20
5948
sol = @test_nowarn solve(prob,
6049
QNDF(autodiff = false, linsolve = KrylovJL_GMRES(),
6150
concrete_jac = true));
6251
@test length(sol.t) < 25
6352
sol = @test_nowarn solve(prob,
6453
Rosenbrock23(autodiff = false,
65-
linsolve = KrylovJL_GMRES(),
66-
precs = precslr, concrete_jac = true));
54+
linsolve = KrylovJL_GMRES(precs = precslr),
55+
concrete_jac = true));
6756
@test length(sol.t) < 20
6857
sol = @test_nowarn solve(prob,
69-
Rodas4(autodiff = false, linsolve = KrylovJL_GMRES(),
70-
precs = precslr, concrete_jac = true));
58+
Rodas4(autodiff = false, linsolve = KrylovJL_GMRES(precs = precslr),
59+
concrete_jac = true));
7160
@test length(sol.t) < 20
7261

7362
sol = @test_nowarn solve(prob, TRBDF2(autodiff = false));
@@ -79,26 +68,26 @@ sol = @test_nowarn solve(prob,
7968
smooth_est = false));
8069
@test length(sol.t) < 20
8170
sol = @test_nowarn solve(prob,
82-
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
83-
precs = precsl, smooth_est = false, concrete_jac = true));
71+
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precsl),
72+
smooth_est = false, concrete_jac = true));
8473
@test length(sol.t) < 20
8574
sol = @test_nowarn solve(prob,
86-
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
87-
precs = precsr, smooth_est = false, concrete_jac = true));
75+
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precsr),
76+
smooth_est = false, concrete_jac = true));
8877
@test length(sol.t) < 20
8978
sol = @test_nowarn solve(prob,
90-
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(),
91-
precs = precslr, smooth_est = false, concrete_jac = true));
79+
TRBDF2(autodiff = false, linsolve = KrylovJL_GMRES(precs = precslr),
80+
smooth_est = false, concrete_jac = true));
9281
@test length(sol.t) < 20
9382
sol = @test_nowarn solve(prob,
9483
QNDF(autodiff = false, linsolve = KrylovJL_GMRES(),
9584
concrete_jac = true));
9685
@test length(sol.t) < 25
9786
sol = @test_nowarn solve(prob,
98-
Rosenbrock23(autodiff = false, linsolve = KrylovJL_GMRES(),
99-
precs = precslr, concrete_jac = true));
87+
Rosenbrock23(autodiff = false, linsolve = KrylovJL_GMRES(precs = precslr),
88+
concrete_jac = true));
10089
@test length(sol.t) < 20
10190
sol = @test_nowarn solve(prob,
102-
Rodas4(autodiff = false, linsolve = KrylovJL_GMRES(),
103-
precs = precslr, concrete_jac = true));
91+
Rodas4(autodiff = false, linsolve = KrylovJL_GMRES(precs = precslr),
92+
concrete_jac = true));
10493
@test length(sol.t) < 20

test/interface/linear_solver_test.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ refsol = solve(probiip, FBDF(), abstol = 1e-12, reltol = 1e-12)
197197
@testset "$solname" for (solname, solver) in pairs(solvers)
198198
sol = solve(prob, solver, abstol = 1e-12, reltol = 1e-12, maxiters = 2e4)
199199
@test sol.retcode == ReturnCode.Success
200-
@test isapprox(sol.u[end], refsol.u[end], rtol = 1e-8, atol = 1e-10)
200+
@test isapprox(sol.u[end], refsol.u[end], rtol = 2e-8, atol = 1e-10)
201201
end
202202
end
203203
end
@@ -207,7 +207,7 @@ end
207207
@testset "$solname" for (solname, solver) in pairs(solvers)
208208
sol = solve(prob, solver, maxiters = 2e4)
209209
@test sol.retcode == ReturnCode.Success
210-
@test isapprox(sol.u[end], refsol.u[end], rtol = 2e-3, atol = 1e-6)
210+
@test isapprox(sol.u[end], refsol.u[end], rtol = 5e-3, atol = 1e-6)
211211
end
212212
end
213213
end

0 commit comments

Comments
 (0)