diff --git a/src/CodeGen_Metal_Dev.cpp b/src/CodeGen_Metal_Dev.cpp index a3cef155a6fa..6f1b343f42a1 100644 --- a/src/CodeGen_Metal_Dev.cpp +++ b/src/CodeGen_Metal_Dev.cpp @@ -863,8 +863,40 @@ vector CodeGen_Metal_Dev::compile_to_src() { string str = src_stream.str(); debug(1) << "Metal kernel:\n" << str << "\n"; + vector buffer(str.begin(), str.end()); - buffer.push_back(0); + + auto metal_compiler = get_env_variable("HL_METAL_COMPILER"); + auto metal_linker = get_env_variable("HL_METAL_LINKER"); + if (!metal_compiler.empty() && !metal_linker.empty()) { + // The user has specified the Metal compiler and linker to use, so instead of embedding + // the shader as a string, we will embed it as a metallib + // Write the source to a temporary file. + auto tmpfile = file_make_temp("metal", ".metal"); + write_entire_file(tmpfile, buffer); + + // Compile the Metal source to a metallib. + string metalir = tmpfile + ".ir"; + string metallib = tmpfile + "lib"; + string cmd = string(metal_compiler) + " -c -o " + metalir + " " + tmpfile; + debug(2) << "Running: " << cmd << "\n"; + + int ret = system(cmd.c_str()); + user_assert(ret == 0) << "HL_METAL_COMPILER set, but failed to compile Metal source to Metal IR.\n"; + + cmd = string(metal_linker) + " -o " + metallib + " " + metalir; + debug(2) << "Running: " << cmd << "\n"; + + ret = system(cmd.c_str()); + user_assert(ret == 0) << "HL_METAL_LINKER set, but failed to compile Metal IR to Metal library.\n"; + + // Read the metallib into a buffer. + buffer = read_entire_file(metallib); + debug(2) << "Metallib size: " << buffer.size() << "\n"; + } else { + buffer.push_back(0); + } + return buffer; } diff --git a/src/runtime/metal.cpp b/src/runtime/metal.cpp index ab18ba82e318..f0e3accab9e6 100644 --- a/src/runtime/metal.cpp +++ b/src/runtime/metal.cpp @@ -15,6 +15,7 @@ extern struct ObjectiveCClass _NSConcreteGlobalBlock; extern struct ObjectiveCClass _NSConcreteStackBlock; void *dlsym(void *, const char *); #define RTLD_DEFAULT ((void *)-2) +extern objc_id /*dispatch_data_t*/ dispatch_data_create(const void * buffer, size_t size, objc_id queue, objc_id destructor); } namespace Halide { @@ -199,6 +200,35 @@ WEAK mtl_library *new_library_with_source(mtl_device *device, const char *source return result; } +WEAK mtl_library *new_library_from_data(mtl_device *device, const char *source_data, size_t source_len) { + objc_id error_return; + + debug(nullptr) << "source_len: " << source_len << "\n"; + + // With DISPATCH_DATA_DESTRUCTOR_DEFAULT, the data is copied, and no dispatch queue needs to be provided + // for the destructor + objc_id dispatch_data = dispatch_data_create((void*)source_data, source_len, nullptr, + /* DISPATCH_DATA_DESTRUCTOR_DEFAULT */ nullptr); + if (dispatch_data == nullptr) { + debug(nullptr) << "dispatch_data_create failed\n"; + } else { + debug(nullptr) << "dispatch_data_create succeeded\n"; + } + + + + typedef mtl_library *(*new_library_with_data_method)(objc_id device, objc_sel sel, objc_id data, objc_id *error_return); + new_library_with_data_method method2 = (new_library_with_data_method)&objc_msgSend; + mtl_library *result = (*method2)(device, sel_getUid("newLibraryWithData:error:"), + dispatch_data, &error_return); + + if (result == nullptr) { + ns_log_object(error_return); + } + + return result; +} + WEAK mtl_function *new_function_with_name(mtl_library *library, const char *name, size_t name_len) { objc_id name_str = wrap_string_as_ns_string(name, name_len); typedef mtl_function *(*new_function_with_name_method)(objc_id library, objc_sel sel, objc_id name); @@ -663,6 +693,17 @@ WEAK int halide_metal_device_free(void *user_context, halide_buffer_t *buf) { return halide_error_code_success; } +namespace { + bool is_compiled_metallib(const char *source, int source_size) { + // Check for the magic bytes at the beginning of the file + if (source_size < 4) { + return false; + } + const uint8_t metal_magic_number[4] = {0x4D, 0x54, 0x4c, 0x42}; // MTLB + return memcmp(source, metal_magic_number, 4) == 0; + } +} + WEAK int halide_metal_initialize_kernels(void *user_context, void **state_ptr, const char *source, int source_size) { MetalContextHolder metal_context(user_context, true); if (metal_context.error()) { @@ -673,9 +714,16 @@ WEAK int halide_metal_initialize_kernels(void *user_context, void **state_ptr, c #endif mtl_library *library{}; + bool precompiled_shader = is_compiled_metallib(source, source_size); + if (precompiled_shader) { + debug(user_context) << "halide_metal_initialize_kernels: using embedded metallib\n"; + } else { + debug(user_context) << "halide_metal_initialize_kernels: using embedded shader source\n"; + } + const bool setup = compilation_cache.kernel_state_setup(user_context, state_ptr, metal_context.device, library, - new_library_with_source, metal_context.device, - source, source_size); + precompiled_shader ? new_library_from_data : new_library_with_source, + metal_context.device, source, source_size); if (!setup || library == nullptr) { error(user_context) << "halide_metal_initialize_kernels: setup failed.\n"; return halide_error_code_generic_error;