37
37
#include " ggml-backend-impl.h"
38
38
39
39
#include " ggml-sycl/backend.hpp"
40
+ #include " ggml-sycl/common.hpp"
40
41
#include " ggml-sycl/presets.hpp"
41
42
#include " ggml-sycl/gemm.hpp"
42
43
#include " ggml-sycl/sycl_hw.hpp"
@@ -490,6 +491,23 @@ catch (sycl::exception const &exc) {
490
491
std::exit (1 );
491
492
}
492
493
494
+ static void ggml_backend_sycl_buffer_memset_tensor (ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
495
+ size_t offset, size_t size) {
496
+ GGML_SYCL_DEBUG (" [SYCL] call %s\n " , __func__);
497
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context ;
498
+ SYCL_CHECK (ggml_sycl_set_device (ctx->device ));
499
+ auto stream = &(dpct::dev_mgr::instance ().get_device (ctx->device ).default_queue ());
500
+ if (size == 0 ) {
501
+ return ; // Nothing to do
502
+ }
503
+ if (tensor->data == nullptr ) {
504
+ GGML_ABORT (" Error: Tensor data pointer is null.\n " );
505
+ }
506
+ void * target_ptr = static_cast <char *>(tensor->data ) + offset;
507
+ SYCL_CHECK (CHECK_TRY_ERROR ((*stream).memset (target_ptr, value, size)));
508
+ SYCL_CHECK (CHECK_TRY_ERROR ((*stream).wait ()));
509
+ }
510
+
493
511
static void ggml_backend_sycl_buffer_reset (ggml_backend_buffer_t buffer) {
494
512
GGML_SYCL_DEBUG (" [SYCL] call %s\n " , __func__);
495
513
if (buffer == nullptr ) {
@@ -510,7 +528,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
510
528
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
511
529
/* .get_base = */ ggml_backend_sycl_buffer_get_base,
512
530
/* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
513
- /* .memset_tensor = */ NULL ,
531
+ /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor ,
514
532
/* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
515
533
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
516
534
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
0 commit comments