@@ -12,20 +12,16 @@ const unittriangularwrappers = (
12
12
(:UnitLowerTriangular , :LowerTriangular )
13
13
)
14
14
15
- @kernel function kernel_generic (ctx, B, J, min_size )
15
+ @kernel function kernel_generic (B, J)
16
16
lin_idx = @index (Global, Linear)
17
- if lin_idx <= min_size
18
- @inbounds diag_idx = diagind (B)[lin_idx]
19
- @inbounds B[diag_idx] += J
20
- end
17
+ @inbounds diag_idx = diagind (B)[lin_idx]
18
+ @inbounds B[diag_idx] += J
21
19
end
22
20
23
- @kernel function kernel_unittriangular (ctx, B, J, diagonal_val, min_size )
21
+ @kernel function kernel_unittriangular (B, J, diagonal_val)
24
22
lin_idx = @index (Global, Linear)
25
- if lin_idx <= min_size
26
- @inbounds diag_idx = diagind (B)[lin_idx]
27
- @inbounds B[diag_idx] = diagonal_val + J
28
- end
23
+ @inbounds diag_idx = diagind (B)[lin_idx]
24
+ @inbounds B[diag_idx] = diagonal_val + J
29
25
end
30
26
31
27
for (t1, t2) in unittriangularwrappers
@@ -34,17 +30,15 @@ for (t1, t2) in unittriangularwrappers
34
30
B = similar (parent (A), typeof (oneunit (T) + J))
35
31
copyto! (B, parent (A))
36
32
min_size = minimum (size (B))
37
- kernel = kernel_unittriangular (get_backend (B))
38
- kernel (B, J, one (eltype (B)), min_size; ndrange= min_size)
33
+ kernel_unittriangular (get_backend (B))(B, J, one (eltype (B)); ndrange= min_size)
39
34
return $ t2 (B)
40
35
end
41
36
42
37
function (- )(J:: UniformScaling , A:: $t1{T, <:AbstractGPUMatrix} ) where T
43
38
B = similar (parent (A), typeof (J - oneunit (T)))
44
39
B .= .- parent (A)
45
40
min_size = minimum (size (B))
46
- kernel = kernel_unittriangular (get_backend (B))
47
- kernel (B, J, - one (eltype (B)), min_size; ndrange= min_size)
41
+ kernel_unittriangular (get_backend (B))(B, J, - one (eltype (B)); ndrange= min_size)
48
42
return $ t2 (B)
49
43
end
50
44
end
@@ -56,17 +50,15 @@ for t in genericwrappers
56
50
B = similar (parent (A), typeof (oneunit (T) + J))
57
51
copyto! (B, parent (A))
58
52
min_size = minimum (size (B))
59
- kernel = kernel_generic (get_backend (B))
60
- kernel (B, J, min_size; ndrange= min_size)
53
+ kernel_generic (get_backend (B))(B, J; ndrange= min_size)
61
54
return $ t (B)
62
55
end
63
56
64
57
function (- )(J:: UniformScaling , A:: $t{T, <:AbstractGPUMatrix} ) where T
65
58
B = similar (parent (A), typeof (J - oneunit (T)))
66
59
B .= .- parent (A)
67
60
min_size = minimum (size (B))
68
- kernel = kernel_generic (get_backend (B))
69
- kernel (B, J, min_size; ndrange= min_size)
61
+ kernel_generic (get_backend (B))(B, J; ndrange= min_size)
70
62
return $ t (B)
71
63
end
72
64
end
@@ -77,17 +69,15 @@ function (+)(A::Hermitian{T,<:AbstractGPUMatrix}, J::UniformScaling{<:Complex})
77
69
B = similar (parent (A), typeof (oneunit (T) + J))
78
70
copyto! (B, parent (A))
79
71
min_size = minimum (size (B))
80
- kernel = kernel_generic (get_backend (B))
81
- kernel (B, J, min_size; ndrange= min_size)
72
+ kernel_generic (get_backend (B))(B, J; ndrange= min_size)
82
73
return B
83
74
end
84
75
85
76
function (- )(J:: UniformScaling{<:Complex} , A:: Hermitian{T,<:AbstractGPUMatrix} ) where T
86
77
B = similar (parent (A), typeof (J - oneunit (T)))
87
78
B .= .- parent (A)
88
79
min_size = minimum (size (B))
89
- kernel = kernel_generic (get_backend (B))
90
- kernel (B, J, min_size; ndrange= min_size)
80
+ kernel_generic (get_backend (B))(B, J; ndrange= min_size)
91
81
return B
92
82
end
93
83
@@ -96,16 +86,14 @@ function (+)(A::AbstractGPUMatrix{T}, J::UniformScaling) where T
96
86
B = similar (A, typeof (oneunit (T) + J))
97
87
copyto! (B, A)
98
88
min_size = minimum (size (B))
99
- kernel = kernel_generic (get_backend (B))
100
- kernel (B, J, min_size; ndrange= min_size)
89
+ kernel_generic (get_backend (B))(B, J; ndrange= min_size)
101
90
return B
102
91
end
103
92
104
93
function (- )(J:: UniformScaling , A:: AbstractGPUMatrix{T} ) where T
105
94
B = similar (A, typeof (J - oneunit (T)))
106
95
B .= .- A
107
96
min_size = minimum (size (B))
108
- kernel = kernel_generic (get_backend (B))
109
- kernel (B, J, min_size; ndrange= min_size)
97
+ kernel_generic (get_backend (B))(B, J; ndrange= min_size)
110
98
return B
111
99
end
0 commit comments