@@ -183,6 +183,10 @@ static umf_result_t ze_memory_provider_initialize(void *params,
183
183
}
184
184
185
185
static void ze_memory_provider_finalize (void * provider ) {
186
+ if (provider == NULL ) {
187
+ ASSERT (0 );
188
+ return ;
189
+ }
186
190
187
191
util_init_once (& ze_is_initialized , init_ze_global_state );
188
192
umf_ba_global_free (provider );
@@ -194,8 +198,10 @@ static void ze_memory_provider_finalize(void *provider) {
194
198
195
199
static bool use_relaxed_allocation (ze_memory_provider_t * ze_provider ,
196
200
size_t size ) {
201
+ assert (ze_provider );
197
202
assert (ze_provider -> device );
198
203
assert (ze_provider -> device_properties .maxMemAllocSize > 0 );
204
+
199
205
return size > ze_provider -> device_properties .maxMemAllocSize ;
200
206
}
201
207
@@ -207,8 +213,9 @@ static ze_relaxed_allocation_limits_exp_desc_t relaxed_device_allocation_desc =
207
213
static umf_result_t ze_memory_provider_alloc (void * provider , size_t size ,
208
214
size_t alignment ,
209
215
void * * resultPtr ) {
210
- assert (provider );
211
- assert (resultPtr );
216
+ if (provider == NULL || resultPtr == NULL ) {
217
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
218
+ }
212
219
213
220
ze_memory_provider_t * ze_provider = (ze_memory_provider_t * )provider ;
214
221
@@ -256,7 +263,10 @@ static umf_result_t ze_memory_provider_alloc(void *provider, size_t size,
256
263
break ;
257
264
}
258
265
default :
259
- return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
266
+ // this shouldn't happen as we check the memory_type settings during
267
+ // the initialization
268
+ LOG_ERR ("unsupported USM memory type" );
269
+ return UMF_RESULT_ERROR_UNKNOWN ;
260
270
}
261
271
262
272
if (ze_result != ZE_RESULT_SUCCESS ) {
@@ -279,7 +289,14 @@ static umf_result_t ze_memory_provider_free(void *provider, void *ptr,
279
289
size_t bytes ) {
280
290
(void )bytes ;
281
291
282
- assert (provider );
292
+ if (provider == NULL ) {
293
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
294
+ }
295
+
296
+ if (ptr == NULL ) {
297
+ return UMF_RESULT_SUCCESS ;
298
+ }
299
+
283
300
ze_memory_provider_t * ze_provider = (ze_memory_provider_t * )provider ;
284
301
ze_result_t ze_result = g_ze_ops .zeMemFree (ze_provider -> context , ptr );
285
302
return ze2umf_result (ze_result );
@@ -290,17 +307,23 @@ static void ze_memory_provider_get_last_native_error(void *provider,
290
307
int32_t * pError ) {
291
308
(void )provider ;
292
309
293
- assert (pError );
310
+ if (ppMessage == NULL || pError == NULL ) {
311
+ ASSERT (0 );
312
+ return ;
313
+ }
294
314
295
315
* pError = TLS_last_native_error ;
296
316
}
297
317
298
318
static umf_result_t ze_memory_provider_get_min_page_size (void * provider ,
299
319
void * ptr ,
300
320
size_t * pageSize ) {
301
- (void )provider ;
302
321
(void )ptr ;
303
322
323
+ if (provider == NULL || pageSize == NULL ) {
324
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
325
+ }
326
+
304
327
// TODO
305
328
* pageSize = 1024 * 64 ;
306
329
return UMF_RESULT_SUCCESS ;
@@ -329,9 +352,12 @@ static umf_result_t ze_memory_provider_purge_force(void *provider, void *ptr,
329
352
static umf_result_t
330
353
ze_memory_provider_get_recommended_page_size (void * provider , size_t size ,
331
354
size_t * pageSize ) {
332
- (void )provider ;
333
355
(void )size ;
334
356
357
+ if (provider == NULL || pageSize == NULL ) {
358
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
359
+ }
360
+
335
361
// TODO
336
362
* pageSize = 1024 * 64 ;
337
363
return UMF_RESULT_SUCCESS ;
@@ -375,8 +401,10 @@ typedef struct ze_ipc_data_t {
375
401
376
402
static umf_result_t ze_memory_provider_get_ipc_handle_size (void * provider ,
377
403
size_t * size ) {
378
- (void )provider ;
379
- ASSERT (size != NULL );
404
+ if (provider == NULL || size == NULL ) {
405
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
406
+ }
407
+
380
408
* size = sizeof (ze_ipc_data_t );
381
409
return UMF_RESULT_SUCCESS ;
382
410
}
@@ -385,9 +413,12 @@ static umf_result_t ze_memory_provider_get_ipc_handle(void *provider,
385
413
const void * ptr ,
386
414
size_t size ,
387
415
void * providerIpcData ) {
388
- ASSERT (ptr != NULL );
389
- ASSERT (providerIpcData != NULL );
390
416
(void )size ;
417
+
418
+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
419
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
420
+ }
421
+
391
422
ze_result_t ze_result ;
392
423
ze_ipc_data_t * ze_ipc_data = (ze_ipc_data_t * )providerIpcData ;
393
424
struct ze_memory_provider_t * ze_provider =
@@ -407,8 +438,10 @@ static umf_result_t ze_memory_provider_get_ipc_handle(void *provider,
407
438
408
439
static umf_result_t ze_memory_provider_put_ipc_handle (void * provider ,
409
440
void * providerIpcData ) {
410
- ASSERT (provider != NULL );
411
- ASSERT (providerIpcData != NULL );
441
+ if (provider == NULL || providerIpcData == NULL ) {
442
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
443
+ }
444
+
412
445
ze_result_t ze_result ;
413
446
struct ze_memory_provider_t * ze_provider =
414
447
(struct ze_memory_provider_t * )provider ;
@@ -433,9 +466,10 @@ static umf_result_t ze_memory_provider_put_ipc_handle(void *provider,
433
466
static umf_result_t ze_memory_provider_open_ipc_handle (void * provider ,
434
467
void * providerIpcData ,
435
468
void * * ptr ) {
436
- ASSERT (provider != NULL );
437
- ASSERT (providerIpcData != NULL );
438
- ASSERT (ptr != NULL );
469
+ if (provider == NULL || ptr == NULL || providerIpcData == NULL ) {
470
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
471
+ }
472
+
439
473
ze_result_t ze_result ;
440
474
ze_ipc_data_t * ze_ipc_data = (ze_ipc_data_t * )providerIpcData ;
441
475
struct ze_memory_provider_t * ze_provider =
@@ -470,9 +504,12 @@ static umf_result_t ze_memory_provider_open_ipc_handle(void *provider,
470
504
471
505
static umf_result_t
472
506
ze_memory_provider_close_ipc_handle (void * provider , void * ptr , size_t size ) {
473
- ASSERT (provider != NULL );
474
- ASSERT (ptr != NULL );
475
507
(void )size ;
508
+
509
+ if (provider == NULL || ptr == NULL ) {
510
+ return UMF_RESULT_ERROR_INVALID_ARGUMENT ;
511
+ }
512
+
476
513
ze_result_t ze_result ;
477
514
struct ze_memory_provider_t * ze_provider =
478
515
(struct ze_memory_provider_t * )provider ;
0 commit comments