Skip to content

Commit 3a99433

Browse files
committed
Fix the tracing for NLLS
1 parent 63bd4a9 commit 3a99433

7 files changed

+69
-52
lines changed

src/core/approximate_jacobian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ function SciMLBase.__init(
206206
update_rule_cache = __internal_init(
207207
prob, alg.update_rule, J, fu, u, du; internalnorm)
208208

209-
trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du;
209+
trace = init_nonlinearsolve_trace(prob, alg, u, fu, ApplyArray(__zero, J), du;
210210
uses_jacobian_inverse = Val(INV), kwargs...)
211211

212212
return ApproximateJacobianSolveCache{INV, GB, iip, maxtime !== nothing}(

src/core/generalized_first_order.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ function SciMLBase.__init(
191191
GB = :LineSearch
192192
end
193193

194-
trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du; kwargs...)
194+
trace = init_nonlinearsolve_trace(prob, alg, u, fu, ApplyArray(__zero, J), du; kwargs...)
195195

196196
return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(
197197
fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,

src/core/spectral_methods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane
134134

135135
abstol, reltol, tc_cache = init_termination_cache(
136136
prob, abstol, reltol, fu, u_cache, termination_condition)
137-
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)
137+
trace = init_nonlinearsolve_trace(prob, alg, u, fu, nothing, du; kwargs...)
138138

139139
if alg.σ_1 === nothing
140140
σ_n = dot(u, u) / dot(u, fu)

src/internal/tracing.jl

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ for Tr in (:TraceMinimal, :TraceWithJacobianConditionNumber, :TraceAll)
5252
end
5353

5454
# NonlinearSolve Tracing Utilities
55-
@concrete struct NonlinearSolveTraceEntry
55+
@concrete struct NonlinearSolveTraceEntry{nType}
5656
iteration::Int
5757
fnorm
5858
stepnorm
@@ -63,19 +63,27 @@ end
6363
δu
6464
end
6565

66-
function __show_top_level(io::IO, entry::NonlinearSolveTraceEntry)
66+
function __show_top_level(io::IO, entry::NonlinearSolveTraceEntry{nType}) where {nType}
6767
if entry.condJ === nothing
6868
@printf io "%-8s %-20s %-20s\n" "----" "-------------" "-----------"
69-
@printf io "%-8s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm"
69+
if nType === :L2
70+
@printf io "%-8s %-20s %-20s\n" "Iter" "f(u) 2-norm" "Step 2-norm"
71+
else
72+
@printf io "%-8s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm"
73+
end
7074
@printf io "%-8s %-20s %-20s\n" "----" "-------------" "-----------"
7175
else
7276
@printf io "%-8s %-20s %-20s %-20s\n" "----" "-------------" "-----------" "-------"
73-
@printf io "%-8s %-20s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm" "cond(J)"
77+
if nType === :L2
78+
@printf io "%-8s %-20s %-20s %-20s\n" "Iter" "f(u) 2-norm" "Step 2-norm" "cond(J)"
79+
else
80+
@printf io "%-8s %-20s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm" "cond(J)"
81+
end
7482
@printf io "%-8s %-20s %-20s %-20s\n" "----" "-------------" "-----------" "-------"
7583
end
7684
end
7785

78-
function Base.show(io::IO, entry::NonlinearSolveTraceEntry)
86+
function Base.show(io::IO, entry::NonlinearSolveTraceEntry{nType}) where {nType}
7987
entry.iteration == 0 && __show_top_level(io, entry)
8088
if entry.iteration < 0
8189
# Special case for final entry
@@ -89,25 +97,32 @@ function Base.show(io::IO, entry::NonlinearSolveTraceEntry)
8997
return nothing
9098
end
9199

92-
function NonlinearSolveTraceEntry(iteration, fu, δu)
93-
return NonlinearSolveTraceEntry(
94-
iteration, norm(fu, Inf), norm(δu, 2), nothing, nothing, nothing, nothing, nothing)
100+
function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu)
101+
nType = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
102+
fnorm = prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
103+
return NonlinearSolveTraceEntry{nType}(
104+
iteration, fnorm, norm(δu, 2), nothing, nothing, nothing, nothing, nothing)
95105
end
96106

97-
function NonlinearSolveTraceEntry(iteration, fu, δu, J)
98-
return NonlinearSolveTraceEntry(iteration, norm(fu, Inf), norm(δu, 2),
107+
function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu, J)
108+
nType = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
109+
fnorm = prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
110+
return NonlinearSolveTraceEntry{nType}(iteration, fnorm, norm(δu, 2),
99111
__cond(J), nothing, nothing, nothing, nothing)
100112
end
101113

102-
function NonlinearSolveTraceEntry(iteration, fu, δu, J, u)
103-
return NonlinearSolveTraceEntry(iteration, norm(fu, Inf), norm(δu, 2), __cond(J),
114+
function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu, J, u)
115+
nType = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
116+
fnorm = prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
117+
return NonlinearSolveTraceEntry{nType}(iteration, fnorm, norm(δu, 2), __cond(J),
104118
__copy(J), __copy(u), __copy(fu), __copy(δu))
105119
end
106120

107121
@concrete struct NonlinearSolveTrace{
108122
show_trace, store_trace, Tr <: AbstractNonlinearSolveTraceLevel}
109123
history
110124
trace_level::Tr
125+
prob
111126
end
112127

113128
function reset!(trace::NonlinearSolveTrace)
@@ -123,61 +138,63 @@ function Base.show(io::IO, trace::NonlinearSolveTrace)
123138
return nothing
124139
end
125140

126-
function init_nonlinearsolve_trace(alg, u, fu, J, δu; show_trace::Val = Val(false),
141+
function init_nonlinearsolve_trace(prob, alg, u, fu, J, δu; show_trace::Val = Val(false),
127142
trace_level::AbstractNonlinearSolveTraceLevel = TraceMinimal(),
128143
store_trace::Val = Val(false), uses_jac_inverse = Val(false), kwargs...)
129144
return init_nonlinearsolve_trace(
130-
alg, show_trace, trace_level, store_trace, u, fu, J, δu, uses_jac_inverse)
145+
prob, alg, show_trace, trace_level, store_trace, u, fu, J, δu, uses_jac_inverse)
131146
end
132147

133-
function init_nonlinearsolve_trace(
134-
alg, ::Val{show_trace}, trace_level::AbstractNonlinearSolveTraceLevel,
135-
::Val{store_trace}, u, fu, J, δu,
136-
::Val{uses_jac_inverse}) where {show_trace, store_trace, uses_jac_inverse}
148+
function init_nonlinearsolve_trace(prob::AbstractNonlinearProblem, alg, ::Val{show_trace},
149+
trace_level::AbstractNonlinearSolveTraceLevel, ::Val{store_trace}, u, fu, J,
150+
δu, ::Val{uses_jac_inverse}) where {show_trace, store_trace, uses_jac_inverse}
137151
if show_trace
138152
print("\nAlgorithm: ")
139153
Base.printstyled(alg, "\n\n"; color = :green, bold = true)
140154
end
141155
J_ = uses_jac_inverse ? (trace_level isa TraceMinimal ? J : __safe_inv(J)) : J
142156
history = __init_trace_history(
143-
Val{show_trace}(), trace_level, Val{store_trace}(), u, fu, J_, δu)
144-
return NonlinearSolveTrace{show_trace, store_trace}(history, trace_level)
157+
prob, Val{show_trace}(), trace_level, Val{store_trace}(), u, fu, J_, δu)
158+
return NonlinearSolveTrace{show_trace, store_trace}(history, trace_level, prob)
145159
end
146160

