Skip to content

Commit 49115cb

Browse files
refactor: improve precompilation of MTKBase, MTK
1 parent 69d2800 commit 49115cb

File tree

3 files changed

+136
-122
lines changed

3 files changed

+136
-122
lines changed

lib/ModelingToolkitBase/src/ModelingToolkitBase.jl

Lines changed: 1 addition & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -354,102 +354,5 @@ function __init__()
354354
SU.hashcons(COMMON_INF, true)
355355
end
356356

357-
PrecompileTools.@compile_workload begin
358-
fold1 = Val{false}()
359-
using SymbolicUtils
360-
using SymbolicUtils: shape
361-
using Symbolics
362-
@syms x y f(t) q[1:5]
363-
SymbolicUtils.Sym{SymReal}(:a; type = Real, shape = SymbolicUtils.ShapeVecT())
364-
x + y
365-
x * y
366-
x / y
367-
x ^ y
368-
x ^ 5
369-
6 ^ x
370-
x - y
371-
-y
372-
2y
373-
z = 2
374-
dict = SymbolicUtils.ACDict{VartypeT}()
375-
dict[x] = 1
376-
dict[y] = 1
377-
type::typeof(DataType) = rand() < 0.5 ? Real : Float64
378-
nt = (; type, shape, unsafe = true)
379-
Base.pairs(nt)
380-
BSImpl.AddMul{VartypeT}(1, dict, SymbolicUtils.AddMulVariant.MUL; type, shape = SymbolicUtils.ShapeVecT(), unsafe = true)
381-
*(y, z)
382-
*(z, y)
383-
SymbolicUtils.symtype(y)
384-
f(x)
385-
(5x / 5)
386-
expand((x + y) ^ 2)
387-
simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y)
388-
ex = x + 2y + sin(x)
389-
rules1 = Dict(x => y)
390-
rules2 = Dict(x => 1)
391-
Dx = Differential(x)
392-
Differential(y)(ex)
393-
uex = unwrap(ex)
394-
Symbolics.executediff(Dx, uex)
395-
# Running `fold = Val(true)` invalidates the precompiled statements
396-
# for `fold = Val(false)` and itself doesn't precompile anyway.
397-
# substitute(ex, rules1)
398-
substitute(ex, rules1; fold = fold1)
399-
substitute(ex, rules2; fold = fold1)
400-
@variables foo
401-
f(foo)
402-
@variables x y f(::Real) q[1:5]
403-
x + y
404-
x * y
405-
x / y
406-
x ^ y
407-
x ^ 5
408-
# 6 ^ x
409-
x - y
410-
-y
411-
2y
412-
symtype(y)
413-
z = 2
414-
*(y, z)
415-
*(z, y)
416-
f(x)
417-
(5x / 5)
418-
[x, y]
419-
[x, f, f]
420-
promote_type(Int, Num)
421-
promote_type(Real, Num)
422-
promote_type(Float64, Num)
423-
# expand((x + y) ^ 2)
424-
# simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y)
425-
ex = x + 2y + sin(x)
426-
rules1 = Dict(x => y)
427-
# rules2 = Dict(x => 1)
428-
# Running `fold = Val(true)` invalidates the precompiled statements
429-
# for `fold = Val(false)` and itself doesn't precompile anyway.
430-
# substitute(ex, rules1)
431-
substitute(ex, rules1; fold = fold1)
432-
Symbolics.linear_expansion(ex, y)
433-
# substitute(ex, rules2; fold = fold1)
434-
# substitute(ex, rules2)
435-
# substitute(ex, rules1; fold = fold2)
436-
# substitute(ex, rules2; fold = fold2)
437-
q[1]
438-
q'q
439-
using ModelingToolkitBase
440-
@variables x(ModelingToolkitBase.t_nounits) y(ModelingToolkitBase.t_nounits)
441-
isequal(ModelingToolkitBase.D_nounits.x, ModelingToolkitBase.t_nounits)
442-
sys = System([ModelingToolkitBase.D_nounits(x) ~ x * y, y ~ 2x], ModelingToolkitBase.t_nounits, [x, y], Num[]; name = :sys)
443-
complete(sys)
444-
mtkcompile(sys)
445-
@syms p[1:2]
446-
ndims(p)
447-
size(p)
448-
axes(p)
449-
length(p)
450-
v = [p]
451-
isempty(v)
452-
# mtkcompile(sys)
453-
end
454-
357+
include("precompile.jl")
455358
end # module
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
PrecompileTools.@compile_workload begin
2+
fold1 = Val{false}()
3+
using SymbolicUtils
4+
using SymbolicUtils: shape
5+
using Symbolics
6+
@syms x y f(t) q[1:5]
7+
SymbolicUtils.Sym{SymReal}(:a; type = Real, shape = SymbolicUtils.ShapeVecT())
8+
x + y
9+
x * y
10+
x / y
11+
x ^ y
12+
x ^ 5
13+
6 ^ x
14+
x - y
15+
-y
16+
2y
17+
z = 2
18+
dict = SymbolicUtils.ACDict{VartypeT}()
19+
dict[x] = 1
20+
dict[y] = 1
21+
type::typeof(DataType) = rand() < 0.5 ? Real : Float64
22+
nt = (; type, shape, unsafe = true)
23+
Base.pairs(nt)
24+
BSImpl.AddMul{VartypeT}(1, dict, SymbolicUtils.AddMulVariant.MUL; type, shape = SymbolicUtils.ShapeVecT(), unsafe = true)
25+
*(y, z)
26+
*(z, y)
27+
SymbolicUtils.symtype(y)
28+
f(x)
29+
(5x / 5)
30+
expand((x + y) ^ 2)
31+
simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y)
32+
ex = x + 2y + sin(x)
33+
rules1 = Dict(x => y)
34+
rules2 = Dict(x => 1)
35+
Dx = Differential(x)
36+
Differential(y)(ex)
37+
uex = unwrap(ex)
38+
Symbolics.executediff(Dx, uex)
39+
# Running `fold = Val(true)` invalidates the precompiled statements
40+
# for `fold = Val(false)` and itself doesn't precompile anyway.
41+
# substitute(ex, rules1)
42+
substitute(ex, rules1; fold = fold1)
43+
substitute(ex, rules2; fold = fold1)
44+
@variables foo
45+
f(foo)
46+
@variables x y f(::Real) q[1:5]
47+
x + y
48+
x * y
49+
x / y
50+
x ^ y
51+
x ^ 5
52+
# 6 ^ x
53+
x - y
54+
-y
55+
2y
56+
symtype(y)
57+
z = 2
58+
*(y, z)
59+
*(z, y)
60+
f(x)
61+
(5x / 5)
62+
[x, y]
63+
[x, f, f]
64+
promote_type(Int, Num)
65+
promote_type(Real, Num)
66+
promote_type(Float64, Num)
67+
# expand((x + y) ^ 2)
68+
# simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y)
69+
ex = x + 2y + sin(x)
70+
rules1 = Dict(x => y)
71+
# rules2 = Dict(x => 1)
72+
# Running `fold = Val(true)` invalidates the precompiled statements
73+
# for `fold = Val(false)` and itself doesn't precompile anyway.
74+
# substitute(ex, rules1)
75+
substitute(ex, rules1; fold = fold1)
76+
Symbolics.linear_expansion(ex, y)
77+
# substitute(ex, rules2; fold = fold1)
78+
# substitute(ex, rules2)
79+
# substitute(ex, rules1; fold = fold2)
80+
# substitute(ex, rules2; fold = fold2)
81+
q[1]
82+
q'q
83+
using ModelingToolkitBase
84+
@variables x(ModelingToolkitBase.t_nounits) y(ModelingToolkitBase.t_nounits)
85+
isequal(ModelingToolkitBase.D_nounits.x, ModelingToolkitBase.t_nounits)
86+
ics = Dict{SymbolicT, SymbolicT}()
87+
ics[x] = 2.3
88+
sys = System([ModelingToolkitBase.D_nounits(x) ~ x * y, y ~ 2x], ModelingToolkitBase.t_nounits, [x, y], Num[]; initial_conditions = ics, guesses = ics, name = :sys)
89+
complete(sys)
90+
@static if @isdefined(ModelingToolkit)
91+
TearingState(sys)
92+
end
93+
mtkcompile(sys)
94+
@syms p[1:2]
95+
ndims(p)
96+
size(p)
97+
axes(p)
98+
length(p)
99+
v = [p]
100+
isempty(v)
101+
# mtkcompile(sys)
102+
end
103+
104+
precompile(Tuple{typeof(SymbolicUtils.isequal_somescalar), Float64, Float64})
105+
precompile(Tuple{typeof(Base.:(var"==")), ModelingToolkitBase.Initial, ModelingToolkitBase.Initial})

