-
-
Notifications
You must be signed in to change notification settings - Fork 221
Type inference on solution output fails for DAE with callback #2594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Turns out that, on master at least, making the callback a const constcallback = ContinuousCallback((u,_,__)->u[1]-1, terminate!)
function eval_prm_ccb(pr)
newprob = ODEProblem{true}(dae!, [2.0, -2.0], (0.0, 1.0), (pr,))
sol = solve(newprob, Rosenbrock23(); callback=constcallback)
evalsol(sol)
end
while
|
Oh right. That is expected. Otherwise this is just normal global variable type instability. |
In my code base (which has some more intricacies), I'm still seeing something that isn't type inferrable. I'll make another issue later if I manage to identify where the problem actually is, but in the meantime I can probably get away with a type assertion in my squared error calculation. |
Are you using save_idxs? |
No, not currently (didn't know that was an option). My problem only has two equations, so I haven't been worried on that front, but I have a parameter struct full of Unitful types and the endpoint in time is dependent on the system dynamics (hence a callback that terminates intrgration when one of my two variables approaches zero) |
Oh Unitful 😅 If you don't use unitful do you still have issues? |
I can replicate my problem without Unitful; here is an example which resembles my code base, without using Unitful: using OrdinaryDiffEqRosenbrock
using LinearAlgebra: Diagonal
using Accessors
# using OrdinaryDiffEqNonlinearSolve
struct Coeffs{T1, T2}
p1::T1
p2::T2
end
function Base.getindex(c::Coeffs, i::Int)
if i == 1
return c.p1
elseif i == 2
return c.p2
else
throw(ArgumentError("Index out of bounds"))
end
end
function dae!(du, u, p, t)
du[1] = -u[1]*p[1]
du[2] = u[2] + u[1]
nothing
end
const dae_fc = ODEFunction(dae!, mass_matrix=Diagonal([1.0, 0.0]))
calc_u0(c::Coeffs) = [2.0, -2.0]
function OrdinaryDiffEqRosenbrock.ODEProblem(c::Coeffs)
u0 = calc_u0(c)
return ODEProblem{true}(dae_fc, u0, (0.0, 1.0), c; initializealg=CheckInit())
# return ODEProblem{true}(dae_fc, u0, (0.0, 1.0), c;)
end With this example, @code_warntype solve(ODEProblem(Coeffs(-1.0, 1.0)), Rosenbrock23()) I get this, where it's a little hard to tell where the instability happens:
Dropping the kwarg for |
Using prob = ODEProblem(cbase)
@report_opt solve(prob)
And if I load ODENonlinearSolve and drop CheckInit as kwarg, then JET reports many more possible errors (321). If it matters, these solves get put inside a call like below, but the type stability happens already at the solve step. In this case I can annotate the return type to match the input type, but it would be nice for that to be inferrable function eval_prm_ccb(pr, c::Coeffs)
# rtype == typeof(pr)
newc = @set c.p1 = pr
newprob = ODEProblem(newc)
sol = solve(newprob, Rosenbrock23(); callback=ccallback)
evalsol(sol)#::rtype
end |
Exploring this MWE with Cthulhu, it looks like maybe this is related to #2613 ? Inside of which (if I am exploring this correctly, which I might not be) eventually leads to a culprit here: OrdinaryDiffEq.jl/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl Lines 74 to 87 in 786689e
|
@oscardssmith did you ever look into this? |
I checked this MWE again after #2617 merged, and this still occurs, so something else seems to be in play here. |
@jClugstor this appears to be due to the remake that happens in
|
If I remember right using the |
so the problem is actually slightly earlier. The actual problem is that
means that the type of the |
This was the behavior before the ADTypes change, it's only unnecessary if we don't care about picking the chunksize based on the colorvec I guess? I was having a similar problem with the types not being inferred in #2567, changing to the other constructor type fixed it, but yeah this might not be the same thing. Let me check it out. It's also possible that this is not a problem in #2567 |
Nope, it's still an issue, so yeah something else is going on. This is the previous code, which would still have that same issue, no? The type of the alg would change, especially since before chunksize was a type parameter. Hmm. L = StaticArrayInterface.known_length(typeof(u0))
if L === nothing # dynamic sized
# If chunksize is zero, pick chunksize right at the start of solve and
# then do function barrier to infer the full solve
x = if prob.f.colorvec === nothing
length(u0)
else
maximum(prob.f.colorvec)
end
cs = ForwardDiff.pickchunksize(x)
return remake(alg, chunk_size = Val{cs}())
else # statically sized
cs = pick_static_chunksize(Val{L}())
return remake(alg, chunk_size = cs)
end
|
I think the problem is that the "# then do function barrier to infer the full solve" must have disappeared. |
Yeah, I'm not sure what was supposed to be there before. btw, on my branch with DI I get the exact same output from There are tests that go through this exact code and check that everything is inferred I believe, so I'm not sure why this is happening in this case. |
I don't think we can get around the fact that the chunk size depends on the number of colors (see also JuliaDiff/DifferentiationInterface.jl#593), for the same reason that in the dense case it depends on the array length. |
Yes, I was testing this yesterday, and on the DI branch the type instability didn't seem to come from |
Describe the bug 🐞
DAE solutions with callbacks work, but they are apparently type unstable, according to
@code_warntype
. ODE solutions with callbacks and DAE solutions without callbacks without are both type stable.See also #2530 , #2558 .
Minimal Reproducible Example 👇
Error & Stacktrace⚠️
Without callback, is type stable:
With callback, is unstable:
Environment (please complete the following information):
using Pkg; Pkg.status()
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
Updated today. The SciML and OrdinaryDiffEq packages:
The text was updated successfully, but these errors were encountered: