@@ -3,22 +3,8 @@ program test_linalg
3
3
use stdlib_experimental_kinds, only: sp, dp, qp, int8, int16, int32, int64
4
4
use stdlib_experimental_linalg, only: diag, eye, trace
5
5
implicit none
6
-
7
- real (sp) :: a(5 )
8
- real (sp), allocatable :: d(:,:)
9
- integer :: i, j
10
6
logical :: warn
11
-
12
- warn = .true.
13
- a = [1 ,2 ,3 ,4 ,5 ]
14
- d = diag(a,0 )
15
-
16
- do i = 1 , size (d,1 )
17
- write (* ,* ) (d(i,j),j= 1 ,size (d,2 ))
18
- end do
19
-
20
- a = diag(d,0 )
21
- print * , " a = " , a
7
+ warn = .false.
22
8
23
9
!
24
10
! eye
@@ -28,7 +14,19 @@ program test_linalg
28
14
!
29
15
! diag
30
16
!
17
+ call test_diag_rsp
18
+ call test_diag_rsp_k
19
+ call test_diag_rdp
20
+ call test_diag_rqp
21
+
22
+ call test_diag_csp
23
+ call test_diag_cdp
24
+ call test_diag_cqp
31
25
26
+ call test_diag_int8
27
+ call test_diag_int16
28
+ call test_diag_int32
29
+ call test_diag_int64
32
30
33
31
!
34
32
! trace
@@ -69,9 +67,198 @@ subroutine test_eye
69
67
msg= " abs(trace(cye) - complex(7.0_sp,0.0_sp)) < epsilon(1.0_sp) failed." ,warn= warn)
70
68
end subroutine
71
69
70
+ subroutine test_diag_rsp
71
+ integer , parameter :: n = 3
72
+ real (sp) :: v(n), a(n,n), b(n,n)
73
+ integer :: i,j
74
+ write (* ,* ) " test_diag_rsp"
75
+ v = [(i,i= 1 ,n)]
76
+ a = diag(v)
77
+ b = reshape ([((merge (i,0 ,i== j), i= 1 ,n), j= 1 ,n)], [n,n])
78
+ call check(all (a == b), &
79
+ msg= " all(a == b) failed." ,warn= warn)
80
+
81
+ call check(all (diag(3 * a) == 3 * v), &
82
+ msg= " all(diag(3*a) == 3*v) failed." ,warn= warn)
83
+ end subroutine
84
+
85
+ subroutine test_diag_rsp_k
86
+ integer , parameter :: n = 4
87
+ real (sp) :: a(n,n), b(n,n)
88
+ integer :: i,j
89
+ write (* ,* ) " test_diag_rsp_k"
90
+
91
+ a = diag([(1._sp ,i= 1 ,n-1 )],- 1 )
92
+
93
+ b = reshape ([((merge (1 ,0 ,i== j+1 ), i= 1 ,n), j= 1 ,n)], [n,n])
94
+
95
+ call check(all (a == b), &
96
+ msg= " all(a == b) failed." ,warn= warn)
97
+
98
+ call check(sum (diag(a,- 1 )) - (n-1 ) < epsilon (1.0_sp ), &
99
+ msg= " sum(diag(a,-1)) - (n-1) < epsilon(1.0_sp) failed." ,warn= warn)
100
+
101
+ call check(all (a == transpose (diag([(1._sp ,i= 1 ,n-1 )],1 ))), &
102
+ msg= " all(a == transpose(diag([(1._sp,i=1,n-1)],1))) failed" ,warn= warn)
103
+
104
+ call random_number (a)
105
+ do i = 1 , n
106
+ call check(size (diag(a,i)) == n- i, &
107
+ msg= " size(diag(a,i)) == n-i failed." ,warn= warn)
108
+ end do
109
+ call check(size (diag(a,n+1 )) == 0 , &
110
+ msg= " size(diag(a,n+1)) == 0 failed." ,warn= warn)
111
+ end subroutine
112
+
113
+ subroutine test_diag_rdp
114
+ integer , parameter :: n = 3
115
+ real (dp) :: v(n), a(n,n), b(n,n)
116
+ integer :: i,j
117
+ write (* ,* ) " test_diag_rdp"
118
+ v = [(i,i= 1 ,n)]
119
+ a = diag(v)
120
+ b = reshape ([((merge (i,0 ,i== j), i= 1 ,n), j= 1 ,n)], [n,n])
121
+ call check(all (a == b), &
122
+ msg= " all(a == b) failed." ,warn= warn)
123
+
124
+ call check(all (diag(3 * a) == 3 * v), &
125
+ msg= " all(diag(3*a) == 3*v) failed." ,warn= warn)
126
+ end subroutine
127
+
128
+ subroutine test_diag_rqp
129
+ integer , parameter :: n = 3
130
+ real (qp) :: v(n), a(n,n), b(n,n)
131
+ integer :: i,j
132
+ write (* ,* ) " test_diag_rqp"
133
+ v = [(i,i= 1 ,n)]
134
+ a = diag(v)
135
+ b = reshape ([((merge (i,0 ,i== j), i= 1 ,n), j= 1 ,n)], [n,n])
136
+ call check(all (a == b), &
137
+ msg= " all(a == b) failed." , warn= warn)
138
+
139
+ call check(all (diag(3 * a) == 3 * v), &
140
+ msg= " all(diag(3*a) == 3*v) failed." , warn= warn)
141
+ end subroutine
142
+
143
+ subroutine test_diag_csp
144
+ integer , parameter :: n = 3
145
+ complex (sp) :: v(n), a(n,n), b(n,n)
146
+ complex (sp), parameter :: i_ = complex (0 ,1 )
147
+ integer :: i,j
148
+ write (* ,* ) " test_diag_csp"
149
+ a = diag([(i,i= 1 ,n)]) + diag([(i_,i= 1 ,n)])
150
+ b = reshape ([((merge (i + 1 * i_,0 * i_,i== j), i= 1 ,n), j= 1 ,n)], [n,n])
151
+ call check(all (a == b), &
152
+ msg= " all(a == b) failed." ,warn= warn)
153
+
154
+ call check(all (abs (real (diag(a)) - [(i,i= 1 ,n)]) < epsilon (1.0_sp )), &
155
+ msg= " all(abs(real(diag(a)) - [(i,i=1,n)]) < epsilon(1.0_sp))" , warn= warn)
156
+ call check(all (abs (aimag (diag(a)) - [(1 ,i= 1 ,n)]) < epsilon (1.0_sp )), &
157
+ msg= " all(abs(aimag(diag(a)) - [(1,i=1,n)]) < epsilon(1.0_sp))" , warn= warn)
158
+ end subroutine
159
+
160
+ subroutine test_diag_cdp
161
+ integer , parameter :: n = 3
162
+ complex (dp) :: v(n), a(n,n), b(n,n)
163
+ complex (dp), parameter :: i_ = complex (0 ,1 )
164
+ integer :: i,j
165
+ write (* ,* ) " test_diag_cdp"
166
+ a = diag([i_],- 2 ) + diag([i_],2 )
167
+ call check(a(3 ,1 ) == i_ .and. a(1 ,3 ) == i_, &
168
+ msg= " a(3,1) == i_ .and. a(1,3) == i_ failed." ,warn= warn)
169
+ end subroutine
170
+
171
+ subroutine test_diag_cqp
172
+ integer , parameter :: n = 3
173
+ complex (qp) :: v(n), a(n,n), b(n,n)
174
+ complex (qp), parameter :: i_ = complex (0 ,1 )
175
+ integer :: i,j
176
+ write (* ,* ) " test_diag_cqp"
177
+ a = diag([i_,i_],- 1 ) + diag([i_,i_],1 )
178
+ call check(all (diag(a,- 1 ) == i_) .and. all (diag(a,1 ) == i_), &
179
+ msg= " all(diag(a,-1) == i_) .and. all(diag(a,1) == i_) failed." ,warn= warn)
180
+ end subroutine
181
+
182
+ subroutine test_diag_int8
183
+ integer , parameter :: n = 3
184
+ integer (int8), allocatable :: a(:,:)
185
+ integer :: i
186
+ logical , allocatable :: mask(:,:)
187
+ write (* ,* ) " test_diag_int8"
188
+ a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
189
+ mask = merge (.true. ,.false. ,eye(n) == 1 )
190
+ call check(all (diag(a) == pack (a,mask)), &
191
+ msg= " all(diag(a) == pack(a,mask)) failed." , warn= warn)
192
+ call check(all (diag(diag(a)) == merge (a,0_int8 ,mask)), &
193
+ msg= " all(diag(diag(a)) == merge(a,0_int8,mask)) failed." , warn= warn)
194
+ end subroutine
195
+ subroutine test_diag_int16
196
+ integer , parameter :: n = 4
197
+ integer (int16), allocatable :: a(:,:)
198
+ integer :: i
199
+ logical , allocatable :: mask(:,:)
200
+ write (* ,* ) " test_diag_int16"
201
+ a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
202
+ mask = merge (.true. ,.false. ,eye(n) == 1 )
203
+ call check(all (diag(a) == pack (a,mask)), &
204
+ msg= " all(diag(a) == pack(a,mask))" , warn= warn)
205
+ call check(all (diag(diag(a)) == merge (a,0_int16 ,mask)), &
206
+ msg= " all(diag(diag(a)) == merge(a,0_int16,mask)) failed." , warn= warn)
207
+ a = unpack (int ([1 ,2 ,3 ,4 ],int16),eye(n)==1 ,a)
208
+ end subroutine
209
+ subroutine test_diag_int32
210
+ integer , parameter :: n = 3
211
+ integer (int32) :: a(n,n)
212
+ logical :: mask(n,n)
213
+ integer :: i, j
214
+ write (* ,* ) " test_diag_int32"
215
+ mask = reshape ([((merge (.true. ,.false. ,i== j+1 ), i= 1 ,n), j= 1 ,n)], [n,n])
216
+ a = 0
217
+ a = unpack ([1_int32 ,1_int32 ],mask,a)
218
+ call check(all (diag([1 ,1 ],- 1 ) == a), &
219
+ msg= " all(diag([1,1],-1) == a) failed." , warn= warn)
220
+ call check(all (diag([1 ,1 ],1 ) == transpose (a)), &
221
+ msg= " all(diag([1,1],1) == transpose(a)) failed." , warn= warn)
222
+ end subroutine
223
+ subroutine test_diag_int64
224
+ integer , parameter :: n = 4
225
+ integer (int64) :: a(n,n), c(2 * n-1 )
226
+ logical :: mask(n,n)
227
+ integer :: i, j
228
+
229
+ write (* ,* ) " test_diag_int64"
230
+
231
+ mask = reshape ([((merge (.true. ,.false. ,i+1 == j), i= 1 ,n), j= 1 ,n)], [n,n])
232
+ a = 0
233
+ a = unpack ([1_int64 ,1_int64 ,1_int64 ],mask,a)
234
+
235
+ call check(all (diag([1 ,1 ,1 ],1 ) == a), &
236
+ msg= " all(diag([1,1,1],1) == a) failed." , warn= warn)
237
+ call check(all (diag([1 ,1 ,1 ],- 1 ) == transpose (a)), &
238
+ msg= " all(diag([1,1,1],-1) == transpose(a)) failed." , warn= warn)
239
+
240
+
241
+ ! Fill array c with Catalan numbers
242
+ do i = 0 , 2 * n-1
243
+ c(i) = catalan_number(i)
244
+ end do
245
+ ! Symmetric Hankel matrix filled with Catalan numbers (det(H) = 1)
246
+ do i = 1 , n
247
+ do j = 1 , n
248
+ a(i,j) = c(i-1 + (j-1 ))
249
+ end do
250
+ end do
251
+ call check(all (diag(a,- 2 ) == diag(a,2 )), &
252
+ msg= " all(diag(a,-2) == diag(a,2))" , warn= warn)
253
+ end subroutine
254
+
255
+
256
+
257
+
72
258
subroutine test_trace_rsp
73
259
integer , parameter :: n = 5
74
260
real (sp) :: a(n,n)
261
+ integer :: i
75
262
write (* ,* ) " test_trace_rsp"
76
263
a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
77
264
call check(abs (trace(a) - sum (diag(a))) < epsilon (1.0_sp ), &
@@ -81,22 +268,24 @@ subroutine test_trace_rsp
81
268
subroutine test_trace_rsp_nonsquare
82
269
integer , parameter :: n = 4
83
270
real (sp) :: a(n,n+1 ), ans
271
+ integer :: i
84
272
write (* ,* ) " test_trace_rsp_nonsquare"
85
- a = reshape ([(i,i= 1 ,n* (n+1 ))],[n,n+1 ])
86
273
87
274
! 1 5 9 13 17
88
275
! 2 6 10 14 18
89
276
! 3 7 11 15 19
90
277
! 4 8 12 16 20
91
-
278
+ a = reshape ([(i,i = 1 ,n * (n +1 ))],[n,n +1 ])
92
279
ans = sum ([1._sp ,6._sp ,11._sp ,16._sp ])
280
+
93
281
call check(abs (trace(a) - ans) < epsilon (1.0_sp ), &
94
282
msg= " abs(trace(a) - ans) < epsilon(1.0_sp) failed." ,warn= warn)
95
283
end subroutine
96
284
97
285
subroutine test_trace_rdp
98
286
integer , parameter :: n = 4
99
287
real (dp) :: a(n,n)
288
+ integer :: i
100
289
write (* ,* ) " test_trace_rdp"
101
290
a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
102
291
call check(abs (trace(a) - sum (diag(a))) < epsilon (1.0_dp ), &
@@ -106,22 +295,24 @@ subroutine test_trace_rdp
106
295
subroutine test_trace_rdp_nonsquare
107
296
integer , parameter :: n = 4
108
297
real (dp) :: a(n,n-1 ), ans
298
+ integer :: i
109
299
write (* ,* ) " test_trace_rdp_nonsquare"
110
- a = reshape ([(i** 2 ,i= 1 ,n* (n-1 ))],[n,n-1 ])
111
300
112
301
! 1 25 81
113
302
! 4 36 100
114
303
! 9 49 121
115
304
! 16 64 144
116
-
305
+ a = reshape ([(i ** 2 ,i = 1 ,n * (n -1 ))],[n,n -1 ])
117
306
ans = sum ([1._dp ,36._dp ,121._dp ])
307
+
118
308
call check(abs (trace(a) - ans) < epsilon (1.0_dp ), &
119
309
msg= " abs(trace(a) - ans) < epsilon(1.0_sp) failed." ,warn= warn)
120
310
end subroutine
121
311
122
312
subroutine test_trace_rqp
123
313
integer , parameter :: n = 3
124
314
real (qp) :: a(n,n)
315
+ integer :: i
125
316
write (* ,* ) " test_trace_rqp"
126
317
a = reshape ([(i,i= 1 ,n** 2 )],[n,n])
127
318
call check(abs (trace(a) - sum (diag(a))) < epsilon (1.0_qp ), &
@@ -205,9 +396,9 @@ subroutine test_trace_int32
205
396
end subroutine
206
397
207
398
subroutine test_trace_int64
208
- integer (int64) , parameter :: n = 5
209
- integer (int64) , parameter :: nd = 2 * n-1 ! number of diagonals
210
- integer (int64) :: i, j
399
+ integer , parameter :: n = 5
400
+ integer , parameter :: nd = 2 * n-1 ! number of diagonals
401
+ integer :: i, j
211
402
integer (int64) :: c(0 :nd), H(n,n)
212
403
write (* ,* ) " test_trace_int64"
213
404
@@ -229,15 +420,15 @@ subroutine test_trace_int64
229
420
end subroutine
230
421
231
422
pure recursive function catalan_number(n) result(value)
232
- integer (int64) , intent (in ) :: n
233
- integer (int64) :: value
423
+ integer , intent (in ) :: n
424
+ integer :: value
234
425
integer :: i
235
426
if (n <= 1 ) then
236
427
value = 1
237
428
else
238
429
value = 0
239
430
do i = 0 , n-1
240
- value = value + catalan_number(int (i,int64) )* catalan_number(n- i-1 )
431
+ value = value + catalan_number(i )* catalan_number(n- i-1 )
241
432
end do
242
433
end if
243
434
end function
0 commit comments