Skip to content

Commit b2f9c51

Browse files
Merge pull request #406 from AayushSabharwal/as/nlintegrator-indexing
feat: support indexing AbstractNonlinearSolveCache
2 parents e3238a7 + 1b501f8 commit b2f9c51

11 files changed

+109
-6
lines changed

.github/workflows/CI.yml

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ jobs:
1919
matrix:
2020
group:
2121
- Core
22+
- Downstream
2223
version:
2324
- "1"
2425
os:

Project.toml

+7-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2727
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2828
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
2929
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
30+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3031
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3132

3233
[weakdeps]
@@ -35,8 +36,8 @@ FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
3536
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
3637
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
3738
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
38-
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
3939
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
40+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
4041
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
4142
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
4243
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
@@ -48,8 +49,8 @@ NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
4849
NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
4950
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
5051
NonlinearSolveMINPACKExt = "MINPACK"
51-
NonlinearSolveNLsolveExt = "NLsolve"
5252
NonlinearSolveNLSolversExt = "NLSolvers"
53+
NonlinearSolveNLsolveExt = "NLsolve"
5354
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
5455
NonlinearSolveSpeedMappingExt = "SpeedMapping"
5556
NonlinearSolveSymbolicsExt = "Symbolics"
@@ -61,8 +62,8 @@ Aqua = "0.8"
6162
ArrayInterface = "7.9"
6263
BandedMatrices = "1.4"
6364
BenchmarkTools = "1.4"
64-
ConcreteStructs = "0.2.3"
6565
CUDA = "5.2"
66+
ConcreteStructs = "0.2.3"
6667
DiffEqBase = "6.149.0"
6768
Enzyme = "0.11.15"
6869
FastBroadcast = "0.2.8"
@@ -78,8 +79,8 @@ LinearAlgebra = "1.10"
7879
LinearSolve = "2.21"
7980
MINPACK = "1.2"
8081
MaybeInplace = "0.1.1"
81-
NLsolve = "4.5"
8282
NLSolvers = "0.5"
83+
NLsolve = "4.5"
8384
NaNMath = "1"
8485
NonlinearProblemLibrary = "0.1.2"
8586
OrdinaryDiffEq = "6.74"
@@ -103,6 +104,7 @@ StaticArrays = "1.7"
103104
StaticArraysCore = "1.4"
104105
Sundials = "4.23.1"
105106
Symbolics = "5.13"
107+
SymbolicIndexingInterface = "0.3.3"
106108
Test = "1.10"
107109
TimerOutputs = "0.5.23"
108110
Zygote = "0.6.69"
@@ -122,8 +124,8 @@ LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
122124
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
123125
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
124126
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
125-
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
126127
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
128+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
127129
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
128130
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
129131
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

