diff --git a/Project.toml b/Project.toml index 51859d2e..601641bd 100644 --- a/Project.toml +++ b/Project.toml @@ -53,3 +53,7 @@ pocl_jll = "7" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[sources] +GPUCompiler = {url="https://github.com/JuliaGPU/GPUCompiler.jl", rev="vc/mtv"} +SPIRVIntrinsics = {url="https://github.com/JuliaGPU/OpenCL.jl", rev="vc/mtv", subdir="lib/intrinsics"} \ No newline at end of file diff --git a/src/pocl/compiler/compilation.jl b/src/pocl/compiler/compilation.jl index 5f88fba3..7e4ef5c1 100644 --- a/src/pocl/compiler/compilation.jl +++ b/src/pocl/compiler/compilation.jl @@ -6,7 +6,7 @@ const OpenCLCompilerJob = CompilerJob{SPIRVCompilerTarget, OpenCLCompilerParams} GPUCompiler.runtime_module(::CompilerJob{<:Any, OpenCLCompilerParams}) = POCL -GPUCompiler.method_table(::OpenCLCompilerJob) = method_table +GPUCompiler.method_table_view(job::OpenCLCompilerJob) = GPUCompiler.StackedMethodTable(job.world, method_table, spirv_method_table) # filter out OpenCL built-ins # TODO: eagerly lower these using the translator API @@ -50,7 +50,7 @@ end # create GPUCompiler objects - target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, version = v"1.2", kwargs...) + target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate = true, kwargs...) params = OpenCLCompilerParams() return CompilerConfig(target, params; kernel, name, always_inline) end diff --git a/src/pocl/pocl.jl b/src/pocl/pocl.jl index b16114d5..501e8b28 100644 --- a/src/pocl/pocl.jl +++ b/src/pocl/pocl.jl @@ -44,10 +44,42 @@ using GPUCompiler import LLVM using Adapt +## device overrides + +# local method table for device functions +Base.Experimental.@MethodTable(method_table) + +macro device_override(ex) + return esc( + quote + Base.Experimental.@overlay($method_table, $ex) + end + ) +end + +macro device_function(ex) + ex = macroexpand(__module__, ex) + def = ExprTools.splitdef(ex) + + # generate a function that errors + def[:body] = quote + error("This function is not intended for use on the CPU") + end + + return esc( + quote + $(ExprTools.combinedef(def)) + @device_override $ex + end + ) +end + import SPIRVIntrinsics SPIRVIntrinsics.@import_all SPIRVIntrinsics.@reexport_public +const spirv_method_table = SPIRVIntrinsics.method_table + include("compiler/compilation.jl") include("compiler/execution.jl") include("compiler/reflection.jl")