Skip to content

Commit bc3a67d

Browse files
committed
Added tests for diag
1 parent 76d0a74 commit bc3a67d

File tree

1 file changed

+216
-25
lines changed

1 file changed

+216
-25
lines changed

src/tests/linalg/test_linalg.f90

Lines changed: 216 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,8 @@ program test_linalg
33
use stdlib_experimental_kinds, only: sp, dp, qp, int8, int16, int32, int64
44
use stdlib_experimental_linalg, only: diag, eye, trace
55
implicit none
6-
7-
real(sp) :: a(5)
8-
real(sp), allocatable :: d(:,:)
9-
integer :: i, j
106
logical :: warn
11-
12-
warn = .true.
13-
a = [1,2,3,4,5]
14-
d = diag(a,0)
15-
16-
do i = 1, size(d,1)
17-
write(*,*) (d(i,j),j=1,size(d,2))
18-
end do
19-
20-
a = diag(d,0)
21-
print *, "a = ", a
7+
warn = .false.
228

239
!
2410
! eye
@@ -28,7 +14,19 @@ program test_linalg
2814
!
2915
! diag
3016
!
17+
call test_diag_rsp
18+
call test_diag_rsp_k
19+
call test_diag_rdp
20+
call test_diag_rqp
21+
22+
call test_diag_csp
23+
call test_diag_cdp
24+
call test_diag_cqp
3125

26+
call test_diag_int8
27+
call test_diag_int16
28+
call test_diag_int32
29+
call test_diag_int64
3230

3331
!
3432
! trace
@@ -69,9 +67,198 @@ subroutine test_eye
6967
msg="abs(trace(cye) - complex(7.0_sp,0.0_sp)) < epsilon(1.0_sp) failed.",warn=warn)
7068
end subroutine
7169

