Skip to content

Commit e660791

Browse files
authored
Merge pull request #737 from bratpiorka/rrudnick_l0_minor
minor L0 provider improvements
2 parents 87f9cdc + abc31b9 commit e660791

File tree

2 files changed

+65
-28
lines changed

2 files changed

+65
-28
lines changed

benchmark/ubench.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030

3131
#include "utils_common.h"
3232

33-
#if (defined UMF_BUILD_GPU_TESTS)
33+
#if (defined UMF_BUILD_LIBUMF_POOL_DISJOINT && \
34+
defined UMF_BUILD_LEVEL_ZERO_PROVIDER && defined UMF_BUILD_GPU_TESTS)
3435
#include "utils_level_zero.h"
3536
#endif
3637

src/provider/provider_level_zero.c

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ static void store_last_native_error(int32_t native_error) {
6969
TLS_last_native_error = native_error;
7070
}
7171

72-
umf_result_t ze2umf_result(ze_result_t result) {
72+
static umf_result_t ze2umf_result(ze_result_t result) {
7373
switch (result) {
7474
case ZE_RESULT_SUCCESS:
7575
return UMF_RESULT_SUCCESS;
@@ -125,7 +125,8 @@ static void init_ze_global_state(void) {
125125
}
126126
}
127127

128-
umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
128+
static umf_result_t ze_memory_provider_initialize(void *params,
129+
void **provider) {
129130
if (provider == NULL || params == NULL) {
130131
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
131132
}
@@ -181,8 +182,11 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
181182
return UMF_RESULT_SUCCESS;
182183
}
183184

184-
void ze_memory_provider_finalize(void *provider) {
185-
assert(provider);
185+
static void ze_memory_provider_finalize(void *provider) {
186+
if (provider == NULL) {
187+
ASSERT(0);
188+
return;
189+
}
186190

187191
util_init_once(&ze_is_initialized, init_ze_global_state);
188192
umf_ba_global_free(provider);
@@ -194,8 +198,10 @@ void ze_memory_provider_finalize(void *provider) {
194198

195199
static bool use_relaxed_allocation(ze_memory_provider_t *ze_provider,
196200
size_t size) {
201+
assert(ze_provider);
197202
assert(ze_provider->device);
198203
assert(ze_provider->device_properties.maxMemAllocSize > 0);
204+
199205
return size > ze_provider->device_properties.maxMemAllocSize;
200206
}
201207

@@ -207,8 +213,9 @@ static ze_relaxed_allocation_limits_exp_desc_t relaxed_device_allocation_desc =
207213
static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
208214
size_t alignment,
209215
void **resultPtr) {
210-
assert(provider);
211-
assert(resultPtr);
216+
if (provider == NULL || resultPtr == NULL) {
217+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
218+
}
212219

213220
ze_memory_provider_t *ze_provider = (ze_memory_provider_t *)provider;
214221

@@ -256,7 +263,10 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
256263
break;
257264
}
258265
default:
259-
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
266+
// this shouldn't happen as we check the memory_type settings during
267+
// the initialization
268+
LOG_ERR("unsupported USM memory type");
269+
return UMF_RESULT_ERROR_UNKNOWN;
260270
}
261271

262272
if (ze_result != ZE_RESULT_SUCCESS) {
@@ -279,29 +289,41 @@ static umf_result_t ze_memory_provider_free(void *provider, void *ptr,
279289
size_t bytes) {
280290
(void)bytes;
281291

282-
assert(provider);
292+
if (provider == NULL) {
293+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
294+
}
295+
296+
if (ptr == NULL) {
297+
return UMF_RESULT_SUCCESS;
298+
}
299+
283300
ze_memory_provider_t *ze_provider = (ze_memory_provider_t *)provider;
284301
ze_result_t ze_result = g_ze_ops.zeMemFree(ze_provider->context, ptr);
285302
return ze2umf_result(ze_result);
286303
}
287304

288-
void ze_memory_provider_get_last_native_error(void *provider,
289-
const char **ppMessage,
290-
int32_t *pError) {
305+
static void ze_memory_provider_get_last_native_error(void *provider,
306+
const char **ppMessage,
307+
int32_t *pError) {
291308
(void)provider;
292-
(void)ppMessage;
293309

294-
assert(pError);
310+
if (ppMessage == NULL || pError == NULL) {
311+
ASSERT(0);
312+
return;
313+
}
295314

296315
*pError = TLS_last_native_error;
297316
}
298317

299318
static umf_result_t ze_memory_provider_get_min_page_size(void *provider,
300319
void *ptr,
301320
size_t *pageSize) {
302-
(void)provider;
303321
(void)ptr;
304322

323+
if (provider == NULL || pageSize == NULL) {
324+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
325+
}
326+
305327
// TODO
306328
*pageSize = 1024 * 64;
307329
return UMF_RESULT_SUCCESS;
@@ -330,15 +352,18 @@ static umf_result_t ze_memory_provider_purge_force(void *provider, void *ptr,
330352
static umf_result_t
331353
ze_memory_provider_get_recommended_page_size(void *provider, size_t size,
332354
size_t *pageSize) {
333-
(void)provider;
334355
(void)size;
335356

357+
if (provider == NULL || pageSize == NULL) {
358+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
359+
}
360+
336361
// TODO
337362
*pageSize = 1024 * 64;
338363
return UMF_RESULT_SUCCESS;
339364
}
340365

341-
const char *ze_memory_provider_get_name(void *provider) {
366+
static const char *ze_memory_provider_get_name(void *provider) {
342367
(void)provider;
343368
return "LEVEL_ZERO";
344369
}
@@ -376,8 +401,10 @@ typedef struct ze_ipc_data_t {
376401

377402
static umf_result_t ze_memory_provider_get_ipc_handle_size(void *provider,
378403
size_t *size) {
379-
(void)provider;
380-
ASSERT(size != NULL);
404+
if (provider == NULL || size == NULL) {
405+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
406+
}
407+
381408
*size = sizeof(ze_ipc_data_t);
382409
return UMF_RESULT_SUCCESS;
383410
}
@@ -386,9 +413,12 @@ static umf_result_t ze_memory_provider_get_ipc_handle(void *provider,
386413
const void *ptr,
387414
size_t size,
388415
void *providerIpcData) {
389-
ASSERT(ptr != NULL);
390-
ASSERT(providerIpcData != NULL);
391416
(void)size;
417+
418+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
419+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
420+
}
421+
392422
ze_result_t ze_result;
393423
ze_ipc_data_t *ze_ipc_data = (ze_ipc_data_t *)providerIpcData;
394424
struct ze_memory_provider_t *ze_provider =
@@ -408,8 +438,10 @@ static umf_result_t ze_memory_provider_get_ipc_handle(void *provider,
408438

409439
static umf_result_t ze_memory_provider_put_ipc_handle(void *provider,
410440
void *providerIpcData) {
411-
ASSERT(provider != NULL);
412-
ASSERT(providerIpcData != NULL);
441+
if (provider == NULL || providerIpcData == NULL) {
442+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
443+
}
444+
413445
ze_result_t ze_result;
414446
struct ze_memory_provider_t *ze_provider =
415447
(struct ze_memory_provider_t *)provider;
@@ -434,9 +466,10 @@ static umf_result_t ze_memory_provider_put_ipc_handle(void *provider,
434466
static umf_result_t ze_memory_provider_open_ipc_handle(void *provider,
435467
void *providerIpcData,
436468
void **ptr) {
437-
ASSERT(provider != NULL);
438-
ASSERT(providerIpcData != NULL);
439-
ASSERT(ptr != NULL);
469+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
470+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
471+
}
472+
440473
ze_result_t ze_result;
441474
ze_ipc_data_t *ze_ipc_data = (ze_ipc_data_t *)providerIpcData;
442475
struct ze_memory_provider_t *ze_provider =
@@ -471,9 +504,12 @@ static umf_result_t ze_memory_provider_open_ipc_handle(void *provider,
471504

472505
static umf_result_t
473506
ze_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
474-
ASSERT(provider != NULL);
475-
ASSERT(ptr != NULL);
476507
(void)size;
508+
509+
if (provider == NULL || ptr == NULL) {
510+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
511+
}
512+
477513
ze_result_t ze_result;
478514
struct ze_memory_provider_t *ze_provider =
479515
(struct ze_memory_provider_t *)provider;

0 commit comments

Comments
 (0)