Skip to content

Commit b389d0e

Browse files
authored
Merge pull request #411 from SciML/ap/untag
Don't use custom tags
2 parents 9a0bb8d + c60a2b6 commit b389d0e

File tree

3 files changed

+6
-14
lines changed

3 files changed

+6
-14
lines changed

ext/NonlinearSolveNLSolversExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLSolversJL, args...;
3434

3535
if autodiff_concrete === :forwarddiff
3636
fj_scalar = @closure (Jx, x) -> begin
37-
T = typeof(NonlinearSolve.NonlinearSolveTag())
37+
T = typeof(ForwardDiff.Tag(prob.f, eltype(x)))
3838
x_dual = ForwardDiff.Dual{T}(x, one(x))
3939
y = prob.f(x_dual, prob.p)
4040
return ForwardDiff.value(y), ForwardDiff.extract_derivative(T, y)

src/internal/helpers.jl

+1-9
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ function evaluate_f!!(f::NonlinearFunction{iip}, fu, u, p) where {iip}
3030
end
3131

3232
# AutoDiff Selection Functions
33-
struct NonlinearSolveTag end
34-
35-
function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:NonlinearSolveTag, <:T}},
36-
f::F, x::AbstractArray{T}) where {T, F}
37-
return true
38-
end
39-
4033
function get_concrete_forward_ad(
4134
autodiff::Union{ADTypes.AbstractForwardMode, ADTypes.AbstractFiniteDifferencesMode},
4235
prob, sp::Val{test_sparse} = True, args...; kwargs...) where {test_sparse}
@@ -62,8 +55,7 @@ function get_concrete_forward_ad(
6255
ad = if !ForwardDiff.can_dual(eltype(prob.u0)) # Use Finite Differencing
6356
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
6457
else
65-
tag = ForwardDiff.Tag(NonlinearSolveTag(), eltype(prob.u0))
66-
(use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(; tag)
58+
(use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)()
6759
end
6860
return ad
6961
end

src/internal/operators.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
109109
if jvp_autodiff isa AutoForwardDiff || jvp_autodiff isa AutoPolyesterForwardDiff
110110
if iip
111111
# FIXME: Technically we should propagate the tag but ignoring that for now
112-
cache1 = Dual{typeof(ForwardDiff.Tag(NonlinearSolveTag(), eltype(u))),
113-
eltype(u), 1}.(similar(u), ForwardDiff.Partials.(tuple.(u)))
114-
cache2 = Dual{typeof(ForwardDiff.Tag(NonlinearSolveTag(), eltype(fu))),
115-
eltype(fu), 1}.(similar(fu), ForwardDiff.Partials.(tuple.(fu)))
112+
cache1 = Dual{typeof(ForwardDiff.Tag(uf, eltype(u))), eltype(u),
113+
1}.(similar(u), ForwardDiff.Partials.(tuple.(u)))
114+
cache2 = Dual{typeof(ForwardDiff.Tag(uf, eltype(fu))), eltype(fu),
115+
1}.(similar(fu), ForwardDiff.Partials.(tuple.(fu)))
116116
@closure (Jv, v, u, p) -> auto_jacvec!(Jv, uf, u, v, cache1, cache2)
117117
else
118118
@closure (v, u, p) -> auto_jacvec(uf, u, v)

0 commit comments

Comments
 (0)