70+
subroutine test_diag_rsp
71+
integer, parameter :: n = 3
72+
real(sp) :: v(n), a(n,n), b(n,n)
73+
integer :: i,j
74+
write(*,*) "test_diag_rsp"
75+
v = [(i,i=1,n)]
76+
a = diag(v)
77+
b = reshape([((merge(i,0,i==j), i=1,n), j=1,n)], [n,n])
78+
call check(all(a == b), &
79+
msg="all(a == b) failed.",warn=warn)
80+
81+
call check(all(diag(3*a) == 3*v), &
82+
msg="all(diag(3*a) == 3*v) failed.",warn=warn)
83+
end subroutine
84+
85+
subroutine test_diag_rsp_k
86+
integer, parameter :: n = 4
87+
real(sp) :: a(n,n), b(n,n)
88+
integer :: i,j
89+
write(*,*) "test_diag_rsp_k"
90+
91+
a = diag([(1._sp,i=1,n-1)],-1)
92+
93+
b = reshape([((merge(1,0,i==j+1), i=1,n), j=1,n)], [n,n])
94+
95+
call check(all(a == b), &
96+
msg="all(a == b) failed.",warn=warn)
97+
98+
call check(sum(diag(a,-1)) - (n-1) < epsilon(1.0_sp), &
99+
msg="sum(diag(a,-1)) - (n-1) < epsilon(1.0_sp) failed.",warn=warn)
100+
101+
call check(all(a == transpose(diag([(1._sp,i=1,n-1)],1))), &
102+
msg="all(a == transpose(diag([(1._sp,i=1,n-1)],1))) failed",warn=warn)
103+
104+
call random_number(a)
105+
do i = 1, n
106+
call check(size(diag(a,i)) == n-i, &
107+
msg="size(diag(a,i)) == n-i failed.",warn=warn)
108+
end do
109+
call check(size(diag(a,n+1)) == 0, &
110+
msg="size(diag(a,n+1)) == 0 failed.",warn=warn)
111+
end subroutine
112+
113+
subroutine test_diag_rdp
114+
integer, parameter :: n = 3
115+
real(dp) :: v(n), a(n,n), b(n,n)
116+
integer :: i,j
117+
write(*,*) "test_diag_rdp"
118+
v = [(i,i=1,n)]
119+
a = diag(v)
120+
b = reshape([((merge(i,0,i==j), i=1,n), j=1,n)], [n,n])
121+
call check(all(a == b), &
122+
msg="all(a == b) failed.",warn=warn)
123+
124+
call check(all(diag(3*a) == 3*v), &
125+
msg="all(diag(3*a) == 3*v) failed.",warn=warn)
126+
end subroutine
127+
128+
subroutine test_diag_rqp
129+
integer, parameter :: n = 3
130+
real(qp) :: v(n), a(n,n), b(n,n)
131+
integer :: i,j
132+
write(*,*) "test_diag_rqp"
133+
v = [(i,i=1,n)]
134+
a = diag(v)
135+
b = reshape([((merge(i,0,i==j), i=1,n), j=1,n)], [n,n])
136+
call check(all(a == b), &
137+
msg="all(a == b) failed.", warn=warn)
138+
139+
call check(all(diag(3*a) == 3*v), &
140+
msg="all(diag(3*a) == 3*v) failed.", warn=warn)
141+
end subroutine
142+
143+
subroutine test_diag_csp
144+
integer, parameter :: n = 3
145+
complex(sp) :: v(n), a(n,n), b(n,n)
146+
complex(sp), parameter :: i_ = complex(0,1)
147+
integer :: i,j
148+
write(*,*) "test_diag_csp"
149+
a = diag([(i,i=1,n)]) + diag([(i_,i=1,n)])
150+
b = reshape([((merge(i + 1*i_,0*i_,i==j), i=1,n), j=1,n)], [n,n])
151+
call check(all(a == b), &
152+
msg="all(a == b) failed.",warn=warn)
153+
154+
call check(all(abs(real(diag(a)) - [(i,i=1,n)]) < epsilon(1.0_sp)), &
155+
msg="all(abs(real(diag(a)) - [(i,i=1,n)]) < epsilon(1.0_sp))", warn=warn)
156+
call check(all(abs(aimag(diag(a)) - [(1,i=1,n)]) < epsilon(1.0_sp)), &
157+
msg="all(abs(aimag(diag(a)) - [(1,i=1,n)]) < epsilon(1.0_sp))", warn=warn)
158+
end subroutine
159+
160+
subroutine test_diag_cdp
161+
integer, parameter :: n = 3
162+
complex(dp) :: v(n), a(n,n), b(n,n)
163+
complex(dp), parameter :: i_ = complex(0,1)
164+
integer :: i,j
165+
write(*,*) "test_diag_cdp"
166+
a = diag([i_],-2) + diag([i_],2)
167+
call check(a(3,1) == i_ .and. a(1,3) == i_, &
168+
msg="a(3,1) == i_ .and. a(1,3) == i_ failed.",warn=warn)
169+
end subroutine
170+
171+
subroutine test_diag_cqp
172+
integer, parameter :: n = 3
173+
complex(qp) :: v(n), a(n,n), b(n,n)
174+
complex(qp), parameter :: i_ = complex(0,1)
175+
integer :: i,j
176+
write(*,*) "test_diag_cqp"
177+
a = diag([i_,i_],-1) + diag([i_,i_],1)
178+
call check(all(diag(a,-1) == i_) .and. all(diag(a,1) == i_), &
179+
msg="all(diag(a,-1) == i_) .and. all(diag(a,1) == i_) failed.",warn=warn)
180+
end subroutine
181+
182+
subroutine test_diag_int8
183+
integer, parameter :: n = 3
184+
integer(int8), allocatable :: a(:,:)
185+
integer :: i
186+
logical, allocatable :: mask(:,:)
187+
write(*,*) "test_diag_int8"
188+
a = reshape([(i,i=1,n**2)],[n,n])
189+
mask = merge(.true.,.false.,eye(n) == 1)
190+
call check(all(diag(a) == pack(a,mask)), &
191+
msg="all(diag(a) == pack(a,mask)) failed.", warn=warn)
192+
call check(all(diag(diag(a)) == merge(a,0_int8,mask)), &
193+
msg="all(diag(diag(a)) == merge(a,0_int8,mask)) failed.", warn=warn)
194+
end subroutine
195+
subroutine test_diag_int16
196+
integer, parameter :: n = 4
197+
integer(int16), allocatable :: a(:,:)
198+
integer :: i
199+
logical, allocatable :: mask(:,:)
200+
write(*,*) "test_diag_int16"
201+
a = reshape([(i,i=1,n**2)],[n,n])
202+
mask = merge(.true.,.false.,eye(n) == 1)
203+
call check(all(diag(a) == pack(a,mask)), &
204+
msg="all(diag(a) == pack(a,mask))", warn=warn)
205+
call check(all(diag(diag(a)) == merge(a,0_int16,mask)), &
206+
msg="all(diag(diag(a)) == merge(a,0_int16,mask)) failed.", warn=warn)
207+
a = unpack(int([1,2,3,4],int16),eye(n)==1,a)
208+
end subroutine
209+
subroutine test_diag_int32
210+
integer, parameter :: n = 3
211+
integer(int32) :: a(n,n)
212+
logical :: mask(n,n)
213+
integer :: i, j
214+
write(*,*) "test_diag_int32"
215+
mask = reshape([((merge(.true.,.false.,i==j+1), i=1,n), j=1,n)], [n,n])
216+
a = 0
217+
a = unpack([1_int32,1_int32],mask,a)
218+
call check(all(diag([1,1],-1) == a), &
219+
msg="all(diag([1,1],-1) == a) failed.", warn=warn)
220+
call check(all(diag([1,1],1) == transpose(a)), &
221+
msg="all(diag([1,1],1) == transpose(a)) failed.", warn=warn)
222+
end subroutine
223+
subroutine test_diag_int64
224+
integer, parameter :: n = 4
225+
integer(int64) :: a(n,n), c(2*n-1)
226+
logical :: mask(n,n)
227+
integer :: i, j
228+
229+
write(*,*) "test_diag_int64"
230+
231+
mask = reshape([((merge(.true.,.false.,i+1==j), i=1,n), j=1,n)], [n,n])
232+
a = 0
233+
a = unpack([1_int64,1_int64,1_int64],mask,a)
234+
235+
call check(all(diag([1,1,1],1) == a), &
236+
msg="all(diag([1,1,1],1) == a) failed.", warn=warn)
237+
call check(all(diag([1,1,1],-1) == transpose(a)), &
238+
msg="all(diag([1,1,1],-1) == transpose(a)) failed.", warn=warn)
239+
240+
241+
! Fill array c with Catalan numbers
242+
do i = 0, 2*n-1
243+
c(i) = catalan_number(i)
244+
end do
245+
! Symmetric Hankel matrix filled with Catalan numbers (det(H) = 1)
246+
do i = 1, n
247+
do j = 1, n
248+
a(i,j) = c(i-1 + (j-1))
249+
end do
250+
end do
251+
call check(all(diag(a,-2) == diag(a,2)), &
252+
msg="all(diag(a,-2) == diag(a,2))", warn=warn)
253+
end subroutine
254+
255+
256+
257+
72258
subroutine test_trace_rsp
73259
integer, parameter :: n = 5
74260
real(sp) :: a(n,n)
261+
integer :: i
75262
write(*,*) "test_trace_rsp"
76263
a = reshape([(i,i=1,n**2)],[n,n])
77264
call check(abs(trace(a) - sum(diag(a))) < epsilon(1.0_sp), &
@@ -81,22 +268,24 @@ subroutine test_trace_rsp
81268
subroutine test_trace_rsp_nonsquare
82269
integer, parameter :: n = 4
83270
real(sp) :: a(n,n+1), ans
271+
integer :: i
84272
write(*,*) "test_trace_rsp_nonsquare"
85-
a = reshape([(i,i=1,n*(n+1))],[n,n+1])
86273

87274
! 1 5 9 13 17
88275
! 2 6 10 14 18
89276
! 3 7 11 15 19
90277
! 4 8 12 16 20
91-
278+
a = reshape([(i,i=1,n*(n+1))],[n,n+1])
92279
ans = sum([1._sp,6._sp,11._sp,16._sp])
280+
93281
call check(abs(trace(a) - ans) < epsilon(1.0_sp), &
94282
msg="abs(trace(a) - ans) < epsilon(1.0_sp) failed.",warn=warn)
95283
end subroutine
96284

97285
subroutine test_trace_rdp
98286
integer, parameter :: n = 4
99287
real(dp) :: a(n,n)
288+
integer :: i
100289
write(*,*) "test_trace_rdp"
101290
a = reshape([(i,i=1,n**2)],[n,n])
102291
call check(abs(trace(a) - sum(diag(a))) < epsilon(1.0_dp), &
@@ -106,22 +295,24 @@ subroutine test_trace_rdp
106295
subroutine test_trace_rdp_nonsquare
107296
integer, parameter :: n = 4
108297
real(dp) :: a(n,n-1), ans
298+
integer :: i
109299
write(*,*) "test_trace_rdp_nonsquare"
110-
a = reshape([(i**2,i=1,n*(n-1))],[n,n-1])
111300

112301
! 1 25 81
113302
! 4 36 100
114303
! 9 49 121
115304
! 16 64 144
116-
305+
a = reshape([(i**2,i=1,n*(n-1))],[n,n-1])
117306
ans = sum([1._dp,36._dp,121._dp])
307+
118308
call check(abs(trace(a) - ans) < epsilon(1.0_dp), &
119309
msg="abs(trace(a) - ans) < epsilon(1.0_sp) failed.",warn=warn)
120310
end subroutine
121311

122312
subroutine test_trace_rqp
123313
integer, parameter :: n = 3
124314
real(qp) :: a(n,n)
315+
integer :: i
125316
write(*,*) "test_trace_rqp"
126317
a = reshape([(i,i=1,n**2)],[n,n])
127318
call check(abs(trace(a) - sum(diag(a))) < epsilon(1.0_qp), &
@@ -205,9 +396,9 @@ subroutine test_trace_int32
205396
end subroutine
206397

207398
subroutine test_trace_int64
208-
integer(int64), parameter :: n = 5
209-
integer(int64), parameter :: nd = 2*n-1 ! number of diagonals
210-
integer(int64) :: i, j
399+
integer, parameter :: n = 5
400+
integer, parameter :: nd = 2*n-1 ! number of diagonals
401+
integer :: i, j
211402
integer(int64) :: c(0:nd), H(n,n)
212403
write(*,*) "test_trace_int64"
213404

@@ -229,15 +420,15 @@ subroutine test_trace_int64
229420
end subroutine
230421

231422
pure recursive function catalan_number(n) result(value)
232-
integer(int64), intent(in) :: n
233-
integer(int64) :: value
423+
integer, intent(in) :: n
424+
integer :: value
234425
integer :: i
235426
if (n <= 1) then
236427
value = 1
237428
else
238429
value = 0
239430
do i = 0, n-1
240-
value = value + catalan_number(int(i,int64))*catalan_number(n-i-1)
431+
value = value + catalan_number(i)*catalan_number(n-i-1)
241432
end do
242433
end if
243434
end function

0 commit comments

Comments
 (0)