Skip to content

Commit 9a0bb8d

Browse files
Merge pull request #409 from SciML/ap/scalar_finitediff
Allow FiniteDiff propagation for scalar problems
2 parents e295922 + 0dc75b4 commit 9a0bb8d

File tree

6 files changed

+57
-17
lines changed

6 files changed

+57
-17
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.10.0"
4+
version = "3.10.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/make.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ makedocs(; sitename = "NonlinearSolve.jl",
1919
doctest = false,
2020
linkcheck = true,
2121
linkcheck_ignore = ["https://twitter.com/ChrisRackauckas/status/1544743542094020615",
22-
"https://link.springer.com/article/10.1007/s40096-020-00339-4"],
22+
"https://link.springer.com/article/10.1007/s40096-020-00339-4"],
2323
checkdocs = :exports,
2424
warnonly = [:missing_docs],
2525
plugins = [bib],

src/globalization/line_search.jl

+11-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,17 @@ function __internal_init(
101101
args...; internalnorm::IN = DEFAULT_NORM, kwargs...) where {F, IN}
102102
T = promote_type(eltype(fu), eltype(u))
103103
if u isa Number
104-
grad_op = @closure (u, fu, p) -> last(__value_derivative(Base.Fix2(f, p), u)) * fu
104+
autodiff = get_concrete_forward_ad(alg.autodiff, prob; check_forward_mode = true)
105+
if !(autodiff isa AutoForwardDiff ||
106+
autodiff isa AutoPolyesterForwardDiff ||
107+
autodiff isa AutoFiniteDiff)
108+
autodiff = AutoFiniteDiff()
109+
# Other cases are not properly supported so we fallback to finite differencing
110+
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
111+
Detected $(autodiff). Falling back to AutoFiniteDiff."
112+
end
113+
grad_op = @closure (u, fu, p) -> last(__value_derivative(
114+
autodiff, Base.Fix2(f, p), u)) * fu
105115
else
106116
if SciMLBase.has_jvp(f)
107117
if isinplace(prob)

src/internal/jacobian.jl

+19-4
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,20 @@ function JacobianCache(
9797
J, f, uf, fu, u, p, jac_cache, alg, 0, autodiff, vjp_autodiff, jvp_autodiff)
9898
end
9999

100-
function JacobianCache(prob, alg, f::F, ::Number, u::Number, p; kwargs...) where {F}
100+
function JacobianCache(
101+
prob, alg, f::F, ::Number, u::Number, p; autodiff = nothing, kwargs...) where {F}
101102
uf = JacobianWrapper{false}(f, p)
103+
autodiff = get_concrete_forward_ad(autodiff, prob; check_reverse_mode = false)
104+
if !(autodiff isa AutoForwardDiff ||
105+
autodiff isa AutoPolyesterForwardDiff ||
106+
autodiff isa AutoFiniteDiff)
107+
autodiff = AutoFiniteDiff()
108+
# Other cases are not properly supported so we fallback to finite differencing
109+
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \
110+
Detected $(autodiff). Falling back to AutoFiniteDiff."
111+
end
102112
return JacobianCache{false}(
103-
u, f, uf, u, u, p, nothing, alg, 0, nothing, nothing, nothing)
113+
u, f, uf, u, u, p, nothing, alg, 0, autodiff, nothing, nothing)
104114
end
105115

106116
@inline (cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p)
@@ -115,7 +125,7 @@ function (cache::JacobianCache)(J::JacobianOperator, u, p = cache.p)
115125
end
116126
function (cache::JacobianCache)(::Number, u, p = cache.p) # Scalar
117127
cache.njacs += 1
118-
J = last(__value_derivative(cache.uf, u))
128+
J = last(__value_derivative(cache.autodiff, cache.uf, u))
119129
return J
120130
end
121131
# Compute the Jacobian
@@ -181,12 +191,17 @@ end
181191
end
182192
end
183193

