Skip to content

Commit afaf14a

Browse files
Merge pull request #1193 from SciML/dg/core1_init
Run `CheckInit` after Initialization and Fix Core1 tests
2 parents 5cc0724 + 93d97c8 commit afaf14a

File tree

6 files changed

+35
-28
lines changed

6 files changed

+35
-28
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ LinearAlgebra = "1.10"
7777
LinearSolve = "2, 3"
7878
Lux = "1"
7979
Markdown = "1.10"
80-
ModelingToolkit = "9.74"
80+
ModelingToolkit = "9.78"
8181
ModelingToolkitStandardLibrary = "2"
8282
Mooncake = "0.4.52"
8383
NLsolve = "4.5.1"

src/SciMLSensitivity.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ using SciMLBase: SciMLBase, AbstractOverloadingSensitivityAlgorithm,
3333
AbstractNonlinearProblem, AbstractSensitivityAlgorithm,
3434
AbstractDiffEqFunction, AbstractODEFunction, unwrapped_f, CallbackSet,
3535
ContinuousCallback, DESolution, NonlinearFunction, NonlinearProblem,
36-
DiscreteCallback, LinearProblem, ODEFunction, ODEProblem,
36+
DiscreteCallback, LinearProblem, ODEFunction, ODEProblem, DAEProblem,
3737
RODEFunction, RODEProblem, ReturnCode, SDEFunction,
3838
SDEProblem, VectorContinuousCallback, deleteat!,
3939
get_tmp_cache, has_adjoint, isinplace, reinit!, remake,

src/adjoint_common.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -761,4 +761,4 @@ if !hasmethod(Zygote.adjoint,
761761
end
762762
sol.u, solu_adjoint
763763
end
764-
end
764+
end

src/concrete_solve.jl

+18-12
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ function DiffEqBase._concrete_solve_adjoint(
408408
end
409409

410410
# Remove callbacks, saveat, etc. from kwargs since it's handled separately
411-
kwargs_fwd = NamedTuple{Base.diff_names(Base._nt_names(values(kwargs)), (:callback,))}(values(kwargs))
411+
kwargs_fwd = NamedTuple{Base.diff_names(Base._nt_names(values(kwargs)), (:callback, :initializealg))}(values(kwargs))
412412

413413
# Capture the callback_adj for the reverse pass and remove both callbacks
414414
kwargs_adj = NamedTuple{
@@ -454,10 +454,11 @@ function DiffEqBase._concrete_solve_adjoint(
454454
end
455455
igs = back(one(iy))[1] .- one(eltype(tunables))
456456

457-
igs, new_u0, new_p, SciMLBase.NoInit()
457+
igs, new_u0, new_p, SciMLBase.CheckInit()
458458
else
459459
nothing, u0, p, initializealg
460460
end
461+
461462
_prob = remake(_prob, u0 = new_u0, p = new_p)
462463

463464
if sensealg isa BacksolveAdjoint
@@ -672,18 +673,20 @@ function DiffEqBase._concrete_solve_adjoint(
672673
else
673674
cb2 = cb
674675
end
675-
if ArrayInterface.ismutable(eltype(state_values(sol)))
676+
677+
if prob isa Union{ODEProblem, DAEProblem}
676678
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts,
677-
dgdu_discrete = df_iip,
678-
sensealg = sensealg,
679-
callback = cb2,
680-
kwargs_init...)
679+
dgdu_discrete = ArrayInterface.ismutable(eltype(state_values(sol))) ? df_iip : df_oop,
680+
sensealg = sensealg,
681+
callback = cb2,
682+
initializealg = BrownFullBasicInit(),
683+
kwargs_init...)
681684
else
682685
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts,
683-
dgdu_discrete = df_oop,
684-
sensealg = sensealg,
685-
callback = cb2,
686-
kwargs_init...)
686+
dgdu_discrete = ArrayInterface.ismutable(eltype(state_values(sol))) ? df_iip : df_oop,
687+
sensealg = sensealg,
688+
callback = cb2,
689+
kwargs_init...)
687690
end
688691

689692
du0 = reshape(du0, size(u0))
@@ -1581,6 +1584,8 @@ function DiffEqBase._concrete_solve_adjoint(
15811584
Array(ybar)
15821585
elseif eltype(ybar) <: AbstractArray
15831586
Array(VectorOfArray(ybar))
1587+
elseif ybar isa Tangent
1588+
Array(VectorOfArray(ybar.u))
15841589
else
15851590
ybar
15861591
end
@@ -1769,7 +1774,8 @@ function DiffEqBase._concrete_solve_adjoint(
17691774
@. _out[_save_idxs] = Δ.u[_save_idxs]
17701775
end
17711776
end
1772-
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df, initializealg = BrownFullBasicInit())
1777+
1778+
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df)
17731779

17741780
dp, Δtunables = if Δ isa AbstractArray || Δ isa Number
17751781
# if Δ isa AbstractArray, the gradients correspond to `u`

test/desauty_dae_mwe.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ desauty_model = create_model()
3636
sys = structural_simplify(desauty_model)
3737

3838

39-
prob = ODEProblem(sys, [], (0.0, 0.1), guesses = [sys.resistor1.v => 1.])
39+
prob = ODEProblem(sys, [sys.resistor1.v => 1.], (0.0, 0.1))
4040
iprob = prob.f.initialization_data.initializeprob
4141
isys = iprob.f.sys
4242

test/mtk.jl

+13-12
Original file line numberDiff line numberDiff line change
@@ -70,34 +70,34 @@ tspan = (0.0, 100.0)
7070
# and with the initialization corrected to satisfy the algebraic equation
7171
prob_incorrectu0 = ODEProblem(sys, u0_incorrect, tspan, p, jac = true, guesses = [w2 => 0.0])
7272
mtkparams_incorrectu0 = SciMLSensitivity.parameter_values(prob_incorrectu0)
73+
test_sol = solve(prob_incorrectu0, Rodas5P(), abstol = 1e-6, reltol = 1e-3)
7374

7475
u0_timedep = [D(x) => 2.0,
7576
x => 1.0,
7677
y => t,
77-
z => 0.0,
78-
w2 => 0.0,]
78+
z => 0.0]
7979
# this ensures that `y => t` is not applied in the adjoint equation
8080
# If the MTK init is called for the reverse, then `y0` in the backwards
8181
# pass will be extremely far off and cause an incorrect gradient
8282
prob_timedepu0 = ODEProblem(sys, u0_timedep, tspan, p, jac = true, guesses = [w2 => 0.0])
8383
mtkparams_timedepu0 = SciMLSensitivity.parameter_values(prob_incorrectu0)
84+
test_sol = solve(prob_timedepu0, Rodas5P(), abstol = 1e-6, reltol = 1e-3)
8485

