Skip to content

Commit abc31b9

Browse files
committed
improve l0 prov error handling
1 parent 57f0c36 commit abc31b9

File tree

1 file changed

+55
-18
lines changed

1 file changed

+55
-18
lines changed

src/provider/provider_level_zero.c

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ static umf_result_t ze_memory_provider_initialize(void *params,
183183
}
184184

185185
static void ze_memory_provider_finalize(void *provider) {
186+
if (provider == NULL) {
187+
ASSERT(0);
188+
return;
189+
}
186190

187191
util_init_once(&ze_is_initialized, init_ze_global_state);
188192
umf_ba_global_free(provider);
@@ -194,8 +198,10 @@ static void ze_memory_provider_finalize(void *provider) {
194198

195199
static bool use_relaxed_allocation(ze_memory_provider_t *ze_provider,
196200
size_t size) {
201+
assert(ze_provider);
197202
assert(ze_provider->device);
198203
assert(ze_provider->device_properties.maxMemAllocSize > 0);
204+
199205
return size > ze_provider->device_properties.maxMemAllocSize;
200206
}
201207

@@ -207,8 +213,9 @@ static ze_relaxed_allocation_limits_exp_desc_t relaxed_device_allocation_desc =
207213
static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
208214
size_t alignment,
209215
void **resultPtr) {
210-
assert(provider);
211-
assert(resultPtr);
216+
if (provider == NULL || resultPtr == NULL) {
217+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
218+
}
212219

213220
ze_memory_provider_t *ze_provider = (ze_memory_provider_t *)provider;
214221

@@ -256,7 +263,10 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
256263
break;
257264
}
258265
default:
259-
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
266+
// this shouldn't happen as we check the memory_type settings during
267+
// the initialization
268+
LOG_ERR("unsupported USM memory type");
269+
return UMF_RESULT_ERROR_UNKNOWN;
260270
}
261271

262272
if (ze_result != ZE_RESULT_SUCCESS) {
@@ -279,7 +289,14 @@ static umf_result_t ze_memory_provider_free(void *provider, void *ptr,
279289
size_t bytes) {
280290
(void)bytes;
281291

282-
assert(provider);
292+
if (provider == NULL) {
293+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
294+
}
295+
296+
if (ptr == NULL) {
297+
return UMF_RESULT_SUCCESS;
298+
}
299+
283300
ze_memory_provider_t *ze_provider = (ze_memory_provider_t *)provider;
284301
ze_result_t ze_result = g_ze_ops.zeMemFree(ze_provider->context, ptr);
285302
return ze2umf_result(ze_result);
@@ -290,17 +307,23 @@ static void ze_memory_provider_get_last_native_error(void *provider,
290307
int32_t *pError) {
291308
(void)provider;
292309

293-
assert(pError);
310+
if (ppMessage == NULL || pError == NULL) {
311+
ASSERT(0);
312+
return;
313+
}
294314

295315
*pError = TLS_last_native_error;
296316
}
297317

298318
static umf_result_t ze_memory_provider_get_min_page_size(void *provider,
299319
void *ptr,
300320
size_t *pageSize) {
301-
(void)provider;
302321
(void)ptr;
303322

323+
if (provider == NULL || pageSize == NULL) {
324+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
325+
}
326+
304327
// TODO
305328
*pageSize = 1024 * 64;
306329
return UMF_RESULT_SUCCESS;
@@ -329,9 +352,12 @@ static umf_result_t ze_memory_provider_purge_force(void *provider, void *ptr,
329352
static umf_result_t
330353
ze_memory_provider_get_recommended_page_size(void *provider, size_t size,
331354
size_t *pageSize) {
332-
(void)provider;
333355
(void)size;
334356

357+
if (provider == NULL || pageSize == NULL) {
358+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
359+
}
360+
335361
// TODO
336362
*pageSize = 1024 * 64;
337363
return UMF_RESULT_SUCCESS;
@@ -375,8 +401,10 @@ typedef struct ze_ipc_data_t {
375401

376402
static umf_result_t ze_memory_provider_get_ipc_handle_size(void *provider,
377403
size_t *size) {
378-
(void)provider;
379-
ASSERT(size != NULL);
404+
if (provider == NULL || size == NULL) {
405+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
406+
}
407+
380408
*size = sizeof(ze_ipc_data_t);
381409
return UMF_RESULT_SUCCESS;
382410
}
@@ -385,9 +413,12 @@ static umf_result_t ze_memory_provider_get_ipc_handle(void *provider,
385413
const void *ptr,
386414
size_t size,
387415
void *providerIpcData) {
388-
ASSERT(ptr != NULL);
389-
ASSERT(providerIpcData != NULL);
390416
(void)size;
417+
418+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
419+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
420+
}
421+
391422
ze_result_t ze_result;
392423
ze_ipc_data_t *ze_ipc_data = (ze_ipc_data_t *)providerIpcData;
393424
struct ze_memory_provider_t *ze_provider =
@@ -407,8 +438,10 @@ static umf_result_t ze_memory_provider_get_ipc_handle(void *provider,
407438

408439
static umf_result_t ze_memory_provider_put_ipc_handle(void *provider,
409440
void *providerIpcData) {
410-
ASSERT(provider != NULL);
411-
ASSERT(providerIpcData != NULL);
441+
if (provider == NULL || providerIpcData == NULL) {
442+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
443+
}
444+
412445
ze_result_t ze_result;
413446
struct ze_memory_provider_t *ze_provider =
414447
(struct ze_memory_provider_t *)provider;
@@ -433,9 +466,10 @@ static umf_result_t ze_memory_provider_put_ipc_handle(void *provider,
433466
static umf_result_t ze_memory_provider_open_ipc_handle(void *provider,
434467
void *providerIpcData,
435468
void **ptr) {
436-
ASSERT(provider != NULL);
437-
ASSERT(providerIpcData != NULL);
438-
ASSERT(ptr != NULL);
469+
if (provider == NULL || ptr == NULL || providerIpcData == NULL) {
470+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
471+
}
472+
439473
ze_result_t ze_result;
440474
ze_ipc_data_t *ze_ipc_data = (ze_ipc_data_t *)providerIpcData;
441475
struct ze_memory_provider_t *ze_provider =
@@ -470,9 +504,12 @@ static umf_result_t ze_memory_provider_open_ipc_handle(void *provider,
470504

471505
static umf_result_t
472506
ze_memory_provider_close_ipc_handle(void *provider, void *ptr, size_t size) {
473-
ASSERT(provider != NULL);
474-
ASSERT(ptr != NULL);
475507
(void)size;
508+
509+
if (provider == NULL || ptr == NULL) {
510+
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
511+
}
512+
476513
ze_result_t ze_result;
477514
struct ze_memory_provider_t *ze_provider =
478515
(struct ze_memory_provider_t *)provider;

0 commit comments

Comments
 (0)