Skip to content

Commit 622ddd3

Browse files
test: fix discrete system tests
1 parent 40d2139 commit 622ddd3

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

test/discrete_system.jl

+13-11
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
- https://github.com/epirecipes/sir-julia/blob/master/markdown/function_map/function_map.md
44
- https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#Deterministic_versus_stochastic_epidemic_models
55
=#
6-
using ModelingToolkit, Test
6+
using ModelingToolkit, SymbolicIndexingInterface, Test
77
using ModelingToolkit: t_nounits as t
88
using ModelingToolkit: get_metadata, MTKParameters
99

@@ -37,13 +37,15 @@ syss = structural_simplify(sys)
3737
df = DiscreteFunction(syss)
3838
# iip
3939
du = zeros(3)
40-
u = collect(1:3)
40+
u = ModelingToolkit.better_varmap_to_vars(Dict([S => 1, I => 2, R => 3]), unknowns(syss))
4141
p = MTKParameters(syss, [c, nsteps, δt, β, γ] .=> collect(1:5))
4242
df.f(du, u, p, 0)
43-
@test du [0.01831563888873422, 0.9816849729159067, 4.999999388195359]
43+
reorderer = getu(syss, [S, I, R])
44+
@test reorderer(du) [0.01831563888873422, 0.9816849729159067, 4.999999388195359]
4445

4546
# oop
46-
@test df.f(u, p, 0) [0.01831563888873422, 0.9816849729159067, 4.999999388195359]
47+
@test reorderer(df.f(u, p, 0))
48+
[0.01831563888873422, 0.9816849729159067, 4.999999388195359]
4749

4850
# Problem
4951
u0 = [S(k - 1) => 990.0, I(k - 1) => 10.0, R(k - 1) => 0.0]
@@ -98,12 +100,12 @@ function sir_map!(u_diff, u, p, t)
98100
end
99101
nothing
100102
end;
101-
u0 = prob_map2.u0;
103+
u0 = prob_map2[[S, I, R]];
102104
p = [0.05, 10.0, 0.25, 0.1];
103105
prob_map = DiscreteProblem(sir_map!, u0, tspan, p);
104106
sol_map2 = solve(prob_map, FunctionMap());
105107

106-
@test Array(sol_map) Array(sol_map2)
108+
@test reduce(hcat, sol_map[[S, I, R]]) Array(sol_map2)
107109

108110
# Delayed difference equation
109111
# @variables x(..) y(..) z(t)
@@ -317,9 +319,9 @@ end
317319

318320
import ModelingToolkit: shift2term
319321
# unknowns(de) = xₜ₋₁, x, zₜ₋₁, xₜ₋₂, z
320-
vars = ModelingToolkit.value.(unknowns(de))
321-
@test isequal(shift2term(Shift(t, 1)(vars[1])), vars[2])
322-
@test isequal(shift2term(Shift(t, 1)(vars[4])), vars[1])
323-
@test isequal(shift2term(Shift(t, -1)(vars[5])), vars[3])
324-
@test isequal(shift2term(Shift(t, -2)(vars[2])), vars[4])
322+
vars = sort(ModelingToolkit.value.(unknowns(de)); by = string)
323+
@test isequal(shift2term(Shift(t, 1)(vars[2])), vars[1])
324+
@test isequal(shift2term(Shift(t, 1)(vars[3])), vars[2])
325+
@test isequal(shift2term(Shift(t, -1)(vars[4])), vars[5])
326+
@test isequal(shift2term(Shift(t, -2)(vars[1])), vars[3])
325327
end

0 commit comments

Comments
 (0)