Skip to content

Commit f17a3bb

Browse files
authored
SYCL: implement memset ggml backend buffer interface (#12580)
* SYCL: implement memset ggml backend buffer interface * use GGML_ABORT macro * Do not wait for all queues to finish for memset operation
1 parent bd40678 commit f17a3bb

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

+19-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "ggml-backend-impl.h"
3838

3939
#include "ggml-sycl/backend.hpp"
40+
#include "ggml-sycl/common.hpp"
4041
#include "ggml-sycl/presets.hpp"
4142
#include "ggml-sycl/gemm.hpp"
4243
#include "ggml-sycl/sycl_hw.hpp"
@@ -490,6 +491,23 @@ catch (sycl::exception const &exc) {
490491
std::exit(1);
491492
}
492493

494+
static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
495+
size_t offset, size_t size) {
496+
GGML_SYCL_DEBUG(" [SYCL] call %s\n", __func__);
497+
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
498+
SYCL_CHECK(ggml_sycl_set_device(ctx->device));
499+
auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
500+
if (size == 0) {
501+
return; // Nothing to do
502+
}
503+
if (tensor->data == nullptr) {
504+
GGML_ABORT("Error: Tensor data pointer is null.\n");
505+
}
506+
void * target_ptr = static_cast<char *>(tensor->data) + offset;
507+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));
508+
SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));
509+
}
510+
493511
static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
494512
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
495513
if (buffer == nullptr) {
@@ -510,7 +528,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
510528
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
511529
/* .get_base = */ ggml_backend_sycl_buffer_get_base,
512530
/* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
513-
/* .memset_tensor = */ NULL,
531+
/* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor,
514532
/* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
515533
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
516534
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,

0 commit comments

Comments
 (0)