Skip to content

Commit f382da6

Browse files
committed
Add SVD support for BlockArrays
1 parent 1e5feaa commit f382da6

File tree

4 files changed

+96
-0
lines changed

4 files changed

+96
-0
lines changed

src/BlockArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ include("blocks.jl")
6666

6767
include("blockbroadcast.jl")
6868
include("blockcholesky.jl")
69+
include("blocksvd.jl")
6970
include("blocklinalg.jl")
7071
include("blockproduct.jl")
7172
include("show.jl")

src/blocksvd.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#=
2+
SVD on blockmatrices:
3+
Interpret the matrix as a linear map acting on vector spaces with a direct sum structure, which is reflected in the structure of U and V.
4+
In the generic case, the SVD does not preserve this structure, and can mix up the blocks, so S becomes a single block.
5+
(Implementation-wise, also most efficiently done by first mapping to a `BlockedArray`)
6+
=#
7+
8+
LinearAlgebra.eigencopy_oftype(A::AbstractBlockMatrix, S) = BlockedArray(Array{S}(A), blocksizes(A, 1), blocksizes(A, 2))
9+
10+
function LinearAlgebra.svd!(A::BlockedMatrix; full::Bool=false, alg::LinearAlgebra.Algorithm=default_svd_alg(A))
11+
USV = svd!(parent(A); full, alg)
12+
13+
# restore block pattern
14+
m = length(USV.S)
15+
bsz1, bsz2, bsz3 = blocksizes(A, 1), [m], blocksizes(A, 2)
16+
17+
u = BlockedArray(USV.U, bsz1, bsz2)
18+
s = BlockedVector(USV.S, bsz2)
19+
vt = BlockedArray(USV.Vt, bsz2, bsz3)
20+
return SVD(u, s, vt)
21+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ include("test_blockrange.jl")
2222
include("test_blockarrayinterface.jl")
2323
include("test_blockbroadcast.jl")
2424
include("test_blocklinalg.jl")
25+
include("test_blocksvd.jl")
2526
include("test_blockproduct.jl")
2627
include("test_blockreduce.jl")
2728
include("test_blockdeque.jl")

test/test_blocksvd.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
module TestBlockSVD
2+
3+
using BlockArrays, Test, LinearAlgebra, Random
4+
5+
Random.seed!(0)
6+
7+
eltypes = (Float32, Float64, ComplexF32, ComplexF64, Int)
8+
9+
@testset "Block SVD ($T)" for T in eltypes
10+
x = rand(T, 4, 4)
11+
USV = svd(x)
12+
U, S, Vt = USV.U, USV.S, USV.Vt
13+
14+
y = BlockArray(x, [2, 2], [2, 2])
15+
# https://github.com/JuliaArrays/BlockArrays.jl/issues/425
16+
# USV_blocked = @inferred svd(y)
17+
USV_block = svd(y)
18+
U_block, S_block, Vt_block = USV_block.U, USV_block.S, USV_block.Vt
19+
20+
# test types
21+
@test U_block isa BlockedMatrix
22+
@test eltype(U_block) == float(T)
23+
@test S_block isa BlockedVector
24+
@test eltype(S_block) == real(float(T))
25+
@test Vt_block isa BlockedMatrix
26+
@test eltype(Vt_block) == float(T)
27+
28+
# test structure
29+
@test blocksizes(U_block, 1) == blocksizes(y, 1)
30+
@test length(blocksizes(U_block, 2)) == 1
31+
@test blocksizes(Vt_block, 2) == blocksizes(y, 2)
32+
@test length(blocksizes(Vt_block, 1)) == 1
33+
34+
# test correctness
35+
@test U_block U
36+
@test S_block S
37+
@test Vt_block Vt
38+
@test U_block * Diagonal(S_block) * Vt_block y
39+
end
40+
41+
@testset "Blocked SVD ($T)" for T in eltypes
42+
x = rand(T, 4, 4)
43+
USV = svd(x)
44+
U, S, Vt = USV.U, USV.S, USV.Vt
45+
46+
y = BlockedArray(x, [2, 2], [2, 2])
47+
# https://github.com/JuliaArrays/BlockArrays.jl/issues/425
48+
# USV_blocked = @inferred svd(y)
49+
USV_blocked = svd(y)
50+
U_blocked, S_blocked, Vt_blocked = USV_blocked.U, USV_blocked.S, USV_blocked.Vt
51+
52+
# test types
53+
@test U_blocked isa BlockedMatrix
54+
@test eltype(U_blocked) == float(T)
55+
@test S_blocked isa BlockedVector
56+
@test eltype(S_blocked) == real(float(T))
57+
@test Vt_blocked isa BlockedMatrix
58+
@test eltype(Vt_blocked) == float(T)
59+
60+
# test structure
61+
@test blocksizes(U_blocked, 1) == blocksizes(y, 1)
62+
@test length(blocksizes(U_blocked, 2)) == 1
63+
@test blocksizes(Vt_blocked, 2) == blocksizes(y, 2)
64+
@test length(blocksizes(Vt_blocked, 1)) == 1
65+
66+
# test correctness
67+
@test U_blocked U
68+
@test S_blocked S
69+
@test Vt_blocked Vt
70+
@test U_blocked * Diagonal(S_blocked) * Vt_blocked y
71+
end
72+
73+
end # module

0 commit comments

Comments
 (0)