@@ -148,6 +148,8 @@ ur_result_t ur_queue_immediate_in_order_t::queueGetNativeHandle(
148
148
ur_result_t ur_queue_immediate_in_order_t::queueFinish () {
149
149
TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::queueFinish" );
150
150
151
+ hContext->getAsyncPool ()->cleanupPoolsForQueue (this );
152
+
151
153
auto commandListLocked = commandListManager.lock ();
152
154
// TODO: use zeEventHostSynchronize instead?
153
155
TRACK_SCOPE_LATENCY (
@@ -701,31 +703,142 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueWriteHostPipe(
701
703
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
702
704
}
703
705
706
+ ur_result_t ur_queue_immediate_in_order_t::enqueueUSMAllocHelper (
707
+ ur_usm_pool_handle_t pPool, const size_t size,
708
+ const ur_exp_async_usm_alloc_properties_t *, uint32_t numEventsInWaitList,
709
+ const ur_event_handle_t *phEventWaitList, void **ppMem,
710
+ ur_event_handle_t *phEvent, ur_usm_type_t type) {
711
+ auto commandListLocked = commandListManager.lock ();
712
+
713
+ if (!pPool) {
714
+ pPool = hContext->getAsyncPool ();
715
+ }
716
+
717
+ auto device = (type == UR_USM_TYPE_HOST) ? nullptr : hDevice;
718
+ auto waitListView =
719
+ getWaitListView (commandListLocked, phEventWaitList, numEventsInWaitList);
720
+
721
+ auto asyncAlloc =
722
+ pPool->allocateEnqueued (hContext, this , device, nullptr , type, size);
723
+ if (!asyncAlloc) {
724
+ auto Ret = pPool->allocate (hContext, device, nullptr , type, size, ppMem);
725
+ if (Ret) {
726
+ return Ret;
727
+ }
728
+ } else {
729
+ ur_event_handle_t originAllocEvent;
730
+ std::tie (*ppMem, originAllocEvent) = *asyncAlloc;
731
+ waitListView = getWaitListView (commandListLocked, phEventWaitList,
732
+ numEventsInWaitList, originAllocEvent);
733
+ }
734
+
735
+ ur_command_t commandType = UR_COMMAND_FORCE_UINT32;
736
+ switch (type) {
737
+ case UR_USM_TYPE_HOST:
738
+ commandType = UR_COMMAND_ENQUEUE_USM_HOST_ALLOC_EXP;
739
+ break ;
740
+ case UR_USM_TYPE_DEVICE:
741
+ commandType = UR_COMMAND_ENQUEUE_USM_DEVICE_ALLOC_EXP;
742
+ break ;
743
+ case UR_USM_TYPE_SHARED:
744
+ commandType = UR_COMMAND_ENQUEUE_USM_SHARED_ALLOC_EXP;
745
+ break ;
746
+ default :
747
+ logger::error (" enqueueUSMAllocHelper: unsupported USM type" );
748
+ throw UR_RESULT_ERROR_UNKNOWN;
749
+ }
750
+
751
+ auto zeSignalEvent = getSignalEvent (commandListLocked, phEvent, commandType);
752
+ auto [pWaitEvents, numWaitEvents] = waitListView;
753
+ if (numWaitEvents > 0 ) {
754
+ ZE2UR_CALL (
755
+ zeCommandListAppendWaitOnEvents,
756
+ (commandListLocked->getZeCommandList (), numWaitEvents, pWaitEvents));
757
+ }
758
+ if (zeSignalEvent) {
759
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
760
+ (commandListLocked->getZeCommandList (), zeSignalEvent));
761
+ }
762
+
763
+ return UR_RESULT_SUCCESS;
764
+ }
765
+
704
766
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMDeviceAllocExp (
705
- ur_usm_pool_handle_t , const size_t ,
706
- const ur_exp_async_usm_alloc_properties_t *, uint32_t ,
707
- const ur_event_handle_t *, void **, ur_event_handle_t *) {
708
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
767
+ ur_usm_pool_handle_t pPool, const size_t size,
768
+ const ur_exp_async_usm_alloc_properties_t *pProperties,
769
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
770
+ void **ppMem, ur_event_handle_t *phEvent) {
771
+ TRACK_SCOPE_LATENCY (
772
+ " ur_queue_immediate_in_order_t::enqueueUSMDeviceAllocExp" );
773
+
774
+ return enqueueUSMAllocHelper (pPool, size, pProperties, numEventsInWaitList,
775
+ phEventWaitList, ppMem, phEvent,
776
+ UR_USM_TYPE_DEVICE);
709
777
}
710
778
711
779
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMSharedAllocExp (
712
- ur_usm_pool_handle_t , const size_t ,
713
- const ur_exp_async_usm_alloc_properties_t *, uint32_t ,
714
- const ur_event_handle_t *, void **, ur_event_handle_t *) {
715
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
780
+ ur_usm_pool_handle_t pPool, const size_t size,
781
+ const ur_exp_async_usm_alloc_properties_t *pProperties,
782
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
783
+ void **ppMem, ur_event_handle_t *phEvent) {
784
+ TRACK_SCOPE_LATENCY (
785
+ " ur_queue_immediate_in_order_t::enqueueUSMSharedAllocExp" );
786
+
787
+ return enqueueUSMAllocHelper (pPool, size, pProperties, numEventsInWaitList,
788
+ phEventWaitList, ppMem, phEvent,
789
+ UR_USM_TYPE_SHARED);
716
790
}
717
791
718
792
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMHostAllocExp (
719
- ur_usm_pool_handle_t , const size_t ,
720
- const ur_exp_async_usm_alloc_properties_t *, uint32_t ,
721
- const ur_event_handle_t *, void **, ur_event_handle_t *) {
722
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
793
+ ur_usm_pool_handle_t pPool, const size_t size,
794
+ const ur_exp_async_usm_alloc_properties_t *pProperties,
795
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
796
+ void **ppMem, ur_event_handle_t *phEvent) {
797
+ TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::enqueueUSMHostAllocExp" );
798
+
799
+ return enqueueUSMAllocHelper (pPool, size, pProperties, numEventsInWaitList,
800
+ phEventWaitList, ppMem, phEvent,
801
+ UR_USM_TYPE_HOST);
723
802
}
724
803
725
804
ur_result_t ur_queue_immediate_in_order_t::enqueueUSMFreeExp (
726
- ur_usm_pool_handle_t , void *, uint32_t , const ur_event_handle_t *,
727
- ur_event_handle_t *) {
728
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
805
+ ur_usm_pool_handle_t pPool, void *pMem, uint32_t numEventsInWaitList,
806
+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
807
+ TRACK_SCOPE_LATENCY (" ur_queue_immediate_in_order_t::enqueueUSMFreeExp" );
808
+ auto commandListLocked = commandListManager.lock ();
809
+
810
+ auto zeSignalEvent = getSignalEvent (commandListLocked, phEvent,
811
+ UR_COMMAND_ENQUEUE_USM_FREE_EXP);
812
+ auto [pWaitEvents, numWaitEvents] =
813
+ getWaitListView (commandListLocked, phEventWaitList, numEventsInWaitList);
814
+
815
+ umf_memory_pool_handle_t hPool = umfPoolByPtr (pMem);
816
+ if (!hPool) {
817
+ return UR_RESULT_SUCCESS;
818
+ }
819
+
820
+ UsmPool *usmPool = nullptr ;
821
+ auto ret = umfPoolGetTag (hPool, (void **)&usmPool);
822
+ if (ret != UR_RESULT_SUCCESS || !usmPool) {
823
+ // This should never happen
824
+ return UR_RESULT_ERROR_UNKNOWN;
825
+ }
826
+
827
+ size_t size = umfPoolMallocUsableSize (hPool, pMem);
828
+ usmPool->asyncPool .insert (pMem, size, *phEvent, this );
829
+
830
+ if (numWaitEvents > 0 ) {
831
+ ZE2UR_CALL (
832
+ zeCommandListAppendWaitOnEvents,
833
+ (commandListLocked->getZeCommandList (), numWaitEvents, pWaitEvents));
834
+ }
835
+
836
+ if (zeSignalEvent) {
837
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
838
+ (commandListLocked->getZeCommandList (), zeSignalEvent));
839
+ }
840
+
841
+ return UR_RESULT_SUCCESS;
729
842
}
730
843
731
844
ur_result_t ur_queue_immediate_in_order_t::bindlessImagesImageCopyExp (
@@ -855,9 +968,9 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueGenericCommandListsExp(
855
968
" ur_queue_immediate_in_order_t::enqueueGenericCommandListsExp" );
856
969
857
970
auto commandListLocked = commandListManager.lock ();
971
+
858
972
auto zeSignalEvent =
859
973
getSignalEvent (commandListLocked, phEvent, callerCommand);
860
-
861
974
auto [pWaitEvents, numWaitEvents] =
862
975
getWaitListView (commandListLocked, phEventWaitList, numEventsInWaitList,
863
976
additionalWaitEvent);
0 commit comments