Skip to content

Commit f56147d

Browse files
authored
Merge pull request #21184 from iamnapo/anj/powm
Fixed the algorithm for powers of a matrix.
2 parents 8df5fbe + 2fbeba3 commit f56147d

File tree

7 files changed

+414
-29
lines changed

7 files changed

+414
-29
lines changed

base/linalg/dense.jl

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function scale!(X::Array{T}, s::Real) where T<:BlasComplex
2929
X
3030
end
3131

32-
#Test whether a matrix is positive-definite
32+
# Test whether a matrix is positive-definite
3333
isposdef!(A::StridedMatrix{<:BlasFloat}, UL::Symbol) = LAPACK.potrf!(char_uplo(UL), A)[2] == 0
3434

3535
"""
@@ -323,18 +323,67 @@ kron(a::AbstractVector, b::AbstractVector)=vec(kron(reshape(a,length(a),1),resha
323323
kron(a::AbstractMatrix, b::AbstractVector)=kron(a,reshape(b,length(b),1))
324324
kron(a::AbstractVector, b::AbstractMatrix)=kron(reshape(a,length(a),1),b)
325325

326-
^(A::AbstractMatrix, p::Integer) = p < 0 ? inv(A^-p) : Base.power_by_squaring(A,p)
327-
328-
function ^(A::AbstractMatrix, p::Number)
326+
# Matrix power
327+
^{T}(A::AbstractMatrix{T}, p::Integer) = p < 0 ? Base.power_by_squaring(inv(A), -p) : Base.power_by_squaring(A, p)
328+
function ^{T}(A::AbstractMatrix{T}, p::Real)
329+
# For integer powers, use repeated squaring
329330
if isinteger(p)
330-
return A^Integer(real(p))
331+
TT = Base.promote_op(^, eltype(A), typeof(p))
332+
return (TT == eltype(A) ? A : copy!(similar(A, TT), A))^Integer(p)
333+
end
334+
335+
# If possible, use diagonalization
336+
if T <: Real && issymmetric(A)
337+
return (Symmetric(A)^p)
338+
end
339+
if ishermitian(A)
340+
return (Hermitian(A)^p)
341+
end
342+
343+
n = checksquare(A)
344+
345+
# Quicker return if A is diagonal
346+
if isdiag(A)
347+
retmat = copy(A)
348+
for i in 1:n
349+
retmat[i, i] = retmat[i, i] ^ p
350+
end
351+
return retmat
352+
end
353+
354+
# Otherwise, use Schur decomposition
355+
if istriu(A)
356+
# Integer part
357+
retmat = A ^ floor(p)
358+
# Real part
359+
if p - floor(p) == 0.5
360+
# special case: A^0.5 === sqrtm(A)
361+
retmat = retmat * sqrtm(A)
362+
else
363+
retmat = retmat * powm!(UpperTriangular(float.(A)), real(p - floor(p)))
364+
end
365+
else
366+
S,Q,d = schur(complex(A))
367+
# Integer part
368+
R = S ^ floor(p)
369+
# Real part
370+
if p - floor(p) == 0.5
371+
# special case: A^0.5 === sqrtm(A)
372+
R = R * sqrtm(S)
373+
else
374+
R = R * powm!(UpperTriangular(float.(S)), real(p - floor(p)))
375+
end
376+
retmat = Q * R * Q'
377+
end
378+
379+
# if A has nonpositive real eigenvalues, retmat is a nonprincipal matrix power.
380+
if isreal(retmat)
381+
return real(retmat)
382+
else
383+
return retmat
331384
end
332-
checksquare(A)
333-
v, X = eig(A)
334-
any(v.<0) && (v = complex(v))
335-
Xinv = ishermitian(A) ? X' : inv(X)
336-
(X * Diagonal(v.^p)) * Xinv
337385
end
386+
^(A::AbstractMatrix, p::Number) = expm(p*logm(A))
338387

339388
# Matrix exponential
340389

@@ -466,7 +515,7 @@ function rcswap!(i::Integer, j::Integer, X::StridedMatrix{<:Number})
466515
end
467516

468517
"""
469-
logm(A::StridedMatrix)
518+
logm(A{T}::StridedMatrix{T})
470519
471520
If `A` has no negative real eigenvalue, compute the principal matrix logarithm of `A`, i.e.
472521
the unique matrix ``X`` such that ``e^X = A`` and ``-\\pi < Im(\\lambda) < \\pi`` for all
@@ -497,8 +546,11 @@ julia> logm(A)
497546
0.0 1.0
498547
```
499548
"""
500-
function logm(A::StridedMatrix)
549+
function logm{T}(A::StridedMatrix{T})
501550
# If possible, use diagonalization
551+
if issymmetric(A) && T <: Real
552+
return full(logm(Symmetric(A)))
553+
end
502554
if ishermitian(A)
503555
return full(logm(Hermitian(A)))
504556
end

base/linalg/diagonal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ end
288288
# identity matrices via eye(Diagonal{type},n)
289289
eye(::Type{Diagonal{T}}, n::Int) where {T} = Diagonal(ones(T,n))
290290

291+
# Matrix functions
291292
expm(D::Diagonal) = Diagonal(exp.(D.diag))
292293
expm(D::Diagonal{<:AbstractMatrix}) = Diagonal(expm.(D.diag))
293294
logm(D::Diagonal) = Diagonal(log.(D.diag))

base/linalg/linalg.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ include("hessenberg.jl")
262262
include("lq.jl")
263263
include("eigen.jl")
264264
include("svd.jl")
265-
include("schur.jl")
266265
include("symmetric.jl")
267266
include("cholesky.jl")
268267
include("lu.jl")
@@ -274,6 +273,8 @@ include("givens.jl")
274273
include("special.jl")
275274
include("bitarray.jl")
276275
include("ldlt.jl")
276+
include("schur.jl")
277+
277278

278279
include("arpack.jl")
279280
include("arnoldi.jl")

base/linalg/schur.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ function schur(A::StridedMatrix)
107107
SchurF = schurfact(A)
108108
SchurF[:T], SchurF[:Z], SchurF[:values]
109109
end
110+
schur(A::Symmetric) = schur(full(A))
111+
schur(A::Hermitian) = schur(full(A))
112+
schur(A::UpperTriangular) = schur(full(A))
113+
schur(A::LowerTriangular) = schur(full(A))
114+
schur(A::Tridiagonal) = schur(full(A))
115+
110116

111117
"""
112118
ordschur!(F::Schur, select::Union{Vector{Bool},BitVector}) -> F::Schur

base/linalg/symmetric.jl

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
#Symmetric and Hermitian matrices
3+
# Symmetric and Hermitian matrices
44
struct Symmetric{T,S<:AbstractMatrix} <: AbstractMatrix{T}
55
data::S
66
uplo::Char
@@ -181,7 +181,7 @@ trace(A::Hermitian) = real(trace(A.data))
181181
Base.conj(A::HermOrSym) = typeof(A)(conj(A.data), A.uplo)
182182
Base.conj!(A::HermOrSym) = typeof(A)(conj!(A.data), A.uplo)
183183

184-
#tril/triu
184+
# tril/triu
185185
function tril(A::Hermitian, k::Integer=0)
186186
if A.uplo == 'U' && k <= 0
187187
return tril!(A.data',k)
@@ -235,7 +235,7 @@ end
235235
## Matvec
236236
A_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::Symmetric{T,<:StridedMatrix}, x::StridedVector{T}) = BLAS.symv!(A.uplo, one(T), A.data, x, zero(T), y)
237237
A_mul_B!{T<:BlasComplex}(y::StridedVector{T}, A::Hermitian{T,<:StridedMatrix}, x::StridedVector{T}) = BLAS.hemv!(A.uplo, one(T), A.data, x, zero(T), y)
238-
##Matmat
238+
## Matmat
239239
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::Symmetric{T,<:StridedMatrix}, B::StridedMatrix{T}) = BLAS.symm!('L', A.uplo, one(T), A.data, B, zero(T), C)
240240
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Symmetric{T,<:StridedMatrix}) = BLAS.symm!('R', B.uplo, one(T), B.data, A, zero(T), C)
241241
A_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::Hermitian{T,<:StridedMatrix}, B::StridedMatrix{T}) = BLAS.hemm!('L', A.uplo, one(T), A.data, B, zero(T), C)
@@ -403,7 +403,54 @@ function svdvals!{T<:Real,S}(A::Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{
403403
return sort!(vals, rev = true)
404404
end
405405

406-
#Matrix-valued functions
406+
# Matrix functions
407+
function ^{T<:Real}(A::Symmetric{T}, p::Integer)
408+
if p < 0
409+
return Symmetric(Base.power_by_squaring(inv(A), -p))
410+
else
411+
return Symmetric(Base.power_by_squaring(A, p))
412+
end
413+
end
414+
function ^{T<:Real}(A::Symmetric{T}, p::Real)
415+
F = eigfact(A)
416+
if all-> λ 0, F.values)
417+
retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors'
418+
else
419+
retmat = (F.vectors * Diagonal((complex(F.values)).^p)) * F.vectors'
420+
end
421+
return Symmetric(retmat)
422+
end
423+
function ^(A::Hermitian, p::Integer)
424+
n = checksquare(A)
425+
if p < 0
426+
retmat = Base.power_by_squaring(inv(A), -p)
427+
else
428+
retmat = Base.power_by_squaring(A, p)
429+
end
430+
for i = 1:n
431+
retmat[i,i] = real(retmat[i,i])
432+
end
433+
return Hermitian(retmat)
434+
end
435+
function ^{T}(A::Hermitian{T}, p::Real)
436+
n = checksquare(A)
437+
F = eigfact(A)
438+
if all-> λ 0, F.values)
439+
retmat = (F.vectors * Diagonal((F.values).^p)) * F.vectors'
440+
if T <: Real
441+
return Hermitian(retmat)
442+
else
443+
for i = 1:n
444+
retmat[i,i] = real(retmat[i,i])
445+
end
446+
return Hermitian(retmat)
447+
end
448+
else
449+
retmat = (F.vectors * Diagonal((complex(F.values).^p))) * F.vectors'
450+
return retmat
451+
end
452+
end
453+
407454
function expm(A::Symmetric)
408455
F = eigfact(A)
409456
return Symmetric((F.vectors * Diagonal(exp.(F.values))) * F.vectors')
@@ -423,10 +470,8 @@ function expm(A::Hermitian{T}) where T
423470
end
424471

425472
for (funm, func) in ([:logm,:log], [:sqrtm,:sqrt])
426-
427473
@eval begin
428-
429-
function ($funm)(A::Symmetric)
474+
function ($funm){T<:Real}(A::Symmetric{T})
430475
F = eigfact(A)
431476
if isposdef(F)
432477
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
@@ -454,7 +499,5 @@ for (funm, func) in ([:logm,:log], [:sqrtm,:sqrt])
454499
return retmat
455500
end
456501
end
457-
458502
end
459-
460503
end

0 commit comments

Comments
 (0)