docs/make.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ makedocs(; sitename = "NonlinearSolve.jl",
1818
clean = true,
1919
doctest = false,
2020
linkcheck = true,
21-
linkcheck_ignore = ["https://twitter.com/ChrisRackauckas/status/1544743542094020615"],
21+
linkcheck_ignore = ["https://twitter.com/ChrisRackauckas/status/1544743542094020615",
22+
"https://link.springer.com/article/10.1007/s40096-020-00339-4"],
2223
checkdocs = :exports,
2324
warnonly = [:missing_docs],
2425
plugins = [bib],

src/NonlinearSolve.jl

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
2929
AbstractSciMLOperator, NLStats, _unwrap_val, has_jac, isinplace
3030
import SparseDiffTools: AbstractSparsityDetection, AutoSparseEnzyme
3131
import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
32+
import SymbolicIndexingInterface: SymbolicIndexingInterface, ParameterIndexingProxy,
33+
symbolic_container, parameter_values, state_values,
34+
getu
3235
end
3336

3437
@reexport using ADTypes, SciMLBase, SimpleNonlinearSolve

src/abstract_types.jl

+21
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,27 @@ Abstract Type for all NonlinearSolve.jl Caches.
207207
"""
208208
abstract type AbstractNonlinearSolveCache{iip, timeit} end
209209

210+
function SymbolicIndexingInterface.symbolic_container(cache::AbstractNonlinearSolveCache)
211+
cache.prob
212+
end
213+
function SymbolicIndexingInterface.parameter_values(cache::AbstractNonlinearSolveCache)
214+
parameter_values(symbolic_container(cache))
215+
end
216+
function SymbolicIndexingInterface.state_values(cache::AbstractNonlinearSolveCache)
217+
state_values(symbolic_container(cache))
218+
end
219+
220+
function Base.getproperty(cache::AbstractNonlinearSolveCache, sym::Symbol)
221+
if sym == :ps
222+
return ParameterIndexingProxy(cache)
223+
end
224+
return getfield(cache, sym)
225+
end
226+
227+
function Base.getindex(cache::AbstractNonlinearSolveCache, sym)
228+
return getu(cache, sym)(cache)
229+
end
230+
210231
function Base.show(io::IO, cache::AbstractNonlinearSolveCache)
211232
__show_cache(io, cache, 0)
212233
end

src/core/generalized_first_order.jl

+5
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ concrete_jac(::GeneralizedFirstOrderAlgorithm{CJ}) where {CJ} = CJ
113113
force_stop::Bool
114114
end
115115

116+
SymbolicIndexingInterface.state_values(cache::GeneralizedFirstOrderAlgorithmCache) = cache.u
117+
function SymbolicIndexingInterface.parameter_values(cache::GeneralizedFirstOrderAlgorithmCache)
118+
cache.p
119+
end
120+
116121
function __reinit_internal!(
117122
cache::GeneralizedFirstOrderAlgorithmCache{iip}, args...; p = cache.p, u0 = cache.u,
118123
alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs...) where {iip}

src/default.jl

+5
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ end
7070
alias_u0::Bool
7171
end
7272

73+
function SymbolicIndexingInterface.symbolic_container(cache::NonlinearSolvePolyAlgorithmCache)
74+
cache.caches[cache.current]
75+
end
76+
SymbolicIndexingInterface.state_values(cache::NonlinearSolvePolyAlgorithmCache) = cache.u0
77+
7378
function Base.show(
7479
io::IO, cache::NonlinearSolvePolyAlgorithmCache{pType, N}) where {pType, N}
7580
problem_kind = ifelse(pType == :NLS, "NonlinearProblem", "NonlinearLeastSquaresProblem")

test/downstream/Project.toml

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
3+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
4+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"

test/downstream/cache_indexing.jl

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using ModelingToolkit, NonlinearSolve
2+
using ModelingToolkit: t_nounits as t
3+
4+
@parameters p d
5+
@variables X(t)
6+
eqs = [0 ~ sin(X + p) - d * sqrt(X + 1)]
7+
@mtkbuild nlsys = NonlinearSystem(eqs, [X], [p, d])
8+
9+
# Creates an integrator.
10+
nlprob = NonlinearProblem(nlsys, [X => 1.0], [p => 2.0, d => 3.0])
11+
12+
@testset "GeneralizedFirstOrderAlgorithmCache" begin
13+
nint = init(nlprob, NewtonRaphson())
14+
@test nint isa NonlinearSolve.GeneralizedFirstOrderAlgorithmCache
15+
16+
@test nint[X] == 1.0
17+
@test nint[nlsys.X] == 1.0
18+
@test nint[:X] == 1.0
19+
@test nint.ps[p] == 2.0
20+
@test nint.ps[nlsys.p] == 2.0
21+
@test nint.ps[:p] == 2.0
22+
end
23+
24+
@testset "NonlinearSolvePolyAlgorithmCache" begin
25+
nint = init(nlprob, FastShortcutNonlinearPolyalg())
26+
@test nint isa NonlinearSolve.NonlinearSolvePolyAlgorithmCache
27+
28+
@test nint[X] == 1.0
29+
@test nint[nlsys.X] == 1.0
30+
@test nint[:X] == 1.0
31+
@test nint.ps[p] == 2.0
32+
@test nint.ps[nlsys.p] == 2.0
33+
@test nint.ps[:p] == 2.0
34+
end
35+
36+
@testset "NonlinearSolveNoInitCache" begin
37+
nint = init(nlprob, SimpleNewtonRaphson())
38+
@test nint isa NonlinearSolve.NonlinearSolveNoInitCache
39+
40+
@test nint[X] == 1.0
41+
@test nint[nlsys.X] == 1.0
42+
@test nint[:X] == 1.0
43+
@test nint.ps[p] == 2.0
44+
@test nint.ps[nlsys.p] == 2.0
45+
@test nint.ps[:p] == 2.0
46+
end

test/downstream/downstream_tests.jl

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using Pkg
2+
using SafeTestsets
3+
4+
function activate_downstream_env()
5+
Pkg.activate("downstream")
6+
Pkg.develop(PackageSpec(path = dirname(dirname(@__DIR__))))
7+
Pkg.instantiate()
8+
end
9+
10+
activate_downstream_env()
11+
@safetestset "Cache indexing test" include("cache_indexing.jl")

test/runtests.jl

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ if GROUP == "All" || GROUP == "Core"
77
joinpath(@__DIR__, "wrappers/"))
88
end
99

10+
if GROUP == "Downstream"
11+
include("downstream/downstream_tests.jl")
12+
end
13+
1014
if GROUP == "GPU"
1115
ReTestItems.runtests(joinpath(@__DIR__, "gpu/"))
1216
end

0 commit comments

Comments
 (0)