Skip to content

Commit 63bd4a9

Browse files
committed
Use a different termination norm for NLLS
1 parent e231d64 commit 63bd4a9

File tree

5 files changed

+24
-10
lines changed

5 files changed

+24
-10
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.8.4"
4+
version = "3.9.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -63,7 +63,7 @@ BandedMatrices = "1.4"
6363
BenchmarkTools = "1.4"
6464
ConcreteStructs = "0.2.3"
6565
CUDA = "5.1"
66-
DiffEqBase = "6.146.0"
66+
DiffEqBase = "6.149.0"
6767
Enzyme = "0.11.15"
6868
FastBroadcast = "0.2.8"
6969
FastClosures = "0.3"

src/core/approximate_jacobian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ function SciMLBase.__init(
167167
prob, alg.initialization, alg, f, fu, u, p; linsolve, maxiters, internalnorm)
168168

169169
abstol, reltol, termination_cache = init_termination_cache(
170-
abstol, reltol, fu, u, termination_condition)
170+
prob, abstol, reltol, fu, u, termination_condition)
171171
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)
172172

173173
J = initialization_cache(nothing)

src/core/generalized_first_order.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ function SciMLBase.__init(
156156
linsolve = get_linear_solver(alg.descent)
157157

158158
abstol, reltol, termination_cache = init_termination_cache(
159-
abstol, reltol, fu, u, termination_condition)
159+
prob, abstol, reltol, fu, u, termination_condition)
160160
linsolve_kwargs = merge((; abstol, reltol), linsolve_kwargs)
161161

162162
jac_cache = JacobianCache(

src/core/spectral_methods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
133133
prob, alg.linesearch, prob.f, fu, u, prob.p; maxiters, internalnorm, kwargs...)
134134

135135
abstol, reltol, tc_cache = init_termination_cache(
136-
abstol, reltol, fu, u_cache, termination_condition)
136+
prob, abstol, reltol, fu, u_cache, termination_condition)
137137
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)
138138

139139
if alg.σ_1 === nothing

src/internal/termination.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
1-
function init_termination_cache(abstol, reltol, du, u, ::Nothing)
2-
return init_termination_cache(
3-
abstol, reltol, du, u, AbsSafeBestTerminationMode(; max_stalled_steps = 32))
1+
function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing)
2+
return init_termination_cache(prob, abstol, reltol, du, u,
3+
AbsSafeBestTerminationMode(Base.Fix1(maximum, abs); max_stalled_steps = 32))
44
end
5-
function init_termination_cache(abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
6-
tc_cache = init(du, u, tc; abstol, reltol, use_deprecated_retcodes = Val(false))
5+
function init_termination_cache(
6+
prob::NonlinearLeastSquaresProblem, abstol, reltol, du, u, ::Nothing)
7+
return init_termination_cache(prob, abstol, reltol, du, u,
8+
AbsSafeBestTerminationMode(Base.Fix2(norm, 2); max_stalled_steps = 32))
9+
end
10+
11+
function init_termination_cache(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
12+
abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
13+
tc_ = if hasfield(typeof(tc), :internalnorm) && tc.internalnorm === nothing
14+
internalnorm = ifelse(
15+
prob isa NonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2))
16+
DiffEqBase.set_termination_mode_internalnorm(tc, internalnorm)
17+
else
18+
tc
19+
end
20+
tc_cache = init(du, u, tc_; abstol, reltol, use_deprecated_retcodes = Val(false))
721
return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache
822
end
923

0 commit comments

Comments
 (0)