147-
function __init_trace_history(::Val{show_trace}, trace_level, ::Val{store_trace},
148-
u, fu, J, δu) where {show_trace, store_trace}
161+
function __init_trace_history(
162+
prob::AbstractNonlinearProblem, ::Val{show_trace}, trace_level,
163+
::Val{store_trace}, u, fu, J, δu) where {show_trace, store_trace}
149164
!store_trace && !show_trace && return nothing
150-
entry = __trace_entry(trace_level, 0, u, fu, J, δu)
165+
entry = __trace_entry(prob, trace_level, 0, u, fu, J, δu)
151166
show_trace && show(entry)
152167
store_trace && return NonlinearSolveTraceEntry[entry]
153168
return nothing
154169
end
155170

156-
function __trace_entry(::TraceMinimal, iter, u, fu, J, δu, α = 1)
157-
return NonlinearSolveTraceEntry(iter, fu, δu .* α)
171+
function __trace_entry(prob, ::TraceMinimal, iter, u, fu, J, δu, α = 1)
172+
return NonlinearSolveTraceEntry(prob, iter, fu, δu .* α)
158173
end
159-
function __trace_entry(::TraceWithJacobianConditionNumber, iter, u, fu, J, δu, α = 1)
160-
return NonlinearSolveTraceEntry(iter, fu, δu .* α, J)
174+
function __trace_entry(prob, ::TraceWithJacobianConditionNumber, iter, u, fu, J, δu, α = 1)
175+
return NonlinearSolveTraceEntry(prob, iter, fu, δu .* α, J)
161176
end
162-
function __trace_entry(::TraceAll, iter, u, fu, J, δu, α = 1)
163-
return NonlinearSolveTraceEntry(iter, fu, δu .* α, J, u)
177+
function __trace_entry(prob, ::TraceAll, iter, u, fu, J, δu, α = 1)
178+
return NonlinearSolveTraceEntry(prob, iter, fu, δu .* α, J, u)
164179
end
165180

166181
function update_trace!(trace::NonlinearSolveTrace{ShT, StT}, iter, u, fu, J, δu,
167182
α = 1; last::Val{L} = Val(false)) where {ShT, StT, L}
168183
!StT && !ShT && return nothing
169184

