@@ -70,34 +70,34 @@ tspan = (0.0, 100.0)
70
70
# and with the initialization corrected to satisfy the algebraic equation
71
71
prob_incorrectu0 = ODEProblem (sys, u0_incorrect, tspan, p, jac = true , guesses = [w2 => 0.0 ])
72
72
mtkparams_incorrectu0 = SciMLSensitivity. parameter_values (prob_incorrectu0)
73
+ test_sol = solve (prob_incorrectu0, Rodas5P (), abstol = 1e-6 , reltol = 1e-3 )
73
74
74
75
u0_timedep = [D (x) => 2.0 ,
75
76
x => 1.0 ,
76
77
y => t,
77
- z => 0.0 ,
78
- w2 => 0.0 ,]
78
+ z => 0.0 ]
79
79
# this ensures that `y => t` is not applied in the adjoint equation
80
80
# If the MTK init is called for the reverse, then `y0` in the backwards
81
81
# pass will be extremely far off and cause an incorrect gradient
82
82
prob_timedepu0 = ODEProblem (sys, u0_timedep, tspan, p, jac = true , guesses = [w2 => 0.0 ])
83
83
mtkparams_timedepu0 = SciMLSensitivity. parameter_values (prob_incorrectu0)
84
+ test_sol = solve (prob_timedepu0, Rodas5P (), abstol = 1e-6 , reltol = 1e-3 )
84
85
85
86
u0_correct = [D (x) => 2.0 ,
86
87
x => 1.0 ,
87
88
y => 0.0 ,
88
- z => 0.0 ,
89
- w2 => - 1.0 ,]
89
+ z => 0.0 ,]
90
90
prob_correctu0 = ODEProblem (sys, u0_correct, tspan, p, jac = true , guesses = [w2 => - 1.0 ])
91
91
mtkparams_correctu0 = SciMLSensitivity. parameter_values (prob_correctu0)
92
- prob_correctu0. u0[5 ] = - 1.0
93
-
92
+ test_sol = solve (prob_correctu0, Rodas5P (), abstol = 1e-6 , reltol = 1e-3 )
94
93
u0_overdetermined = [D (x) => 2.0 ,
95
94
x => 1.0 ,
96
95
y => 0.0 ,
97
96
z => 0.0 ,
98
97
w2 => - 1.0 ,]
99
98
prob_overdetermined = ODEProblem (sys, u0_overdetermined, tspan, p, jac = true )
100
99
mtkparams_overdetermined = SciMLSensitivity. parameter_values (prob_overdetermined)
100
+ test_sol = solve (prob_overdetermined, Rodas5P (), abstol = 1e-6 , reltol = 1e-3 )
101
101
102
102
sensealg = GaussAdjoint (; autojacvec = SciMLSensitivity. ZygoteVJP ())
103
103
@@ -115,25 +115,26 @@ setups = [
115
115
(prob_correctu0, mtkparams_correctu0, BrownFullBasicInit ()),
116
116
(prob_correctu0, mtkparams_correctu0, OrdinaryDiffEqCore. DefaultInit ()),
117
117
118
- (prob_correctu0, mtkparams_correctu0, NoInit ()),
118
+ (prob_correctu0, mtkparams_correctu0, NoInit ()),
119
119
(prob_correctu0, mtkparams_correctu0, nothing ),
120
120
121
121
(prob_overdetermined, mtkparams_overdetermined, BrownFullBasicInit ()),
122
122
(prob_overdetermined, mtkparams_overdetermined, OrdinaryDiffEq. OrdinaryDiffEqCore. DefaultInit ()),
123
123
124
124
(prob_overdetermined, mtkparams_overdetermined, NoInit ()),
125
125
(prob_overdetermined, mtkparams_overdetermined, nothing ),
126
- ]
126
+ ];
127
127
128
128
grads = map (setups) do setup
129
129
prob, ps, init = setup
130
130
@show init
131
131
u0 = prob. u0
132
132
Zygote. gradient (u0, ps) do u0,p
133
+ new_prob = remake (prob, u0 = u0, p = p)
133
134
if init === nothing
134
- new_sol = solve (prob , Rodas5P (); u0 = u0, p = ps, sensealg, abstol = 1e-6 , reltol = 1e-3 )
135
+ new_sol = solve (new_prob , Rodas5P (); sensealg, abstol = 1e-6 , reltol = 1e-3 )
135
136
else
136
- new_sol = solve (prob , Rodas5P (); u0 = u0, p = ps, initializealg = init, sensealg, abstol = 1e-6 , reltol = 1e-3 )
137
+ new_sol = solve (new_prob , Rodas5P (); initializealg = init, sensealg, abstol = 1e-6 , reltol = 1e-3 )
137
138
end
138
139
gt = Zygote. ChainRules. ChainRulesCore. ignore_derivatives () do
139
140
@test new_sol. retcode == SciMLBase. ReturnCode. Success
148
149
149
150
u0grads = getindex .(grads,1 )
150
151
pgrads = getproperty .(getindex .(grads, 2 ), (:tunable ,))
151
- @test all (x ≈ u0grads[1 ] for x in grads )
152
- @test all (x ≈ pgrads[1 ] for x in grads )
152
+ @test all (x ≈ u0grads[1 ] for x in u0grads )
153
+ @test all (x ≈ pgrads[1 ] for x in pgrads )
0 commit comments