8586
u0_correct = [D(x) => 2.0,
8687
x => 1.0,
8788
y => 0.0,
88-
z => 0.0,
89-
w2 => -1.0,]
89+
z => 0.0,]
9090
prob_correctu0 = ODEProblem(sys, u0_correct, tspan, p, jac = true, guesses = [w2 => -1.0])
9191
mtkparams_correctu0 = SciMLSensitivity.parameter_values(prob_correctu0)
92-
prob_correctu0.u0[5] = -1.0
93-
92+
test_sol = solve(prob_correctu0, Rodas5P(), abstol = 1e-6, reltol = 1e-3)
9493
u0_overdetermined = [D(x) => 2.0,
9594
x => 1.0,
9695
y => 0.0,
9796
z => 0.0,
9897
w2 => -1.0,]
9998
prob_overdetermined = ODEProblem(sys, u0_overdetermined, tspan, p, jac = true)
10099
mtkparams_overdetermined = SciMLSensitivity.parameter_values(prob_overdetermined)
100+
test_sol = solve(prob_overdetermined, Rodas5P(), abstol = 1e-6, reltol = 1e-3)
101101

102102
sensealg = GaussAdjoint(; autojacvec = SciMLSensitivity.ZygoteVJP())
103103

@@ -115,25 +115,26 @@ setups = [
115115
(prob_correctu0, mtkparams_correctu0, BrownFullBasicInit()),
116116
(prob_correctu0, mtkparams_correctu0, OrdinaryDiffEqCore.DefaultInit()),
117117

118-
(prob_correctu0, mtkparams_correctu0, NoInit()),
118+
(prob_correctu0, mtkparams_correctu0, NoInit()),
119119
(prob_correctu0, mtkparams_correctu0, nothing),
120120

121121
(prob_overdetermined, mtkparams_overdetermined, BrownFullBasicInit()),
122122
(prob_overdetermined, mtkparams_overdetermined, OrdinaryDiffEq.OrdinaryDiffEqCore.DefaultInit()),
123123

124124
(prob_overdetermined, mtkparams_overdetermined, NoInit()),
125125
(prob_overdetermined, mtkparams_overdetermined, nothing),
126-
]
126+
];
127127

128128
grads = map(setups) do setup
129129
prob, ps, init = setup
130130
@show init
131131
u0 = prob.u0
132132
Zygote.gradient(u0, ps) do u0,p
133+
new_prob = remake(prob, u0 = u0, p = p)
133134
if init === nothing
134-
new_sol = solve(prob, Rodas5P(); u0 = u0, p = ps, sensealg, abstol = 1e-6, reltol = 1e-3)
135+
new_sol = solve(new_prob, Rodas5P(); sensealg, abstol = 1e-6, reltol = 1e-3)
135136
else
136-
new_sol = solve(prob, Rodas5P(); u0 = u0, p = ps, initializealg = init, sensealg, abstol = 1e-6, reltol = 1e-3)
137+
new_sol = solve(new_prob, Rodas5P(); initializealg = init, sensealg, abstol = 1e-6, reltol = 1e-3)
137138
end
138139
gt = Zygote.ChainRules.ChainRulesCore.ignore_derivatives() do
139140
@test new_sol.retcode == SciMLBase.ReturnCode.Success
@@ -148,5 +149,5 @@ end
148149

149150
u0grads = getindex.(grads,1)
150151
pgrads = getproperty.(getindex.(grads, 2), (:tunable,))
151-
@test all(x u0grads[1] for x in grads)
152-
@test all(x pgrads[1] for x in grads)
152+
@test all(x u0grads[1] for x in u0grads)
153+
@test all(x pgrads[1] for x in pgrads)

0 commit comments

Comments
 (0)