170185
if L
171-
entry = NonlinearSolveTraceEntry(
172-
-1, norm(fu, Inf), NaN32, nothing, nothing, nothing, nothing, nothing)
186+
nType = ifelse(trace.prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
187+
fnorm = trace.prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
188+
entry = NonlinearSolveTraceEntry{nType}(
189+
-1, fnorm, NaN32, nothing, nothing, nothing, nothing, nothing)
173190
ShT && show(entry)
174191
return trace
175192
end
176193

177194
show_now = ShT && (mod1(iter, trace.trace_level.print_frequency) == 1)
178195
store_now = StT && (mod1(iter, trace.trace_level.store_frequency) == 1)
179196
(show_now || store_now) &&
180-
(entry = __trace_entry(trace.trace_level, iter, u, fu, J, δu, α))
197+
(entry = __trace_entry(trace.prob, trace.trace_level, iter, u, fu, J, δu, α))
181198
store_now && push!(trace.history, entry)
182199
show_now && show(entry)
183200
return trace

test/core/forward_ad_tests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ end
7979
gs = abs.(ForwardDiff.derivative(solve_with(Val{mode}(), u0, alg), p))
8080
gs_true = abs.(jacobian_f(u0, p))
8181
if !(isapprox(gs, gs_true, atol = 1e-5))
82-
@show sol.retcode, sol.u
8382
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_gradient=gs true_gradient=gs_true
8483
else
8584
@test abs.(gs)abs.(gs_true) atol=1e-5

test/core/nlls_tests.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ using Reexport
66
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
77
true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]))
88

9-
θ_true = [1.0, 0.1, 2.0, 0.5]
9+
const θ_true = [1.0, 0.1, 2.0, 0.5]
1010

11-
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
11+
const x = [-1.0, -0.5, 0.0, 0.5, 1.0]
1212

13-
y_target = true_function(x, θ_true)
13+
const y_target = true_function(x, θ_true)
1414

1515
function loss_function(θ, p)
1616
= true_function(p, θ)
@@ -23,7 +23,7 @@ function loss_function(resid, θ, p)
2323
return resid
2424
end
2525

26-
θ_init = θ_true .+ randn!(StableRNG(0), similar(θ_true)) * 0.1
26+
const θ_init = θ_true .+ randn!(StableRNG(0), similar(θ_true)) * 0.1
2727

2828
solvers = []
2929
for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES(), KrylovJL_LSMR()]
@@ -56,9 +56,9 @@ end
5656
nlls_problems = [prob_oop, prob_iip]
5757

5858
for prob in nlls_problems, solver in solvers
59-
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
59+
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-6)
6060
@test SciMLBase.successful_retcode(sol)
61-
@test maximum(abs, sol.resid) < 1e-6
61+
@test norm(sol.resid, 2) < 1e-6
6262
end
6363
end
6464

@@ -90,8 +90,9 @@ end
9090
x)]
9191

9292
for prob in probs, solver in solvers
93-
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
94-
@test maximum(abs, sol.resid) < 1e-6
93+
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-6)
94+
@test SciMLBase.successful_retcode(sol)
95+
@test norm(sol.resid, 2) < 1e-6
9596
end
9697
end
9798

test/core/rootfind_tests.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ end
120120
@test all(solve(probN, NewtonRaphson(; autodiff)).u .≈ sqrt(2.0))
121121
end
122122

123-
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
123+
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
124124
u0 in (1.0, [1.0, 1.0])
125125

126126
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -238,7 +238,7 @@ end
238238
end
239239
end
240240

241-
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
241+
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
242242
u0 in (1.0, [1.0, 1.0])
243243

244244
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -324,7 +324,7 @@ end
324324
end
325325
end
326326

327-
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
327+
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
328328
u0 in (1.0, [1.0, 1.0])
329329

330330
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -395,7 +395,7 @@ end
395395
end
396396
end
397397

398-
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
398+
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
399399
u0 in (1.0, [1.0, 1.0])
400400

401401
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -462,7 +462,7 @@ end
462462
sqrt(2.0))
463463
end
464464

465-
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
465+
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
466466
u0 in (1.0, [1.0, 1.0])
467467

468468
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -514,7 +514,7 @@ end
514514
@test nlprob_iterator_interface(quadratic_f, p, Val(false), Broyden()) sqrt.(p)
515515
@test nlprob_iterator_interface(quadratic_f!, p, Val(true), Broyden()) sqrt.(p)
516516

517-
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
517+
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
518518
u0 in (1.0, [1.0, 1.0])
519519

520520
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -563,7 +563,7 @@ end
563563
@test nlprob_iterator_interface(quadratic_f, p, Val(false), Klement()) sqrt.(p)
564564
@test nlprob_iterator_interface(quadratic_f!, p, Val(true), Klement()) sqrt.(p)
565565

566-
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
566+
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
567567
u0 in (1.0, [1.0, 1.0])
568568

569569
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@@ -613,7 +613,7 @@ end
613613
@test nlprob_iterator_interface(
614614
quadratic_f!, p, Val(true), LimitedMemoryBroyden())sqrt.(p) atol=1e-2
615615

616-
@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
616+
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
617617
u0 in (1.0, [1.0, 1.0])
618618

619619
probN = NonlinearProblem(quadratic_f, u0, 2.0)

0 commit comments

Comments
 (0)