From 40dfbd10f7181b716d3d1c9bd79a448e747454a9 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 28 Dec 2023 08:10:45 -0500 Subject: [PATCH 1/4] Fix save_end overriding behavior Fixes https://github.com/SciML/OrdinaryDiffEq.jl/issues/1842 --- src/integrators/integrator_utils.jl | 5 +++++ src/solve.jl | 13 ++++++++++--- test/interface/ode_saveat_tests.jl | 25 +++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/src/integrators/integrator_utils.jl b/src/integrators/integrator_utils.jl index 4bc032d1a6..da1ef70034 100644 --- a/src/integrators/integrator_utils.jl +++ b/src/integrators/integrator_utils.jl @@ -79,6 +79,10 @@ function _savevalues!(integrator, force_save, reduce_size)::Tuple{Bool, Bool} integrator.cache.current) end else # ==t, just save + if curt == integrator.sol.prob.tspan[2] && !integrator.opts.save_end + integrator.saveiter -= 1 + continue + end savedexactly = true copyat_or_push!(integrator.sol.t, integrator.saveiter, integrator.t) if integrator.opts.save_idxs === nothing @@ -145,6 +149,7 @@ postamble!(integrator::ODEIntegrator) = _postamble!(integrator) function _postamble!(integrator) DiffEqBase.finalize!(integrator.opts.callback, integrator.u, integrator.t, integrator) solution_endpoint_match_cur_integrator!(integrator) + save resize!(integrator.sol.t, integrator.saveiter) resize!(integrator.sol.u, integrator.saveiter) if !(integrator.sol isa DAESolution) diff --git a/src/solve.jl b/src/solve.jl index acb6ffee54..0661c72434 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -293,9 +293,16 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem, sizehint!(ts, 50) sizehint!(ks, 50) elseif !isempty(saveat_internal) - sizehint!(timeseries, length(saveat_internal) + 1) - sizehint!(ts, length(saveat_internal) + 1) - sizehint!(ks, length(saveat_internal) + 1) + savelength = length(saveat_internal) + 1 + if save_start == false + savelength -= 1 + end + if save_end == false && prob.tspan[2] in saveat_internal.valtree + savelength -= 1 + end + sizehint!(timeseries, savelength) + sizehint!(ts, savelength) + sizehint!(ks, savelength) else sizehint!(timeseries, 2) sizehint!(ts, 2) diff --git a/test/interface/ode_saveat_tests.jl b/test/interface/ode_saveat_tests.jl index 858fada876..52df1a9276 100644 --- a/test/interface/ode_saveat_tests.jl +++ b/test/interface/ode_saveat_tests.jl @@ -187,3 +187,28 @@ prob = ODEProblem(SIR!, [0.99, 0.01, 0.0], (t_obs[1], t_obs[end]), [0.20, 0.15]) sol = solve(prob, DP5(), reltol = 1e-6, abstol = 1e-6, saveat = t_obs) @test maximum(sol) <= 1 @test minimum(sol) >= 0 + +@testset "Proper save_start and save_end behavior" begin + function f2(du, u, p, t) + du[1] = -cos(u[1]) * u[1] + end + prob = ODEProblem(f2, [10], (0.0, 0.4)) + + @test solve(prob, Tsit5(); saveat = 0:.1:.4).t == [0.0; 0.1; 0.2; 0.3; 0.4] + @test solve(prob, Tsit5(); saveat = 0:.1:.4, save_start = true, save_end = true).t == [0.0; 0.1; 0.2; 0.3; 0.4] + @test solve(prob, Tsit5(); saveat = 0:.1:.4, save_start = false, save_end = false).t == [0.1; 0.2; 0.3] + + ts = solve(prob, Tsit5()).t + @test 0.0 in ts + @test 0.4 in ts + ts = solve(prob, Tsit5(); save_start = true, save_end = true).t + @test 0.0 in ts + @test 0.4 in ts + ts = solve(prob, Tsit5(); save_start = false, save_end = false).t + @test 0.0 ∉ ts + @test 0.4 ∉ ts + + @test solve(prob, Tsit5(); saveat = [.2]).t == [0.2] + @test solve(prob, Tsit5(); saveat = [.2], save_start = true, save_end = true).t == [0.0; 0.2; 0.4] + @test solve(prob, Tsit5(); saveat = [.2], save_start = false, save_end = false).t == [0.2] +end \ No newline at end of file From 8643fdbb670f7aa51929432de7f66063dacbebe3 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 28 Dec 2023 08:14:07 -0500 Subject: [PATCH 2/4] Update src/integrators/integrator_utils.jl --- src/integrators/integrator_utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/integrators/integrator_utils.jl b/src/integrators/integrator_utils.jl index da1ef70034..8edec505a3 100644 --- a/src/integrators/integrator_utils.jl +++ b/src/integrators/integrator_utils.jl @@ -149,7 +149,6 @@ postamble!(integrator::ODEIntegrator) = _postamble!(integrator) function _postamble!(integrator) DiffEqBase.finalize!(integrator.opts.callback, integrator.u, integrator.t, integrator) solution_endpoint_match_cur_integrator!(integrator) - save resize!(integrator.sol.t, integrator.saveiter) resize!(integrator.sol.u, integrator.saveiter) if !(integrator.sol isa DAESolution) From f3751847255d069eb5fdcbccabb95150760f0788 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 28 Dec 2023 08:14:33 -0500 Subject: [PATCH 3/4] fix save forcing --- src/integrators/integrator_utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/integrators/integrator_utils.jl b/src/integrators/integrator_utils.jl index da1ef70034..eab7044f36 100644 --- a/src/integrators/integrator_utils.jl +++ b/src/integrators/integrator_utils.jl @@ -111,7 +111,9 @@ function _savevalues!(integrator, force_save, reduce_size)::Tuple{Bool, Bool} end end if force_save || (integrator.opts.save_everystep && - (isempty(integrator.sol.t) || (integrator.t !== integrator.sol.t[end]))) + (isempty(integrator.sol.t) || (integrator.t !== integrator.sol.t[end]) && + (integrator.opts.save_end || integrator.t !== integrator.sol.prob.tspan[2]) + )) integrator.saveiter += 1 saved, savedexactly = true, true if integrator.opts.save_idxs === nothing From d054a0e3db98e437cedbc8cca8f6a12751c77643 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 28 Dec 2023 09:47:48 -0500 Subject: [PATCH 4/4] Update ode_saveat_tests.jl --- test/interface/ode_saveat_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/interface/ode_saveat_tests.jl b/test/interface/ode_saveat_tests.jl index 52df1a9276..5e1a6de97e 100644 --- a/test/interface/ode_saveat_tests.jl +++ b/test/interface/ode_saveat_tests.jl @@ -160,7 +160,7 @@ integ = init(ODEProblem((u, p, t) -> u, 0.0, (0.0, 1.0)), Tsit5(), saveat = _sav save_end = false) add_tstop!(integ, 2.0) solve!(integ) -@test integ.sol.t == _saveat +@test integ.sol.t == _saveat[1:end-1] # Catch save for maxiters ode = ODEProblem((u, p, t) -> u, 1.0, (0.0, 1.0)) @@ -211,4 +211,4 @@ sol = solve(prob, DP5(), reltol = 1e-6, abstol = 1e-6, saveat = t_obs) @test solve(prob, Tsit5(); saveat = [.2]).t == [0.2] @test solve(prob, Tsit5(); saveat = [.2], save_start = true, save_end = true).t == [0.0; 0.2; 0.4] @test solve(prob, Tsit5(); saveat = [.2], save_start = false, save_end = false).t == [0.2] -end \ No newline at end of file +end