@@ -71,6 +71,20 @@ if addrspaceptr_available
71
71
end
72
72
end
73
73
74
+ # Fix for https://github.com/JuliaGPU/CUDAnative.jl/issues/587
75
+ # Instead of ccall'ing the intrinsics with NTuple{N, T}, we generate
76
+ # custom structs LLVMStructN{T}, containing N fields of type T, and use
77
+ # those as return type.
78
+ for N in unique (values (map_frag_sizes))
79
+ struct_ty = Symbol (" LLVMStruct$N " )
80
+
81
+ @eval struct $ struct_ty{T}
82
+ Base. Cartesian. @nexprs $ N i -> x_i:: T
83
+ end
84
+
85
+ @eval Base. convert (:: Type{NTuple{$N, T}} , x:: $struct_ty{T} ) where {T} = ntuple (i -> getfield (x, i), $ N)
86
+ end
87
+
74
88
# ###############################################################################
75
89
# LOW LEVEL API
76
90
# ###############################################################################
@@ -126,8 +140,9 @@ for mat in ["a", "b", "c"],
126
140
ccall_name = " extern $llvm_intr "
127
141
128
142
ptr_ty = addrspaceptr_available ? Core. AddrSpacePtr{arr_ty, addr_space_int} : Ref{arr_ty}
143
+ struct_ty = Symbol (" LLVMStruct$sz " )
129
144
130
- @eval $ func_name (src_addr, stride) = ccall ($ ccall_name, llvmcall, NTuple{ $ sz, $ frag_ty}, ($ ptr_ty, Int32), src_addr, stride)
145
+ @eval $ func_name (src_addr, stride) = convert (NTuple{ $ sz, $ frag_ty}, ccall ($ ccall_name, llvmcall, $ struct_ty{ $ frag_ty}, ($ ptr_ty, Int32), src_addr, stride) )
131
146
@eval export $ func_name
132
147
@eval @doc (@doc llvm_wmma_load) $ func_name
133
148
end
@@ -245,7 +260,9 @@ for a_layout in ["col", "row"],
245
260
b_vars = ntuple (i -> :(b[$ i]), b_sz)
246
261
c_vars = ntuple (i -> :(c[$ i]), c_sz)
247
262
248
- @eval $ func_name (a, b, c) = ccall ($ ccall_name, llvmcall, NTuple{$ d_sz, $ d_frag_ty}, ($ (a_types... ), $ (b_types... ), $ (c_types... )), $ (a_vars... ), $ (b_vars... ), $ (c_vars... ))
263
+ struct_ty = Symbol (" LLVMStruct$d_sz " )
264
+
265
+ @eval $ func_name (a, b, c) = convert (NTuple{$ d_sz, $ d_frag_ty}, ccall ($ ccall_name, llvmcall, $ struct_ty{$ d_frag_ty}, ($ (a_types... ), $ (b_types... ), $ (c_types... )), $ (a_vars... ), $ (b_vars... ), $ (c_vars... )))
249
266
@eval export $ func_name
250
267
@eval @doc (@doc llvm_wmma_mma) $ func_name
251
268
end
0 commit comments