Skip to content

Commit e679e0c

Browse files
feat: allow manually choosing time-independent initialization
1 parent b3d3c38 commit e679e0c

File tree

3 files changed

+21
-13
lines changed

3 files changed

+21
-13
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
13781378
allow_incomplete = false,
13791379
force_time_independent = false,
13801380
algebraic_only = false,
1381+
time_dependent_init = is_time_dependent(sys),
13811382
kwargs...) where {iip, specialize}
13821383
if !iscomplete(sys)
13831384
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
@@ -1392,7 +1393,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
13921393
simplify_system = true
13931394
else
13941395
isys = generate_initializesystem(
1395-
sys; u0map, initialization_eqs, check_units,
1396+
sys; u0map, initialization_eqs, check_units, time_dependent_init,
13961397
pmap = parammap, guesses, extra_metadata = (; use_scc), algebraic_only)
13971398
simplify_system = true
13981399
end
@@ -1431,7 +1432,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
14311432

14321433
# TODO: throw on uninitialized arrays
14331434
filter!(x -> !(x isa Symbolics.Arr), uninit)
1434-
if is_time_dependent(sys) && !isempty(uninit)
1435+
if time_dependent_init && !isempty(uninit)
14351436
allow_incomplete || throw(IncompleteInitializationError(uninit))
14361437
# for incomplete initialization, we will add the missing variables as parameters.
14371438
# they will be updated by `update_initializeprob!` and `initializeprobmap` will

src/systems/nonlinear/initializesystem.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
function generate_initializesystem(sys::AbstractSystem; kwargs...)
2-
if is_time_dependent(sys)
1+
function generate_initializesystem(
2+
sys::AbstractSystem; time_dependent_init = is_time_dependent(sys), kwargs...)
3+
if time_dependent_init
34
generate_initializesystem_timevarying(sys; kwargs...)
45
else
56
generate_initializesystem_timeindependent(sys; kwargs...)
@@ -153,7 +154,7 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
153154
end
154155
meta = InitializationSystemMetadata(
155156
anydict(u0map), anydict(pmap), additional_guesses,
156-
additional_initialization_eqs, extra_metadata, nothing)
157+
additional_initialization_eqs, extra_metadata, nothing, true)
157158
return NonlinearSystem(eqs_ics,
158159
vars,
159160
pars;
@@ -254,7 +255,7 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem;
254255
end
255256
meta = InitializationSystemMetadata(
256257
anydict(u0map), anydict(pmap), additional_guesses,
257-
additional_initialization_eqs, extra_metadata, nothing)
258+
additional_initialization_eqs, extra_metadata, nothing, false)
258259
return NonlinearSystem(eqs_ics,
259260
vars,
260261
pars;
@@ -500,6 +501,7 @@ struct InitializationSystemMetadata
500501
additional_initialization_eqs::Vector{Equation}
501502
extra_metadata::NamedTuple
502503
oop_reconstruct_u0_p::Union{Nothing, ReconstructInitializeprob}
504+
time_dependent_init::Bool
503505
end
504506

505507
function get_possibly_array_fallback_singletons(varmap, p)
@@ -609,6 +611,7 @@ function SciMLBase.remake_initialization_data(
609611
merge!(guesses, meta.additional_guesses)
610612
use_scc = get(meta.extra_metadata, :use_scc, true)
611613
initialization_eqs = meta.additional_initialization_eqs
614+
time_dependent_init = meta.time_dependent_init
612615
end
613616
else
614617
# there is no initializeprob, so the original problem construction
@@ -656,7 +659,7 @@ function SciMLBase.remake_initialization_data(
656659
u0map, pmap, defs, cmap, dvs, ps)
657660
floatT = float_type_from_varmap(op)
658661
kws = maybe_build_initialization_problem(
659-
sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns;
662+
sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; time_dependent_init,
660663
use_scc, initialization_eqs, floatT, allow_incomplete = true)
661664

662665
return SciMLBase.remake_initialization_data(sys, kws, newu0, t0, newp, newu0, newp)

src/systems/problem_utils.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -633,15 +633,16 @@ All other keyword arguments are forwarded to `InitializationProblem`.
633633
function maybe_build_initialization_problem(
634634
sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs,
635635
guesses, missing_unknowns; implicit_dae = false,
636-
u0_constructor = identity, floatT = Float64, kwargs...)
636+
time_dependent_init = is_time_dependent(sys), u0_constructor = identity,
637+
floatT = Float64, kwargs...)
637638
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))
638639

639640
if t === nothing && is_time_dependent(sys)
640641
t = zero(floatT)
641642
end
642643

643644
initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}(
644-
sys, t, u0map, pmap; guesses, kwargs...)
645+
sys, t, u0map, pmap; guesses, time_dependent_init, kwargs...)
645646
if state_values(initializeprob) !== nothing
646647
initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob)))
647648
end
@@ -660,7 +661,10 @@ function maybe_build_initialization_problem(
660661

661662
meta = get_metadata(initializeprob.f.sys)
662663

663-
if is_time_dependent(sys)
664+
if time_dependent_init === nothing
665+
time_dependent_init = is_time_dependent(sys)
666+
end
667+
if time_dependent_init
664668
all_init_syms = Set(all_symbols(initializeprob))
665669
solved_unknowns = filter(var -> var in all_init_syms, unknowns(sys))
666670
initializeprobmap = u0_constructor getu(initializeprob, solved_unknowns)
@@ -700,7 +704,7 @@ function maybe_build_initialization_problem(
700704
end
701705
end
702706

703-
if is_time_dependent(sys)
707+
if time_dependent_init
704708
for v in missing_unknowns
705709
op[v] = get_temporary_value(v, floatT)
706710
end
@@ -803,7 +807,7 @@ function process_SciMLProblem(
803807
symbolic_u0 = false, warn_cyclic_dependency = false,
804808
circular_dependency_max_cycle_length = length(all_symbols(sys)),
805809
circular_dependency_max_cycles = 10,
806-
substitution_limit = 100, use_scc = true,
810+
substitution_limit = 100, use_scc = true, time_dependent_init = is_time_dependent(sys),
807811
force_initialization_time_independent = false, algebraic_only = false,
808812
allow_incomplete = false, is_initializeprob = false, kwargs...)
809813
dvs = unknowns(sys)
@@ -858,7 +862,7 @@ function process_SciMLProblem(
858862
warn_cyclic_dependency, check_units = check_initialization_units,
859863
circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
860864
force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete,
861-
u0_constructor, floatT)
865+
u0_constructor, floatT, time_dependent_init)
862866

863867
kwargs = merge(kwargs, kws)
864868
end

0 commit comments

Comments
 (0)