Skip to content

Commit a0a7c93

Browse files
authored
Use Sparspak for general Reals and \ for non-GPL version
* Generic element type handling based on Sparspak * Create Sparspak solver directly from matrix * Use sparspak to define `\` for all cases
1 parent a4a5bc0 commit a0a7c93

26 files changed

+986
-458
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
name = "ExtendableSparse"
22
uuid = "95c220a8-a1cf-11e9-0c77-dbfce5f500b3"
33
authors = ["Juergen Fuhrmann <[email protected]>"]
4-
version = "0.8.0"
4+
version = "0.9.0"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1010
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1111
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
12+
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
1213
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
1314
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1415

1516
[compat]
1617
DocStringExtensions = "0.8.0,0.9"
1718
Requires = "1.1.3"
1819
julia = "1.6"
20+
Sparspak= "0.3.0"
1921

2022
[extras]
2123
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

src/ExtendableSparse.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
module ExtendableSparse
22
using SparseArrays
33
using LinearAlgebra
4+
using Sparspak
45

5-
if Base.USE_GPL_LIBS
6+
7+
# Define our own constant here in order to be able to
8+
# test things at least a little bit..
9+
const USE_GPL_LIBS=Base.USE_GPL_LIBS
10+
11+
12+
if USE_GPL_LIBS
613
using SuiteSparse
714
end
815

@@ -42,7 +49,6 @@ function __init__()
4249
@require Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" include("pardiso_lu.jl")
4350
@require IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895" include("ilut.jl")
4451
@require AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c" include("amg.jl")
45-
@require Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac" include("sparspak.jl")
4652
end
4753

4854

src/SparspakCSCInterface.jl

+278
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
module SparspakCSCInterface
2+
#
3+
# Essentially, this is the code for https://github.com/PetrKryslUCSD/Sparspak.jl/pull/9
4+
# If that PR is accepted, we can remove the code from here
5+
#
6+
using SparseArrays, LinearAlgebra
7+
using Sparspak.SpkGraph: Graph
8+
using Sparspak.SpkOrdering: Ordering
9+
using Sparspak.SpkETree: ETree
10+
using Sparspak.SpkSparseBase: _SparseBase
11+
using Sparspak.SpkSparseSolver: Problem,SparseSolver, findorder!,symbolicfactor!,inmatrix!,factor!,triangularsolve!
12+
import Sparspak.SpkSparseSolver: solve!, inmatrix!
13+
import Sparspak.SpkSparseBase: _inmatrix!
14+
15+
16+
function Graph(m::SparseMatrixCSC{FT,IT}, diagonal=false) where {FT,IT}
17+
nv = size(m,1)
18+
nrows = size(m,2)
19+
ncols = size(m,1)
20+
colptr = SparseArrays.getcolptr(m)
21+
rowval = SparseArrays.getrowval(m)
22+
23+
24+
if (diagonal)
25+
nedges = nnz(m)
26+
else
27+
dedges=0
28+
for i in 1:ncols
29+
for iptr in colptr[i]:colptr[i+1]-1
30+
if rowval[iptr]==i
31+
dedges+=1
32+
continue
33+
end
34+
end
35+
end
36+
nedges = nnz(m) - dedges
37+
end
38+
39+
#jf if diagonal == true, we possibly can just use colptr & rowval
40+
#jf and skip the loop
41+
42+
xadj = fill(zero(IT), nv + 1)
43+
adj = fill(zero(IT), nedges)
44+
45+
k = 1
46+
for i in 1:ncols
47+
xadj[i] = k
48+
for iptr in colptr[i]:colptr[i+1]-1
49+
j = rowval[iptr]
50+
if (i != j || diagonal)
51+
adj[k] = j
52+
k = k + 1
53+
end
54+
end
55+
end
56+
57+
xadj[ncols+1] = k
58+
59+
return Graph(nv, nedges, nrows, ncols, xadj, adj)
60+
end
61+
62+
63+
64+
65+
function _SparseBase(m::SparseMatrixCSC{FT,IT}) where {IT,FT}
66+
maxblocksize = 30 # This can be set by the user
67+
68+
tempsizeneed = zero(IT)
69+
n = size(m,2)
70+
nnz = SparseArrays.nnz(m)
71+
nnzl = zero(IT)
72+
nsub = zero(IT)
73+
74+
nsuper = zero(IT)
75+
if (n > zero(IT))
76+
nsuper = 1
77+
end
78+
79+
factorops = zero(FT)
80+
solveops = zero(FT)
81+
realstore = zero(FT)
82+
integerstore = zero(FT)
83+
errflag = 0
84+
85+
order = Ordering(n) # ordering object for the solver
86+
g = Graph(m)
87+
t = ETree(n)
88+
89+
colcnt = IT[]
90+
snode = IT[]
91+
xsuper = IT[]
92+
xlindx = IT[]
93+
lindx = IT[]
94+
xlnz = IT[]
95+
xunz = IT[]
96+
ipiv = IT[]
97+
lnz = FT[]
98+
unz = FT[]
99+
100+
return _SparseBase(order, t, g, errflag, n, nnz, nnzl, nsub, nsuper, maxblocksize,
101+
tempsizeneed, factorops, solveops, realstore, integerstore,
102+
colcnt, snode, xsuper, xlindx, lindx, xlnz, xunz, ipiv,
103+
lnz, unz)
104+
end
105+
106+
107+
function _inmatrix!(s::_SparseBase{IT, FT}, m::SparseMatrixCSC{FT,IT}) where {IT, FT}
108+
if (s.n == 0)
109+
@error "$(@__FILE__): An empty problem. No matrix."
110+
return false
111+
end
112+
113+
s.lnz .= zero(FT)
114+
s.unz .= zero(FT)
115+
s.ipiv .= zero(IT)
116+
117+
function doit(ncols, colptr, rowval, cinvp, rinvp, snode, xsuper, xlindx, lindx, nzval, xlnz, lnz, xunz, unz)
118+
for i in 1:ncols
119+
for iptr in colptr[i]:colptr[i+1]-1
120+
inew = rinvp[rowval[iptr]];
121+
jnew = cinvp[i]
122+
value = nzval[iptr]
123+
## jf: all of this could go into a function, so we can keep things synced
124+
if (inew >= xsuper[snode[jnew]])
125+
# Lies in L. get pointers and lengths needed to search
126+
# column jnew of L for location l(inew, jnew).
127+
jsup = snode[jnew];
128+
fstcol = xsuper[jsup]
129+
fstsub = xlindx[jsup]
130+
lstsub = xlindx[jsup + 1] - 1
131+
nnzloc = 0;
132+
for nxtsub in fstsub:lstsub
133+
irow = lindx[nxtsub]
134+
if (irow > inew)
135+
@error "$(@__FILE__): No space for matrix element $(inew), $(jnew)."
136+
return false
137+
end
138+
if (irow == inew)
139+
# find a proper offset into lnz and increment by value
140+
_p = xlnz[jnew] + nnzloc
141+
lnz[_p] += value
142+
break
143+
end
144+
nnzloc = nnzloc + 1
145+
end
146+
else
147+
# Lies in U
148+
jsup = snode[inew]
149+
fstcol = xsuper[jsup]
150+
lstcol = xsuper[jsup + 1] - 1
151+
width = lstcol - fstcol + 1
152+
lstsub = xlindx[jsup + 1] - 1
153+
fstsub = xlindx[jsup] + width
154+
nnzloc = 0;
155+
for nxtsub in fstsub:lstsub
156+
irow = lindx[nxtsub]
157+
if (irow > jnew)
158+
@error "$(@__FILE__): No space for matrix element $(inew), $(jnew)."
159+
return false
160+
end
161+
if (irow == jnew)
162+
# find a proper offset into unz and increment by value
163+
_p = xunz[inew] + nnzloc
164+
unz[_p] += value
165+
break
166+
end
167+
nnzloc = nnzloc + 1
168+
end
169+
end
170+
end
171+
end
172+
return true
173+
end
174+
return doit(size(m,1), m.colptr, m.rowval, s.order.rinvp, s.order.cinvp, s.snode, s.xsuper, s.xlindx, s.lindx, m.nzval, s.xlnz, s.lnz, s.xunz, s.unz)
175+
end
176+
177+
178+
function inmatrix!(s::SparseSolver, m::SparseMatrixCSC)
179+
if (s._inmatrixdone)
180+
return true
181+
end
182+
if ( ! s._symbolicdone)
183+
error("Sequence error. Symbolic factor not done yet.")
184+
return false
185+
end
186+
success = _inmatrix!(s.slvr, m)
187+
s._inmatrixdone = true
188+
s._factordone = false
189+
return success
190+
end
191+
192+
193+
194+
function SparseSolver(m::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
195+
ma = size(m,2)
196+
na = size(m,1)
197+
mc = 0
198+
nc = 0
199+
n = ma
200+
slvr = _SparseBase(m)
201+
_orderdone = false
202+
_symbolicdone = false
203+
_inmatrixdone = false
204+
_factordone = false
205+
_trisolvedone = false
206+
_refinedone = false
207+
_condestdone = false
208+
dummyproblem=Problem(1,1,1,zero(Tv))
209+
return SparseSolver(dummyproblem, slvr, n, ma, na, mc, nc, _inmatrixdone, _orderdone, _symbolicdone, _factordone, _trisolvedone, _refinedone, _condestdone)
210+
end
211+
212+
213+
function solve!(s::SparseSolver, m::SparseMatrixCSC,rhs)
214+
findorder!(s) || ErrorException("Finding Order.")
215+
symbolicfactor!(s) || ErrorException("Symbolic Factorization.")
216+
inmatrix!(s,m) || ErrorException("Matrix input.")
217+
factor!(s) || ErrorException("Numerical Factorization.")
218+
temp=copy(rhs)
219+
triangularsolve!(s,temp) || ErrorException("Triangular Solve.")
220+
return temp
221+
end
222+
223+
#########################################################################
224+
# SparspakLU
225+
226+
"""
227+
sparspaklu(m)
228+
229+
Calculate LU factorization using Sparspak. Steps are
230+
`findorder`, `symbolicfactor`, `factor`.
231+
232+
Returns a Sparspak.SpkSparseSolver.SparseSolver instance,
233+
which has methods for `LinearAlgebra.ldiv!` and `Base.:\` .
234+
"""
235+
function sparspaklu(m::SparseMatrixCSC)
236+
lu=SparseSolver(m)
237+
findorder!(lu) || ErrorException("Finding Order.")
238+
symbolicfactor!(lu) || ErrorException("Symbolic Factorization.")
239+
inmatrix!(lu,m) || ErrorException("Matrix input.")
240+
factor!(lu) || ErrorException("Numerical Factorization.")
241+
lu
242+
end
243+
244+
245+
"""
246+
sparspaklu!(lu,m)
247+
248+
Calculate numerical LU factorization,
249+
reusing ordering and symbolic factorization.
250+
"""
251+
function sparspaklu!(lu::SparseSolver, m::SparseMatrixCSC)
252+
#JF: check if structure is the same
253+
lu.p=m
254+
lu._inmatrixdone = false
255+
lu._factordone = false
256+
lu._trisolvedone = false
257+
inmatrix!(lu,m) || ErrorException("Matrix input.")
258+
factor!(lu) || ErrorException("Numerical Factorization.")
259+
lu
260+
end
261+
262+
function LinearAlgebra.ldiv!(u, lu::SparseSolver, v)
263+
u.=v
264+
triangularsolve!(lu,u) || ErrorException("Triangular Solve.")
265+
lu._trisolvedone = false
266+
u
267+
end
268+
269+
function LinearAlgebra.ldiv!(lu::SparseSolver, v)
270+
triangularsolve!(lu,v) || ErrorException("Triangular Solve.")
271+
lu._trisolvedone = false
272+
v
273+
end
274+
275+
Base.:\(lu::SparseSolver, v)=ldiv!(lu,copy(v))
276+
277+
export sparsepaklu!,sparspaklu
278+
end

src/extendable.jl

+23-4
Original file line numberDiff line numberDiff line change
@@ -265,19 +265,38 @@ function SparseArrays.findnz(ext::ExtendableSparseMatrix)
265265
return findnz(ext.cscmatrix)
266266
end
267267

268+
if USE_GPL_LIBS
268269

269270

271+
for (Tv) in (:Float64,:ComplexF64)
272+
@eval begin
270273
"""
271-
$(SIGNATURES)
274+
$(TYPEDSIGNATURES)
272275
273-
[`\\`](@ref) for extmatrix
276+
[`\\`](@ref) for ExtendableSparse for $($Tv)
274277
"""
275-
function LinearAlgebra.:\(ext::ExtendableSparseMatrix,B::AbstractVecOrMat{T} where T)
278+
function LinearAlgebra.:\(ext::ExtendableSparseMatrix{$Tv,Ti}, B::AbstractVecOrMat{$Tv}) where Ti
279+
flush!(ext)
280+
ext.cscmatrix\B
281+
end
282+
end
283+
end
284+
285+
end # USE_GPL_LIBS
286+
287+
288+
"""
289+
$(TYPEDSIGNATURES)
290+
291+
[`\\`](@ref) for ExtendableSparse for generic floating point. This calls Sparspak.jl.
292+
"""
293+
function LinearAlgebra.:\(ext::ExtendableSparseMatrix{Tv,Ti}, b::AbstractVector{Tv}) where {Tv,Ti}
276294
flush!(ext)
277-
ext.cscmatrix\B
295+
SparspakLU(ext)\b
278296
end
279297

280298

299+
281300
"""
282301
$(SIGNATURES)
283302

0 commit comments

Comments
 (0)