Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/CodeGen_Metal_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,40 @@ vector<char> CodeGen_Metal_Dev::compile_to_src() {
string str = src_stream.str();
debug(1) << "Metal kernel:\n"
<< str << "\n";

vector<char> buffer(str.begin(), str.end());
buffer.push_back(0);

auto metal_compiler = get_env_variable("HL_METAL_COMPILER");
auto metal_linker = get_env_variable("HL_METAL_LINKER");
if (!metal_compiler.empty() && !metal_linker.empty()) {
// The user has specified the Metal compiler and linker to use, so instead of embedding
// the shader as a string, we will embed it as a metallib
// Write the source to a temporary file.
auto tmpfile = file_make_temp("metal", ".metal");
write_entire_file(tmpfile, buffer);

// Compile the Metal source to a metallib.
string metalir = tmpfile + ".ir";
string metallib = tmpfile + "lib";
string cmd = string(metal_compiler) + " -c -o " + metalir + " " + tmpfile;
debug(2) << "Running: " << cmd << "\n";

int ret = system(cmd.c_str());
user_assert(ret == 0) << "HL_METAL_COMPILER set, but failed to compile Metal source to Metal IR.\n";

cmd = string(metal_linker) + " -o " + metallib + " " + metalir;
debug(2) << "Running: " << cmd << "\n";

ret = system(cmd.c_str());
user_assert(ret == 0) << "HL_METAL_LINKER set, but failed to compile Metal IR to Metal library.\n";

// Read the metallib into a buffer.
buffer = read_entire_file(metallib);
debug(2) << "Metallib size: " << buffer.size() << "\n";
} else {
buffer.push_back(0);
}

return buffer;
}

Expand Down
52 changes: 50 additions & 2 deletions src/runtime/metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ extern struct ObjectiveCClass _NSConcreteGlobalBlock;
extern struct ObjectiveCClass _NSConcreteStackBlock;
void *dlsym(void *, const char *);
#define RTLD_DEFAULT ((void *)-2)
extern objc_id /*dispatch_data_t*/ dispatch_data_create(const void * buffer, size_t size, objc_id queue, objc_id destructor);
}

namespace Halide {
Expand Down Expand Up @@ -199,6 +200,35 @@ WEAK mtl_library *new_library_with_source(mtl_device *device, const char *source
return result;
}

WEAK mtl_library *new_library_from_data(mtl_device *device, const char *source_data, size_t source_len) {
objc_id error_return;

debug(nullptr) << "source_len: " << source_len << "\n";

// With DISPATCH_DATA_DESTRUCTOR_DEFAULT, the data is copied, and no dispatch queue needs to be provided
// for the destructor
objc_id dispatch_data = dispatch_data_create((void*)source_data, source_len, nullptr,
/* DISPATCH_DATA_DESTRUCTOR_DEFAULT */ nullptr);
if (dispatch_data == nullptr) {
debug(nullptr) << "dispatch_data_create failed\n";
} else {
debug(nullptr) << "dispatch_data_create succeeded\n";
}



typedef mtl_library *(*new_library_with_data_method)(objc_id device, objc_sel sel, objc_id data, objc_id *error_return);
new_library_with_data_method method2 = (new_library_with_data_method)&objc_msgSend;
mtl_library *result = (*method2)(device, sel_getUid("newLibraryWithData:error:"),
dispatch_data, &error_return);

if (result == nullptr) {
ns_log_object(error_return);
}

return result;
}

WEAK mtl_function *new_function_with_name(mtl_library *library, const char *name, size_t name_len) {
objc_id name_str = wrap_string_as_ns_string(name, name_len);
typedef mtl_function *(*new_function_with_name_method)(objc_id library, objc_sel sel, objc_id name);
Expand Down Expand Up @@ -663,6 +693,17 @@ WEAK int halide_metal_device_free(void *user_context, halide_buffer_t *buf) {
return halide_error_code_success;
}

namespace {
bool is_compiled_metallib(const char *source, int source_size) {
// Check for the magic bytes at the beginning of the file
if (source_size < 4) {
return false;
}
const uint8_t metal_magic_number[4] = {0x4D, 0x54, 0x4c, 0x42}; // MTLB
return memcmp(source, metal_magic_number, 4) == 0;
}
}

WEAK int halide_metal_initialize_kernels(void *user_context, void **state_ptr, const char *source, int source_size) {
MetalContextHolder metal_context(user_context, true);
if (metal_context.error()) {
Expand All @@ -673,9 +714,16 @@ WEAK int halide_metal_initialize_kernels(void *user_context, void **state_ptr, c
#endif

mtl_library *library{};
bool precompiled_shader = is_compiled_metallib(source, source_size);
if (precompiled_shader) {
debug(user_context) << "halide_metal_initialize_kernels: using embedded metallib\n";
} else {
debug(user_context) << "halide_metal_initialize_kernels: using embedded shader source\n";
}

const bool setup = compilation_cache.kernel_state_setup(user_context, state_ptr, metal_context.device, library,
new_library_with_source, metal_context.device,
source, source_size);
precompiled_shader ? new_library_from_data : new_library_with_source,
metal_context.device, source, source_size);
if (!setup || library == nullptr) {
error(user_context) << "halide_metal_initialize_kernels: setup failed.\n";
return halide_error_code_generic_error;
Expand Down
Loading