184-
@inline function __value_derivative(f::F, x::R) where {F, R}
194+
@inline function __value_derivative(
195+
::Union{AutoForwardDiff, AutoPolyesterForwardDiff}, f::F, x::R) where {F, R}
185196
T = typeof(ForwardDiff.Tag(f, R))
186197
out = f(ForwardDiff.Dual{T}(x, one(x)))
187198
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
188199
end
189200

201+
@inline function __value_derivative(ad::AutoFiniteDiff, f::F, x::R) where {F, R}
202+
return f(x), FiniteDiff.finite_difference_derivative(f, x, ad.fdtype)
203+
end
204+
190205
@inline function __scalar_jacvec(f::F, x::R, v::V) where {F, R, V}
191206
T = typeof(ForwardDiff.Tag(f, R))
192207
out = f(ForwardDiff.Dual{T}(x, v))

src/internal/operators.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,11 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
6262
elseif SciMLBase.has_vjp(f)
6363
f.vjp
6464
elseif u isa Number # Ignore vjp directives
65-
if ForwardDiff.can_dual(typeof(u))
66-
@closure (v, u, p) -> last(__value_derivative(uf, u)) * v
65+
if ForwardDiff.can_dual(typeof(u)) && (vjp_autodiff === nothing ||
66+
vjp_autodiff isa AutoForwardDiff ||
67+
vjp_autodiff isa AutoPolyesterForwardDiff)
68+
# VJP is same as VJP for scalars
69+
@closure (v, u, p) -> last(__scalar_jacvec(uf, u, v))
6770
else
6871
@closure (v, u, p) -> FiniteDiff.finite_difference_derivative(uf, u) * v
6972
end
@@ -92,8 +95,11 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
9295
elseif SciMLBase.has_jvp(f)
9396
f.jvp
9497
elseif u isa Number # Ignore jvp directives
95-
if ForwardDiff.can_dual(typeof(u))
96-
@closure (v, u, p) -> last(__scalar_jacvec(uf, u, v)) * v
98+
# Only ForwardDiff if user didn't override
99+
if ForwardDiff.can_dual(typeof(u)) && (jvp_autodiff === nothing ||
100+
jvp_autodiff isa AutoForwardDiff ||
101+
jvp_autodiff isa AutoPolyesterForwardDiff)
102+
@closure (v, u, p) -> last(__scalar_jacvec(uf, u, v))
97103
else
98104
@closure (v, u, p) -> FiniteDiff.finite_difference_derivative(uf, u) * v
99105
end

test/misc/polyalg_tests.jl

+15-6
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,22 @@ end
8686
# Uses the `__solve` function
8787
@test_throws MethodError solve(probN; abstol = 1e-9)
8888
@test_throws MethodError solve(probN, RobustMultiNewton(); abstol = 1e-9)
89-
solver = solve(probN, RobustMultiNewton(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
90-
@test SciMLBase.successful_retcode(solver)
91-
solver = solve(
89+
sol = solve(probN, RobustMultiNewton(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
90+
@test SciMLBase.successful_retcode(sol)
91+
sol = solve(
9292
probN, FastShortcutNonlinearPolyalg(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
93-
@test SciMLBase.successful_retcode(solver)
94-
solver = solve(probN, custom_polyalg; abstol = 1e-9)
95-
@test SciMLBase.successful_retcode(solver)
93+
@test SciMLBase.successful_retcode(sol)
94+
sol = solve(probN, custom_polyalg; abstol = 1e-9)
95+
@test SciMLBase.successful_retcode(sol)
96+
97+
quadratic_f(u::Float64, p) = u^2 - p
98+
99+
prob = NonlinearProblem(quadratic_f, 2.0, 4.0)
100+
101+
@test_throws MethodError solve(prob)
102+
@test_throws MethodError solve(prob, RobustMultiNewton())
103+
sol = solve(prob, RobustMultiNewton(; autodiff = AutoFiniteDiff()))
104+
@test SciMLBase.successful_retcode(sol)
96105
end
97106

98107
@testitem "Simple Scalar Problem #187" begin

0 commit comments

Comments
 (0)