Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 2eeb715

Browse files
Fix return type of WMMA intrinsics
1 parent c335366 commit 2eeb715

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

src/device/cuda/wmma.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ if addrspaceptr_available
7171
end
7272
end
7373

74+
# Fix for https://github.com/JuliaGPU/CUDAnative.jl/issues/587
75+
# Instead of ccall'ing the intrinsics with NTuple{N, T}, we generate
76+
# custom structs LLVMStructN{T}, containing N fields of type T, and use
77+
# those as return type.
78+
for N in unique(values(map_frag_sizes))
79+
struct_ty = Symbol("LLVMStruct$N")
80+
81+
@eval struct $struct_ty{T}
82+
Base.Cartesian.@nexprs $N i -> x_i::T
83+
end
84+
85+
@eval Base.convert(::Type{NTuple{$N, T}}, x::$struct_ty{T}) where {T} = ntuple(i -> getfield(x, i), $N)
86+
end
87+
7488
################################################################################
7589
# LOW LEVEL API
7690
################################################################################
@@ -126,8 +140,9 @@ for mat in ["a", "b", "c"],
126140
ccall_name = "extern $llvm_intr"
127141

128142
ptr_ty = addrspaceptr_available ? Core.AddrSpacePtr{arr_ty, addr_space_int} : Ref{arr_ty}
143+
struct_ty = Symbol("LLVMStruct$sz")
129144

130-
@eval $func_name(src_addr, stride) = ccall($ccall_name, llvmcall, NTuple{$sz, $frag_ty}, ($ptr_ty, Int32), src_addr, stride)
145+
@eval $func_name(src_addr, stride) = convert(NTuple{$sz, $frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$frag_ty}, ($ptr_ty, Int32), src_addr, stride))
131146
@eval export $func_name
132147
@eval @doc (@doc llvm_wmma_load) $func_name
133148
end
@@ -245,7 +260,9 @@ for a_layout in ["col", "row"],
245260
b_vars = ntuple(i -> :(b[$i]), b_sz)
246261
c_vars = ntuple(i -> :(c[$i]), c_sz)
247262

248-
@eval $func_name(a, b, c) = ccall($ccall_name, llvmcall, NTuple{$d_sz, $d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...))
263+
struct_ty = Symbol("LLVMStruct$d_sz")
264+
265+
@eval $func_name(a, b, c) = convert(NTuple{$d_sz, $d_frag_ty}, ccall($ccall_name, llvmcall, $struct_ty{$d_frag_ty}, ($(a_types...), $(b_types...), $(c_types...)), $(a_vars...), $(b_vars...), $(c_vars...)))
249266
@eval export $func_name
250267
@eval @doc (@doc llvm_wmma_mma) $func_name
251268
end

0 commit comments

Comments
 (0)