diff --git a/Project.toml b/Project.toml index faf9e4eaf9..293e3c2734 100644 --- a/Project.toml +++ b/Project.toml @@ -45,6 +45,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -142,6 +143,7 @@ OrdinaryDiffEq = "6.82.0" OrdinaryDiffEqCore = "1.15.0" OrdinaryDiffEqDefault = "1.2" OrdinaryDiffEqNonlinearSolve = "1.5.0" +PreallocationTools = "0.4.27" PrecompileTools = "1" Pyomo = "0.1.0" REPL = "1" diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 0ad1080965..c3f95f839c 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -99,6 +99,8 @@ const DQ = DynamicQuantities import DifferentiationInterface as DI using ADTypes: AutoForwardDiff import SciMLPublic: @public +import PreallocationTools +import PreallocationTools: DiffCache export @derivatives @@ -287,6 +289,7 @@ export IntervalNonlinearProblem export OptimizationProblem, constraints export SteadyStateProblem export JumpProblem +export SemilinearODEFunction, SemilinearODEProblem export alias_elimination, flatten export connect, domain_connect, @connector, Connection, AnalysisPoint, Flow, Stream, instream diff --git a/src/problems/odeproblem.jl b/src/problems/odeproblem.jl index 6726322907..62705e381d 100644 --- a/src/problems/odeproblem.jl +++ b/src/problems/odeproblem.jl @@ -98,9 +98,143 @@ end maybe_codegen_scimlproblem(expression, SteadyStateProblem{iip}, args; kwargs...) end +struct SemilinearODEFunction{iip, spec} end + +@fallback_iip_specialize function SemilinearODEFunction{iip, specialize}( + sys::System; u0 = nothing, p = nothing, t = nothing, + semiquadratic_form = nothing, semiquadratic_jacobian = nothing, + eval_expression = false, eval_module = @__MODULE__, + expression = Val{false}, sparse = false, check_compatibility = true, + jac = false, checkbounds = false, cse = true, initialization_data = nothing, + analytic = nothing, kwargs...) where {iip, specialize} + check_complete(sys, SemilinearODEFunction) + check_compatibility && check_compatible_system(SemilinearODEFunction, sys) + + if semiquadratic_form === nothing + sys = add_semilinear_parameters(sys) + semiquadratic_form = calculate_split_form(sys; sparse) + end + + A, B, x2, C = semiquadratic_form + M = calculate_massmatrix(sys) + _M = concrete_massmatrix(M; sparse, u0) + + f1, f2 = generate_semiquadratic_functions( + sys, A, B, x2, C; expression, wrap_gfw = Val{true}, + eval_expression, eval_module, kwargs...) + + if jac + semiquadratic_jacobian = @something(semiquadratic_jacobian, + calculate_semiquadratic_jacobian(sys, B, x2, C; sparse, massmatrix = _M)) + f1jac, x2jac, Cjac = semiquadratic_jacobian + _jac = generate_semiquadratic_jacobian( + sys, B, x2, C, f1jac, x2jac, Cjac; sparse, expression, + wrap_gfw = Val{true}, eval_expression, eval_module, kwargs...) + _W_sparsity = f1jac + W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse) + else + _jac = nothing + W_prototype = nothing + end + + observedfun = ObservedFunctionCache( + sys; expression, steady_state = false, eval_expression, eval_module, checkbounds, cse) + + f1_args = (; f1) + f1_kwargs = (; jac = _jac) + f1 = maybe_codegen_scimlfn( + expression, ODEFunction{iip, specialize}, f1_args; f1_kwargs...) + args = (; f1, f2) + + kwargs = (; + sys = sys, + jac = _jac, + mass_matrix = _M, + jac_prototype = W_prototype, + observed = observedfun, + analytic, + initialization_data) + kwargs = (; sys, observed = observedfun, mass_matrix = _M) + + return maybe_codegen_scimlfn( + expression, SplitFunction{iip, specialize}, args; kwargs...) +end + +struct SemilinearODEProblem{iip, spec} end + +@fallback_iip_specialize function SemilinearODEProblem{iip, spec}( + sys::System, op, tspan; check_compatibility = true, + u0_eltype = nothing, expression = Val{false}, callback = nothing, + jac = false, sparse = false, kwargs...) where {iip, spec} + check_complete(sys, SemilinearODEProblem) + check_compatibility && check_compatible_system(SemilinearODEProblem, sys) + + A, B, x2, C = semiquadratic_form = calculate_split_form(sys) + + semiquadratic_jacobian = nothing + if jac + f1jac, x2jac, Cjac = semiquadratic_jacobian = calculate_semiquadratic_jacobian( + sys, B, x2, C; sparse) + end + + sys = add_semilinear_parameters(sys) + linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME)) + bilinear_matrix_param = unwrap(getproperty(sys, BILINEAR_MATRIX_PARAM_NAME)) + diffcache = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME)) + + floatT = calculate_float_type(op, typeof(op)) + _u0_eltype = something(u0_eltype, floatT) + + guess = copy(guesses(sys)) + guess[linear_matrix_param] = fill(NaN, size(A)) + guess[bilinear_matrix_param] = fill(NaN, size(B)) + @set! sys.guesses = guess + defs = copy(defaults(sys)) + defs[linear_matrix_param] = A + defs[bilinear_matrix_param] = B + cachelen = jac ? length(x2jac) : length(x2) + defs[diffcache] = DiffCache(zeros(DiffEqBase.value(_u0_eltype), cachelen)) + @set! sys.defaults = defs + + f, u0, p = process_SciMLProblem(SemilinearODEFunction{iip, spec}, sys, op; + t = tspan !== nothing ? tspan[1] : tspan, expression, check_compatibility, + semiquadratic_form, semiquadratic_jacobian, jac, sparse, u0_eltype, kwargs...) + + kwargs = process_kwargs( + sys; expression, callback, kwargs...) + + ptype = getmetadata(sys, ProblemTypeCtx, StandardODEProblem()) + args = (; f, u0, tspan, p) + maybe_codegen_scimlproblem(expression, SplitODEProblem{iip}, args; kwargs...) +end + +function add_semilinear_parameters(sys::System) + m = length(equations(sys)) + n = length(unknowns(sys)) + linear_matrix_param = get_linear_matrix_param((m, n)) + bilinear_matrix_param = get_bilinear_matrix_param((m, (n^2 + n) ÷ 2)) + @assert !is_parameter(sys, linear_matrix_param) + sys = with_additional_constant_parameter(sys, linear_matrix_param) + @assert !is_parameter(sys, bilinear_matrix_param) + sys = with_additional_constant_parameter(sys, bilinear_matrix_param) + @assert !is_parameter(sys, get_diffcache_param(Float64)) + diffcache = get_diffcache_param(Float64) + sys = with_additional_nonnumeric_parameter(sys, diffcache) + var_to_name = copy(get_var_to_name(sys)) + var_to_name[LINEAR_MATRIX_PARAM_NAME] = linear_matrix_param + var_to_name[BILINEAR_MATRIX_PARAM_NAME] = bilinear_matrix_param + var_to_name[DIFFCACHE_PARAM_NAME] = diffcache + @set! sys.var_to_name = var_to_name + if get_parent(sys) !== nothing + @set! sys.parent = add_semilinear_parameters(get_parent(sys)) + end + return sys +end + function check_compatible_system( T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction}, - Type{DAEProblem}, Type{SteadyStateProblem}}, + Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction}, + Type{SemilinearODEProblem}}, sys::System) check_time_dependent(sys, T) check_not_dde(sys) diff --git a/src/systems/codegen.jl b/src/systems/codegen.jl index 4a68f935e8..b98693a9e7 100644 --- a/src/systems/codegen.jl +++ b/src/systems/codegen.jl @@ -1142,35 +1142,15 @@ Return matrix `A` and vector `b` such that the system `sys` can be represented a - `sparse`: return a sparse `A`. """ function calculate_A_b(sys::System; sparse = false) - rhss = [eq.rhs for eq in full_equations(sys)] + rhss = [-eq.rhs for eq in full_equations(sys)] dvs = unknowns(sys) - A = Matrix{Any}(undef, length(rhss), length(dvs)) - b = Vector{Any}(undef, length(rhss)) - for (i, rhs) in enumerate(rhss) - # mtkcompile makes this `0 ~ rhs` which typically ends up giving - # unknowns negative coefficients. If given the equations `A * x ~ b` - # it will simplify to `0 ~ b - A * x`. Thus this negation usually leads - # to more comprehensible user API. - resid = -rhs - for (j, var) in enumerate(dvs) - p, q, islinear = Symbolics.linear_expansion(resid, var) - if !islinear - throw(ArgumentError("System is not linear. Equation $((0 ~ rhs)) is not linear in unknown $var.")) - end - A[i, j] = p - resid = q - end - # negate beucause `resid` is the residual on the LHS - b[i] = -resid - end - - @assert all(Base.Fix1(isassigned, A), eachindex(A)) - @assert all(Base.Fix1(isassigned, A), eachindex(b)) - - if sparse - A = SparseArrays.sparse(A) + A, b = semilinear_form(rhss, dvs) + if !sparse + A = collect(A) end + A = unwrap.(A) + b = unwrap.(-b) return A, b end @@ -1217,3 +1197,197 @@ function generate_update_b(sys::System, b::AbstractVector; expression = Val{true return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res; eval_expression, eval_module) end + +# f1 = rest +# f2 = A * x + B * x2 + C +function calculate_split_form(sys::System; sparse = false) + rhss = [eq.rhs for eq in full_equations(sys)] + dvs = unknowns(sys) + A, B, x2, C = semiquadratic_form(rhss, dvs) + if !sparse + A = collect(A) + B = collect(B) + end + A = unwrap.(A) + B = unwrap.(B) + x2 = unwrap.(x2) + C = unwrap.(C) + + return A, B, x2, C +end + +const DIFFCACHE_PARAM_NAME = :__mtk_diffcache + +function get_diffcache_param(::Type{T}) where {T} + toconstant(Symbolics.variable( + DIFFCACHE_PARAM_NAME; T = DiffCache{Vector{T}, Vector{T}})) +end + +# x2 +const BILINEAR_CACHEVAR = unwrap(only(@constants bilinear_xₘₜₖ::Vector{Real})) +# A +const LINEAR_MATRIX_PARAM_NAME = :linear_Aₘₜₖ +function get_linear_matrix_param(size::NTuple{2, Int}) + m, n = size + unwrap(only(@constants linear_Aₘₜₖ[1:m, 1:n])) +end +# B +const BILINEAR_MATRIX_PARAM_NAME = :bilinear_Bₘₜₖ +function get_bilinear_matrix_param(size::NTuple{2, Int}) + m, n = size + unwrap(only(@constants bilinear_Bₘₜₖ[1:m, 1:n])) +end + +function generate_semiquadratic_functions( + sys::System, A, B, x2, C; expression = Val{true}, wrap_gfw = Val{false}, + eval_expression = false, eval_module = @__MODULE__, kwargs...) + linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME)) + bilinear_matrix_param = unwrap(getproperty(sys, BILINEAR_MATRIX_PARAM_NAME)) + diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME)) + dvs = unknowns(sys) + ps = reorder_parameters(sys) + # Codegen is a bit manual, and we're manually creating an efficient IIP function. + # Since we explicitly provide Symbolics.DEFAULT_OUTSYM, the `u` is actually the second + # argument. + iip_x = generated_argument_name(2) + oop_x = generated_argument_name(1) + + f1_iip_ir = Assignment[Assignment(BILINEAR_CACHEVAR, + term(view, + term(PreallocationTools.get_tmp, + diffcache_par, Symbolics.DEFAULT_OUTSYM), + 1:length(x2))) + # write to x2 + Assignment(:__tmp1, SetArray(false, BILINEAR_CACHEVAR, x2)) + # out .= C + Assignment( + :__tmp2, SetArray(false, Symbolics.DEFAULT_OUTSYM, C)) + # mul!(out, B, x2, 1, 1) + Assignment(:__tmp3, + term(mul!, Symbolics.DEFAULT_OUTSYM, bilinear_matrix_param, + BILINEAR_CACHEVAR, true, true))] + f1_iip = build_function_wrapper( + sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., get_iv(sys); p_start = 3, + extra_assignments = f1_iip_ir, expression = Val{true}, kwargs...) + f1_oop = build_function_wrapper( + sys, term(+, term(*, bilinear_matrix_param, x2), C), dvs, ps..., + get_iv(sys); expression = Val{true}, iip_config = (true, false), kwargs...) + + f2_iip_ir = Assignment[ + Assignment( + :__tmp1, term(mul!, Symbolics.DEFAULT_OUTSYM, linear_matrix_param, iip_x)) + ] + f2_iip = build_function_wrapper( + sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., get_iv(sys); p_start = 3, + extra_assignments = f2_iip_ir, expression = Val{true}, kwargs...) + f2_oop = build_function_wrapper( + sys, term(*, linear_matrix_param, oop_x), dvs, ps..., get_iv(sys); + expression = Val{true}, iip_config = (true, false), kwargs...) + + f1 = maybe_compile_function(expression, wrap_gfw, (2, 3, is_split(sys)), + (f1_oop, f1_iip); eval_expression, eval_module) + f2 = maybe_compile_function(expression, wrap_gfw, (2, 3, is_split(sys)), + (f2_oop, f2_iip); eval_expression, eval_module) + return f1, f2 +end + +function calculate_semiquadratic_jacobian( + sys::System, B, x2, C; sparse = false, massmatrix = calculate_massmatrix(sys)) + dvs = unknowns(sys) + if sparse + x2jac = Symbolics.sparsejacobian(x2, dvs) + Cjac = Symbolics.sparsejacobian(C, dvs) + else + x2jac = Symbolics.jacobian(x2, dvs) + Cjac = Symbolics.jacobian(C, dvs) + end + + f1jac = B * x2jac + Cjac + + if sparse + for i in 1:length(dvs) + massmatrix[i, i] == 0 && continue + _iszero(f1jac[i, i]) || continue + f1jac[i, i] = 1 + f1jac[i, i] = 0 + end + end + + return f1jac, x2jac, Cjac +end + +const COLPTR_PARAM = unwrap(only(@parameters __mtk_colptr::Vector{Int})) +const ROWVAL_PARAM = unwrap(only(@parameters __mtk_rowval::Vector{Int})) + +function generate_semiquadratic_jacobian( + sys::System, B, x2, C, f1jac, x2jac, Cjac; sparse = false, + expression = Val{true}, wrap_gfw = Val{false}, + eval_expression = false, eval_module = @__MODULE__, kwargs...) + if sparse + @assert is_parameter(sys, COLPTR_PARAM) + @assert is_parameter(sys, ROWVAL_PARAM) + end + bilinear_matrix_param = unwrap(getproperty(sys, BILINEAR_MATRIX_PARAM_NAME)) + diffcache_par = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME)) + dvs = unknowns(sys) + ps = reorder_parameters(sys) + # Codegen is a bit manual, and we're manually creating an efficient IIP function. + # Since we explicitly provide Symbolics.DEFAULT_OUTSYM, the `u` is actually the second + # argument. + iip_x = generated_argument_name(2) + oop_x = generated_argument_name(1) + + iip_ir = Assignment[] + push!(iip_ir, + Assignment(:__mtk_preallocbuf, + term(PreallocationTools.get_tmp, diffcache_par, Symbolics.DEFAULT_OUTSYM))) + if sparse + push!( + iip_ir, Assignment(:__mtk_nzvals, term(view, :__mtk_preallocbuf, 1:nnz(x2jac)))) + push!(iip_ir, Assignment(:__tmp1, SetArray(false, :__mtk_nzvals, x2jac.nzvals))) + push!(iip_ir, + Assignment(:__mtk_x2jacbuf, + term(SparseMatrixCSC, size(x2jac)..., + COLPTR_PARAM, ROWVAL_PARAM, :__mtk_nzvals))) + cjac_idxs = AtIndex[] + for (i, j, v) in zip(findnz(Cjac)...) + push!(cjac_idxs, AtIndex(CartesianIndex(i, j), v)) + end + else + push!(iip_ir, + Assignment(:__mtk_x2jacbuf, + term(reshape, term(view, :__mtk_preallocbuf, 1:length(x2jac)), size(x2jac)))) + push!(iip_ir, Assignment(:__tmp1, SetArray(false, :__mtk_x2jacbuf, x2jac))) + cjac_idxs = AtIndex[] + for i in eachindex(Cjac) + _iszero(Cjac[i]) && continue + push!(cjac_idxs, AtIndex(i, Cjac[i])) + end + end + push!(iip_ir, Assignment(:__tmp2, SetArray(false, Symbolics.DEFAULT_OUTSYM, cjac_idxs))) + push!(iip_ir, + Assignment(:__tmp3, + term(mul!, Symbolics.DEFAULT_OUTSYM, + bilinear_matrix_param, :__mtk_x2jacbuf, true, true))) + + jaciip = build_function_wrapper( + sys, nothing, Symbolics.DEFAULT_OUTSYM, dvs, ps..., get_iv(sys); + p_start = 3, extra_assignments = iip_ir, expression = Val{true}, kwargs...) + + make_x2 = if sparse + MakeSparseArray(x2jac) + else + MakeArray(x2jac, generated_argument_name(1)) + end + make_cjac = if sparse + MakeSparseArray(Cjac) + else + MakeArray(Cjac, generated_argument_name(1)) + end + oop_expr = term(+, term(*, bilinear_matrix_param, make_x2), Cjac) + jacoop = build_function_wrapper( + sys, oop_expr, dvs, ps..., get_iv(sys); expression = Val{true}, kwargs...) + + return maybe_compile_function(expression, wrap_gfw, (2, 3, is_split(sys)), + (jacoop, jaciip); eval_expression, eval_module) +end diff --git a/src/systems/codegen_utils.jl b/src/systems/codegen_utils.jl index dbbd7f85a8..d594a3902a 100644 --- a/src/systems/codegen_utils.jl +++ b/src/systems/codegen_utils.jl @@ -79,8 +79,8 @@ function array_variable_assignments(args...; argument_name = generated_argument_ # to help reduce allocations if first(idxs) < last(idxs) && vec(idxs) == first(idxs):last(idxs) idxs = first(idxs):last(idxs) - elseif vec(idxs) == last(idxs):-1:first(idxs) - idxs = last(idxs):-1:first(idxs) + elseif vec(idxs) == first(idxs):-1:last(idxs) + idxs = first(idxs):-1:last(idxs) else # Otherwise, turn the indexes into an `SArray` so they're stack-allocated idxs = SArray{Tuple{size(idxs)...}}(idxs) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index c66c562e9c..e58a924b88 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -63,6 +63,16 @@ struct IndexCache symbol_to_variable::Dict{Symbol, SymbolicParam} end +function Base.copy(ic::IndexCache) + IndexCache(copy(ic.unknown_idx), copy(ic.discrete_idx), copy(ic.callback_to_clocks), + copy(ic.tunable_idx), copy(ic.initials_idx), copy(ic.constant_idx), + copy(ic.nonnumeric_idx), copy(ic.observed_syms_to_timeseries), + copy(ic.dependent_pars_to_timeseries), copy(ic.discrete_buffer_sizes), + ic.tunable_buffer_size, ic.initials_buffer_size, + copy(ic.constant_buffer_sizes), copy(ic.nonnumeric_buffer_sizes), + copy(ic.symbol_to_variable)) +end + function IndexCache(sys::AbstractSystem) unks = unknowns(sys) unk_idxs = UnknownIndexMap() @@ -716,3 +726,55 @@ function subset_unknowns_observed( @set! ic.observed_syms_to_timeseries = observed_syms_to_timeseries return ic end + +function with_additional_constant_parameter(sys::AbstractSystem, par) + par = unwrap(par) + ps = copy(get_ps(sys)) + push!(ps, par) + @set! sys.ps = ps + is_split(sys) || return sys + + ic = copy(get_index_cache(sys)) + T = symtype(par) + bufidx = findfirst(buft -> buft.type == T, ic.constant_buffer_sizes) + if bufidx === nothing + push!(ic.constant_buffer_sizes, BufferTemplate(T, 1)) + bufidx = length(ic.constant_buffer_sizes) + idx_in_buf = 1 + else + buft = ic.constant_buffer_sizes[bufidx] + ic.constant_buffer_sizes[bufidx] = BufferTemplate(T, buft.length + 1) + idx_in_buf = buft.length + 1 + end + + ic.constant_idx[par] = ic.constant_idx[renamespace(sys, par)] = (bufidx, idx_in_buf) + @set! sys.index_cache = ic + + return sys +end + +function with_additional_nonnumeric_parameter(sys::AbstractSystem, par) + par = unwrap(par) + ps = copy(get_ps(sys)) + push!(ps, par) + @set! sys.ps = ps + is_split(sys) || return sys + + ic = copy(get_index_cache(sys)) + T = symtype(par) + bufidx = findfirst(buft -> buft.type == T, ic.nonnumeric_buffer_sizes) + if bufidx === nothing + push!(ic.nonnumeric_buffer_sizes, BufferTemplate(T, 1)) + bufidx = length(ic.nonnumeric_buffer_sizes) + idx_in_buf = 1 + else + buft = ic.nonnumeric_buffer_sizes[bufidx] + ic.nonnumeric_buffer_sizes[bufidx] = BufferTemplate(T, buft.length + 1) + idx_in_buf = buft.length + 1 + end + + ic.nonnumeric_idx[par] = ic.nonnumeric_idx[renamespace(sys, par)] = (bufidx, idx_in_buf) + @set! sys.index_cache = ic + + return sys +end