src/ModelingToolkit.jl

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,15 @@ import FillArrays
8888
using BipartiteGraphs
8989
import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, blockpush!,
9090
undef_blocks, blocks
91-
import StateSelection
92-
import StateSelection: CLIL
93-
import ModelingToolkitTearing as MTKTearing
94-
using ModelingToolkitTearing: TearingState, SystemStructure
9591

96-
ModelingToolkitBase.complete(dg::StateSelection.DiffGraph) = BipartiteGraphs.complete(dg)
92+
@recompile_invalidations begin
93+
import StateSelection
94+
import StateSelection: CLIL
95+
import ModelingToolkitTearing as MTKTearing
96+
using ModelingToolkitTearing: TearingState, SystemStructure
97+
98+
ModelingToolkitBase.complete(dg::StateSelection.DiffGraph) = BipartiteGraphs.complete(dg)
99+
end
97100

98101
macro import_mtkbase()
99102
allnames = names(MTKBase; all = true)
@@ -125,25 +128,27 @@ end
125128
using ModelingToolkitBase: COMMON_SENTINEL, COMMON_NOTHING, COMMON_MISSING,
126129
COMMON_TRUE, COMMON_FALSE, COMMON_INF
127130

128-
include("linearization.jl")
129-
include("systems/analysis_points.jl")
130-
include("systems/solver_nlprob.jl")
131-
132-
include("problems/docs.jl")
133-
include("systems/codegen.jl")
134-
include("problems/semilinearodeproblem.jl")
135-
include("problems/sccnonlinearproblem.jl")
136-
137-
include("discretedomain.jl")
138-
include("systems/systemstructure.jl")
139-
include("initialization.jl")
140-
include("systems/systems.jl")
141-
include("systems/clock_inference.jl")
142-
include("systems/if_lifting.jl")
143-
include("systems/substitute_component.jl")
144-
145-
include("systems/alias_elimination.jl")
146-
include("structural_transformation/StructuralTransformations.jl")
131+
@recompile_invalidations begin
132+
include("linearization.jl")
133+
include("systems/analysis_points.jl")
134+
include("systems/solver_nlprob.jl")
135+
136+
include("problems/docs.jl")
137+
include("systems/codegen.jl")
138+
include("problems/semilinearodeproblem.jl")
139+
include("problems/sccnonlinearproblem.jl")
140+
141+
include("discretedomain.jl")
142+
include("systems/systemstructure.jl")
143+
include("initialization.jl")
144+
include("systems/systems.jl")
145+
include("systems/clock_inference.jl")
146+
include("systems/if_lifting.jl")
147+
include("systems/substitute_component.jl")
148+
149+
include("systems/alias_elimination.jl")
150+
include("structural_transformation/StructuralTransformations.jl")
151+
end
147152

148153
@reexport using .StructuralTransformations
149154

@@ -165,4 +170,5 @@ function FMIComponent end
165170
@public linearize_symbolic, reorder_unknowns
166171
@public similarity_transform
167172

173+
include(pkgdir(ModelingToolkitBase, "src", "precompile.jl"))
168174
end # module

0 commit comments

Comments
 (0)