@@ -69,7 +69,7 @@ static void store_last_native_error(int32_t native_error) {
69
69
TLS_last_native_error = native_error ;
70
70
}
71
71
72
- umf_result_t ze2umf_result (ze_result_t result ) {
72
+ static umf_result_t ze2umf_result (ze_result_t result ) {
73
73
switch (result ) {
74
74
case ZE_RESULT_SUCCESS :
75
75
return UMF_RESULT_SUCCESS ;
@@ -125,7 +125,8 @@ static void init_ze_global_state(void) {
125
125
}
126
126
}
127
127
128
- umf_result_t ze_memory_provider_initialize (void * params , void * * provider ) {
128
+ static umf_result_t ze_memory_provider_initialize (void * params ,
129
+ void * * provider ) {
129
130
if (provider == NULL || params == NULL ) {
130
131
return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
131
132
}
@@ -181,8 +182,11 @@ umf_result_t ze_memory_provider_initialize(void *params, void **provider) {
181
182
return UMF_RESULT_SUCCESS ;
182
183
}
183
184
184
- void ze_memory_provider_finalize (void * provider ) {
185
- assert (provider );
185
+ static void ze_memory_provider_finalize (void * provider ) {
186
+ if (provider == NULL ) {
187
+ ASSERT (0 );
188
+ return ;
189
+ }
186
190
187
191
util_init_once (& ze_is_initialized , init_ze_global_state );
188
192
umf_ba_global_free (provider );
@@ -194,8 +198,10 @@ void ze_memory_provider_finalize(void *provider) {
194
198
195
199
static bool use_relaxed_allocation (ze_memory_provider_t * ze_provider ,
196
200
size_t size ) {
201
+ assert (ze_provider );
197
202
assert (ze_provider -> device );
198
203
assert (ze_provider -> device_properties .maxMemAllocSize > 0 );
204
+
199
205
return size > ze_provider -> device_properties .maxMemAllocSize ;
200
206
}
201
207
@@ -207,8 +213,9 @@ static ze_relaxed_allocation_limits_exp_desc_t relaxed_device_allocation_desc =
207
213
static umf_result_t ze_memory_provider_alloc (void * provider , size_t size ,
208
214
size_t alignment ,
209
215
void * * resultPtr ) {
210
- assert (provider );
211
- assert (resultPtr );
216
+ if (provider == NULL || resultPtr == NULL ) {
217
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
218
+ }
212
219
213
220
ze_memory_provider_t * ze_provider = (ze_memory_provider_t * )provider ;
214
221
@@ -256,7 +263,10 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
256
263
break ;
257
264
}
258
265
default :
259
- return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
266
+ // this shouldn't happen as we check the memory_type settings during
267
+ // the initialization
268
+ LOG_ERR ("unsupported USM memory type" );
269
+ return UMF_RESULT_ERROR_UNKNOWN ;
260
270
}
261
271
262
272
if (ze_result != ZE_RESULT_SUCCESS ) {
@@ -279,29 +289,41 @@ static umf_result_t ze_memory_provider_free(void *provider, void *ptr,
279
289
size_t bytes ) {
280
290
(void )bytes ;
281
291
282
- assert (provider );
292
+ if (provider == NULL ) {
293
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
294
+ }
295
+
296
+ if (ptr == NULL ) {
297
+ return UMF_RESULT_SUCCESS ;
298
+ }
299
+
283
300
ze_memory_provider_t * ze_provider = (ze_memory_provider_t * )provider ;
284
301
ze_result_t ze_result = g_ze_ops .zeMemFree (ze_provider -> context , ptr );
285
302
return ze2umf_result (ze_result );
286
303
}
287
304
288
- void ze_memory_provider_get_last_native_error (void * provider ,
289
- const char * * ppMessage ,
290
- int32_t * pError ) {
305
+ static void ze_memory_provider_get_last_native_error (void * provider ,
306
+ const char * * ppMessage ,
307
+ int32_t * pError ) {
291
308
(void )provider ;
292
- (void )ppMessage ;
293
309
294
- assert (pError );
310
+ if (ppMessage == NULL || pError == NULL ) {
311
+ ASSERT (0 );
312
+ return ;
313
+ }
295
314
296
315
* pError = TLS_last_native_error ;
297
316
}
298
317
299
318
static umf_result_t ze_memory_provider_get_min_page_size (void * provider ,
300
319
void * ptr ,
301
320
size_t * pageSize ) {
302
- (void )provider ;
303
321
(void )ptr ;
304
322
323
+ if (provider == NULL || pageSize == NULL ) {
324
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
325
+ }
326
+
305
327
// TODO
306
328
* pageSize = 1024 * 64 ;
307
329
return UMF_RESULT_SUCCESS ;
@@ -330,15 +352,18 @@ static umf_result_t ze_memory_provider_purge_force(void *provider, void *ptr,
330
352
static umf_result_t
331
353
ze_memory_provider_get_recommended_page_size (void * provider , size_t size ,
332
354
size_t * pageSize ) {
333
- (void )provider ;
334
355
(void )size ;
335
356
357
+ if (provider == NULL || pageSize == NULL ) {
358
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
359
+ }
360
+
336
361
// TODO
337
362
* pageSize = 1024 * 64 ;
338
363
return UMF_RESULT_SUCCESS ;
339
364
}
340
365
341
- const char * ze_memory_provider_get_name (void * provider ) {
366
+ static const char * ze_memory_provider_get_name (void * provider ) {
342
367
(void )provider ;
343
368
return "LEVEL_ZERO" ;
344
369
}
@@ -376,8 +401,10 @@ typedef struct ze_ipc_data_t {
376
401
377
402
static umf_result_t ze_memory_provider_get_ipc_handle_size (void * provider ,
378
403
size_t * size ) {
379
- (void )provider ;
380
- ASSERT (size != NULL );
404
+ if (provider == NULL || size == NULL ) {
405
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
406
+ }
407
+
381
408
* size = sizeof (ze_ipc_data_t );
382
409
return UMF_RESULT_SUCCESS ;
383
410
}
@@ -386,9 +413,12 @@ static umf_result_t ze_memory_provider_get_ipc_handle(void *provider,
386
413
const void * ptr ,
387
414
size_t size ,
388
415
void * providerIpcData ) {
389
- ASSERT (ptr != NULL );
390
- ASSERT (providerIpcData != NULL );
391
416
(void )size ;
417
+
418
+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
419
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
420
+ }
421
+
392
422
ze_result_t ze_result ;
393
423
ze_ipc_data_t * ze_ipc_data = (ze_ipc_data_t * )providerIpcData ;
394
424
struct ze_memory_provider_t * ze_provider =
@@ -408,8 +438,10 @@ static umf_result_t ze_memory_provider_get_ipc_handle(void *provider,
408
438
409
439
static umf_result_t ze_memory_provider_put_ipc_handle (void * provider ,
410
440
void * providerIpcData ) {
411
- ASSERT (provider != NULL );
412
- ASSERT (providerIpcData != NULL );
441
+ if (provider == NULL || providerIpcData == NULL ) {
442
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
443
+ }
444
+
413
445
ze_result_t ze_result ;
414
446
struct ze_memory_provider_t * ze_provider =
415
447
(struct ze_memory_provider_t * )provider ;
@@ -434,9 +466,10 @@ static umf_result_t ze_memory_provider_put_ipc_handle(void *provider,
434
466
static umf_result_t ze_memory_provider_open_ipc_handle (void * provider ,
435
467
void * providerIpcData ,
436
468
void * * ptr ) {
437
- ASSERT (provider != NULL );
438
- ASSERT (providerIpcData != NULL );
439
- ASSERT (ptr != NULL );
469
+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
470
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
471
+ }
472
+
440
473
ze_result_t ze_result ;
441
474
ze_ipc_data_t * ze_ipc_data = (ze_ipc_data_t * )providerIpcData ;
442
475
struct ze_memory_provider_t * ze_provider =
@@ -471,9 +504,12 @@ static umf_result_t ze_memory_provider_open_ipc_handle(void *provider,
471
504
472
505
static umf_result_t
473
506
ze_memory_provider_close_ipc_handle (void * provider , void * ptr , size_t size ) {
474
- ASSERT (provider != NULL );
475
- ASSERT (ptr != NULL );
476
507
(void )size ;
508
+
509
+ if (provider == NULL || ptr == NULL ) {
510
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
511
+ }
512
+
477
513
ze_result_t ze_result ;
478
514
struct ze_memory_provider_t * ze_provider =
479
515
(struct ze_memory_provider_t * )provider ;
0 commit comments