Skip to content

Commit 0933fe9

Browse files
Merge pull request #2608 from AayushSabharwal/as/reinit-dae
fix: rerun `initialize_dae!` in `reinit!`
2 parents 5e78c0a + 634c638 commit 0933fe9

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

lib/OrdinaryDiffEqCore/src/integrators/integrator_interface.jl

+7
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ function DiffEqBase.reinit!(integrator::ODEIntegrator, u0 = integrator.sol.prob.
331331
d_discontinuities = integrator.opts.d_discontinuities_cache,
332332
reset_dt = (integrator.dtcache == zero(integrator.dt)) &&
333333
integrator.opts.adaptive,
334+
reinit_dae = true,
334335
reinit_callbacks = true, initialize_save = true,
335336
reinit_cache = true,
336337
reinit_retcode = true)
@@ -406,6 +407,12 @@ function DiffEqBase.reinit!(integrator::ODEIntegrator, u0 = integrator.sol.prob.
406407
auto_dt_reset!(integrator)
407408
end
408409

410+
if reinit_dae &&
411+
(integrator.isdae || SciMLBase.has_initializeprob(integrator.sol.prob.f))
412+
DiffEqBase.initialize_dae!(integrator)
413+
update_uprev!(integrator)
414+
end
415+
409416
if reinit_callbacks
410417
initialize_callbacks!(integrator, initialize_save)
411418
end

test/interface/dae_initialization_tests.jl

+27
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,30 @@ prob = ODEProblem(f, ones(3), (0.0, 1.0))
115115
integrator = init(prob, Rodas5P(),
116116
initializealg = ShampineCollocationInit(1.0, BrokenNLSolve()))
117117
@test all(isequal(reinterpret(Float64, 0xDEADBEEFDEADBEEF)), integrator.u)
118+
119+
@testset "`reinit!` reruns initialization" begin
120+
initializeprob = NonlinearProblem(1.0, [0.0]) do u, p
121+
return u^2 - p[1]^2
122+
end
123+
initializeprobmap = function (nlsol)
124+
return [nlsol.prob.p[1], nlsol.u]
125+
end
126+
update_initializeprob! = function (iprob, integ)
127+
iprob.p[1] = integ.u[1]
128+
end
129+
initialization_data = SciMLBase.OverrideInitData(
130+
initializeprob, update_initializeprob!, initializeprobmap, nothing)
131+
fn = ODEFunction(; mass_matrix = [1 0; 0 0], initialization_data) do du, u, p, t
132+
du[1] = u[1]
133+
du[2] = u[1]^2 - u[2]^2
134+
end
135+
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0))
136+
integ = init(prob, Rodas5P())
137+
@test integ.u[2.0, 2.0] atol=1e-8
138+
reinit!(integ)
139+
@test integ.u[2.0, 2.0] atol=1e-8
140+
@test_nowarn step!(integ, 0.01, true)
141+
reinit!(integ, reinit_dae = false)
142+
@test integ.u [2.0, 0.0]
143+
@test_warn ["dt", "forced below floating point epsilon"] step!(integ, 0.01, true)
144+
end

0 commit comments

Comments
 (0)