Skip to content

Commit a21900f

Browse files
Merge pull request #1186 from SciML/secondorderdual
Fix dual tag ordering in second order
2 parents 323e20f + 073c5b1 commit a21900f

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/second_order.jl

+6-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,15 @@ function _second_order_sensitivities(loss, prob, alg, sensealg::ForwardDiffOverA
88
end
99
end
1010

11+
struct SciMLSensitivityTag end
12+
13+
14+
1115
function _second_order_sensitivity_product(loss, v, prob, alg,
1216
sensealg::ForwardDiffOverAdjoint,
1317
args...; kwargs...)
14-
θ = ForwardDiff.Dual.(prob.p, v)
18+
T = typeof(ForwardDiff.Tag(SciMLSensitivityTag(),eltype(v)))
19+
θ = ForwardDiff.Dual{T,eltype(v),1}.(prob.p, ForwardDiff.Partials.(Tuple.(v)))
1520
_loss = p -> loss(solve(prob, alg, args...; p = p, sensealg = sensealg.adjalg,
1621
kwargs...))
1722
getindex.(ForwardDiff.partials.(Zygote.gradient(_loss, θ)[1]), 1)

test/second_order.jl

+14
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,17 @@ H2v = H * v
3434

3535
@test H H2
3636
@test Hv H2v
37+
38+
function lotka!(du,u,p,t)
39+
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
40+
du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
41+
end
42+
43+
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
44+
prob = ODEProblem(lotka!,u0,(0.0,10.0),p)
45+
loss(sol) = sum(sol)
46+
v = ones(4)
47+
48+
Hv = second_order_sensitivity_product(loss,v,prob,Vern9(),saveat=0.1,abstol=1e-12,reltol=1e-12)
49+
forward_Hv = ForwardDiff.hessian(p -> sum(solve(prob, Vern9(), p=p,saveat=0.1,abstol=1e-12,reltol=1e-12)), p)*v
50+
@test Hv forward_Hv

0 commit comments

Comments
 (0)