|
| 1 | +module SparspakCSCInterface |
| 2 | +# |
| 3 | +# Essentially, this is the code for https://github.com/PetrKryslUCSD/Sparspak.jl/pull/9 |
| 4 | +# If that PR is accepted, we can remove the code from here |
| 5 | +# |
| 6 | +using SparseArrays, LinearAlgebra |
| 7 | +using Sparspak.SpkGraph: Graph |
| 8 | +using Sparspak.SpkOrdering: Ordering |
| 9 | +using Sparspak.SpkETree: ETree |
| 10 | +using Sparspak.SpkSparseBase: _SparseBase |
| 11 | +using Sparspak.SpkSparseSolver: Problem,SparseSolver, findorder!,symbolicfactor!,inmatrix!,factor!,triangularsolve! |
| 12 | +import Sparspak.SpkSparseSolver: solve!, inmatrix! |
| 13 | +import Sparspak.SpkSparseBase: _inmatrix! |
| 14 | + |
| 15 | + |
| 16 | +function Graph(m::SparseMatrixCSC{FT,IT}, diagonal=false) where {FT,IT} |
| 17 | + nv = size(m,1) |
| 18 | + nrows = size(m,2) |
| 19 | + ncols = size(m,1) |
| 20 | + colptr = SparseArrays.getcolptr(m) |
| 21 | + rowval = SparseArrays.getrowval(m) |
| 22 | + |
| 23 | + |
| 24 | + if (diagonal) |
| 25 | + nedges = nnz(m) |
| 26 | + else |
| 27 | + dedges=0 |
| 28 | + for i in 1:ncols |
| 29 | + for iptr in colptr[i]:colptr[i+1]-1 |
| 30 | + if rowval[iptr]==i |
| 31 | + dedges+=1 |
| 32 | + continue |
| 33 | + end |
| 34 | + end |
| 35 | + end |
| 36 | + nedges = nnz(m) - dedges |
| 37 | + end |
| 38 | + |
| 39 | + #jf if diagonal == true, we possibly can just use colptr & rowval |
| 40 | + #jf and skip the loop |
| 41 | + |
| 42 | + xadj = fill(zero(IT), nv + 1) |
| 43 | + adj = fill(zero(IT), nedges) |
| 44 | + |
| 45 | + k = 1 |
| 46 | + for i in 1:ncols |
| 47 | + xadj[i] = k |
| 48 | + for iptr in colptr[i]:colptr[i+1]-1 |
| 49 | + j = rowval[iptr] |
| 50 | + if (i != j || diagonal) |
| 51 | + adj[k] = j |
| 52 | + k = k + 1 |
| 53 | + end |
| 54 | + end |
| 55 | + end |
| 56 | + |
| 57 | + xadj[ncols+1] = k |
| 58 | + |
| 59 | + return Graph(nv, nedges, nrows, ncols, xadj, adj) |
| 60 | +end |
| 61 | + |
| 62 | + |
| 63 | + |
| 64 | + |
| 65 | +function _SparseBase(m::SparseMatrixCSC{FT,IT}) where {IT,FT} |
| 66 | + maxblocksize = 30 # This can be set by the user |
| 67 | + |
| 68 | + tempsizeneed = zero(IT) |
| 69 | + n = size(m,2) |
| 70 | + nnz = SparseArrays.nnz(m) |
| 71 | + nnzl = zero(IT) |
| 72 | + nsub = zero(IT) |
| 73 | + |
| 74 | + nsuper = zero(IT) |
| 75 | + if (n > zero(IT)) |
| 76 | + nsuper = 1 |
| 77 | + end |
| 78 | + |
| 79 | + factorops = zero(FT) |
| 80 | + solveops = zero(FT) |
| 81 | + realstore = zero(FT) |
| 82 | + integerstore = zero(FT) |
| 83 | + errflag = 0 |
| 84 | + |
| 85 | + order = Ordering(n) # ordering object for the solver |
| 86 | + g = Graph(m) |
| 87 | + t = ETree(n) |
| 88 | + |
| 89 | + colcnt = IT[] |
| 90 | + snode = IT[] |
| 91 | + xsuper = IT[] |
| 92 | + xlindx = IT[] |
| 93 | + lindx = IT[] |
| 94 | + xlnz = IT[] |
| 95 | + xunz = IT[] |
| 96 | + ipiv = IT[] |
| 97 | + lnz = FT[] |
| 98 | + unz = FT[] |
| 99 | + |
| 100 | + return _SparseBase(order, t, g, errflag, n, nnz, nnzl, nsub, nsuper, maxblocksize, |
| 101 | + tempsizeneed, factorops, solveops, realstore, integerstore, |
| 102 | + colcnt, snode, xsuper, xlindx, lindx, xlnz, xunz, ipiv, |
| 103 | + lnz, unz) |
| 104 | +end |
| 105 | + |
| 106 | + |
| 107 | +function _inmatrix!(s::_SparseBase{IT, FT}, m::SparseMatrixCSC{FT,IT}) where {IT, FT} |
| 108 | + if (s.n == 0) |
| 109 | + @error "$(@__FILE__): An empty problem. No matrix." |
| 110 | + return false |
| 111 | + end |
| 112 | + |
| 113 | + s.lnz .= zero(FT) |
| 114 | + s.unz .= zero(FT) |
| 115 | + s.ipiv .= zero(IT) |
| 116 | + |
| 117 | + function doit(ncols, colptr, rowval, cinvp, rinvp, snode, xsuper, xlindx, lindx, nzval, xlnz, lnz, xunz, unz) |
| 118 | + for i in 1:ncols |
| 119 | + for iptr in colptr[i]:colptr[i+1]-1 |
| 120 | + inew = rinvp[rowval[iptr]]; |
| 121 | + jnew = cinvp[i] |
| 122 | + value = nzval[iptr] |
| 123 | +## jf: all of this could go into a function, so we can keep things synced |
| 124 | + if (inew >= xsuper[snode[jnew]]) |
| 125 | +# Lies in L. get pointers and lengths needed to search |
| 126 | +# column jnew of L for location l(inew, jnew). |
| 127 | + jsup = snode[jnew]; |
| 128 | + fstcol = xsuper[jsup] |
| 129 | + fstsub = xlindx[jsup] |
| 130 | + lstsub = xlindx[jsup + 1] - 1 |
| 131 | + nnzloc = 0; |
| 132 | + for nxtsub in fstsub:lstsub |
| 133 | + irow = lindx[nxtsub] |
| 134 | + if (irow > inew) |
| 135 | + @error "$(@__FILE__): No space for matrix element $(inew), $(jnew)." |
| 136 | + return false |
| 137 | + end |
| 138 | + if (irow == inew) |
| 139 | +# find a proper offset into lnz and increment by value |
| 140 | + _p = xlnz[jnew] + nnzloc |
| 141 | + lnz[_p] += value |
| 142 | + break |
| 143 | + end |
| 144 | + nnzloc = nnzloc + 1 |
| 145 | + end |
| 146 | + else |
| 147 | +# Lies in U |
| 148 | + jsup = snode[inew] |
| 149 | + fstcol = xsuper[jsup] |
| 150 | + lstcol = xsuper[jsup + 1] - 1 |
| 151 | + width = lstcol - fstcol + 1 |
| 152 | + lstsub = xlindx[jsup + 1] - 1 |
| 153 | + fstsub = xlindx[jsup] + width |
| 154 | + nnzloc = 0; |
| 155 | + for nxtsub in fstsub:lstsub |
| 156 | + irow = lindx[nxtsub] |
| 157 | + if (irow > jnew) |
| 158 | + @error "$(@__FILE__): No space for matrix element $(inew), $(jnew)." |
| 159 | + return false |
| 160 | + end |
| 161 | + if (irow == jnew) |
| 162 | +# find a proper offset into unz and increment by value |
| 163 | + _p = xunz[inew] + nnzloc |
| 164 | + unz[_p] += value |
| 165 | + break |
| 166 | + end |
| 167 | + nnzloc = nnzloc + 1 |
| 168 | + end |
| 169 | + end |
| 170 | + end |
| 171 | + end |
| 172 | + return true |
| 173 | + end |
| 174 | + return doit(size(m,1), m.colptr, m.rowval, s.order.rinvp, s.order.cinvp, s.snode, s.xsuper, s.xlindx, s.lindx, m.nzval, s.xlnz, s.lnz, s.xunz, s.unz) |
| 175 | +end |
| 176 | + |
| 177 | + |
| 178 | +function inmatrix!(s::SparseSolver, m::SparseMatrixCSC) |
| 179 | + if (s._inmatrixdone) |
| 180 | + return true |
| 181 | + end |
| 182 | + if ( ! s._symbolicdone) |
| 183 | + error("Sequence error. Symbolic factor not done yet.") |
| 184 | + return false |
| 185 | + end |
| 186 | + success = _inmatrix!(s.slvr, m) |
| 187 | + s._inmatrixdone = true |
| 188 | + s._factordone = false |
| 189 | + return success |
| 190 | +end |
| 191 | + |
| 192 | + |
| 193 | + |
| 194 | +function SparseSolver(m::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} |
| 195 | + ma = size(m,2) |
| 196 | + na = size(m,1) |
| 197 | + mc = 0 |
| 198 | + nc = 0 |
| 199 | + n = ma |
| 200 | + slvr = _SparseBase(m) |
| 201 | + _orderdone = false |
| 202 | + _symbolicdone = false |
| 203 | + _inmatrixdone = false |
| 204 | + _factordone = false |
| 205 | + _trisolvedone = false |
| 206 | + _refinedone = false |
| 207 | + _condestdone = false |
| 208 | + dummyproblem=Problem(1,1,1,zero(Tv)) |
| 209 | + return SparseSolver(dummyproblem, slvr, n, ma, na, mc, nc, _inmatrixdone, _orderdone, _symbolicdone, _factordone, _trisolvedone, _refinedone, _condestdone) |
| 210 | +end |
| 211 | + |
| 212 | + |
| 213 | +function solve!(s::SparseSolver, m::SparseMatrixCSC,rhs) |
| 214 | + findorder!(s) || ErrorException("Finding Order.") |
| 215 | + symbolicfactor!(s) || ErrorException("Symbolic Factorization.") |
| 216 | + inmatrix!(s,m) || ErrorException("Matrix input.") |
| 217 | + factor!(s) || ErrorException("Numerical Factorization.") |
| 218 | + temp=copy(rhs) |
| 219 | + triangularsolve!(s,temp) || ErrorException("Triangular Solve.") |
| 220 | + return temp |
| 221 | +end |
| 222 | + |
| 223 | +######################################################################### |
| 224 | +# SparspakLU |
| 225 | + |
| 226 | +""" |
| 227 | + sparspaklu(m) |
| 228 | +
|
| 229 | +Calculate LU factorization using Sparspak. Steps are |
| 230 | +`findorder`, `symbolicfactor`, `factor`. |
| 231 | +
|
| 232 | +Returns a Sparspak.SpkSparseSolver.SparseSolver instance, |
| 233 | +which has methods for `LinearAlgebra.ldiv!` and `Base.:\` . |
| 234 | +""" |
| 235 | +function sparspaklu(m::SparseMatrixCSC) |
| 236 | + lu=SparseSolver(m) |
| 237 | + findorder!(lu) || ErrorException("Finding Order.") |
| 238 | + symbolicfactor!(lu) || ErrorException("Symbolic Factorization.") |
| 239 | + inmatrix!(lu,m) || ErrorException("Matrix input.") |
| 240 | + factor!(lu) || ErrorException("Numerical Factorization.") |
| 241 | + lu |
| 242 | +end |
| 243 | + |
| 244 | + |
| 245 | +""" |
| 246 | + sparspaklu!(lu,m) |
| 247 | +
|
| 248 | +Calculate numerical LU factorization, |
| 249 | +reusing ordering and symbolic factorization. |
| 250 | +""" |
| 251 | +function sparspaklu!(lu::SparseSolver, m::SparseMatrixCSC) |
| 252 | + #JF: check if structure is the same |
| 253 | + lu.p=m |
| 254 | + lu._inmatrixdone = false |
| 255 | + lu._factordone = false |
| 256 | + lu._trisolvedone = false |
| 257 | + inmatrix!(lu,m) || ErrorException("Matrix input.") |
| 258 | + factor!(lu) || ErrorException("Numerical Factorization.") |
| 259 | + lu |
| 260 | +end |
| 261 | + |
| 262 | +function LinearAlgebra.ldiv!(u, lu::SparseSolver, v) |
| 263 | + u.=v |
| 264 | + triangularsolve!(lu,u) || ErrorException("Triangular Solve.") |
| 265 | + lu._trisolvedone = false |
| 266 | + u |
| 267 | +end |
| 268 | + |
| 269 | +function LinearAlgebra.ldiv!(lu::SparseSolver, v) |
| 270 | + triangularsolve!(lu,v) || ErrorException("Triangular Solve.") |
| 271 | + lu._trisolvedone = false |
| 272 | + v |
| 273 | +end |
| 274 | + |
| 275 | +Base.:\(lu::SparseSolver, v)=ldiv!(lu,copy(v)) |
| 276 | + |
| 277 | +export sparsepaklu!,sparspaklu |
| 278 | +end |
0 commit comments