@@ -394,94 +394,6 @@ def _update_request_row_fields(
394394 return tuple (content [col ] for col in REQUEST_COLUMNS )
395395
396396
397- def kill_cluster_requests (cluster_name : str , exclude_request_name : str ):
398- """Kill all pending and running requests for a cluster.
399-
400- Args:
401- cluster_name: the name of the cluster.
402- exclude_request_names: exclude requests with these names. This is to
403- prevent killing the caller request.
404- """
405- request_ids = [
406- request_task .request_id
407- for request_task in get_request_tasks (req_filter = RequestTaskFilter (
408- status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
409- exclude_request_names = [exclude_request_name ],
410- cluster_names = [cluster_name ],
411- fields = ['request_id' ]))
412- ]
413- kill_requests (request_ids )
414-
415-
416- def kill_requests_with_prefix (request_ids : Optional [List [str ]] = None ,
417- user_id : Optional [str ] = None ) -> List [str ]:
418- """Kill requests with a given request ID prefix."""
419- expanded_request_ids : Optional [List [str ]] = None
420- if request_ids is not None :
421- expanded_request_ids = []
422- for request_id in request_ids :
423- request_tasks = get_requests_with_prefix (request_id ,
424- fields = ['request_id' ])
425- if request_tasks is None or len (request_tasks ) == 0 :
426- continue
427- if len (request_tasks ) > 1 :
428- raise ValueError (f'Multiple requests found for '
429- f'request ID prefix: { request_id } ' )
430- expanded_request_ids .append (request_tasks [0 ].request_id )
431- return kill_requests (request_ids = expanded_request_ids , user_id = user_id )
432-
433-
434- def kill_requests (request_ids : Optional [List [str ]] = None ,
435- user_id : Optional [str ] = None ) -> List [str ]:
436- """Kill a SkyPilot API request and set its status to cancelled.
437-
438- Args:
439- request_ids: The request IDs to kill. If None, all requests for the
440- user are killed.
441- user_id: The user ID to kill requests for. If None, all users are
442- killed.
443-
444- Returns:
445- A list of request IDs that were cancelled.
446- """
447- if request_ids is None :
448- request_ids = [
449- request_task .request_id
450- for request_task in get_request_tasks (req_filter = RequestTaskFilter (
451- status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
452- # Avoid cancelling the cancel request itself.
453- exclude_request_names = ['sky.api_cancel' ],
454- user_id = user_id ,
455- fields = ['request_id' ]))
456- ]
457- cancelled_request_ids = []
458- for request_id in request_ids :
459- with update_request (request_id ) as request_record :
460- if request_record is None :
461- logger .debug (f'No request ID { request_id } ' )
462- continue
463- # Skip internal requests. The internal requests are scheduled with
464- # request_id in range(len(INTERNAL_REQUEST_EVENTS)).
465- if request_record .request_id in set (
466- event .id for event in daemons .INTERNAL_REQUEST_DAEMONS ):
467- continue
468- if request_record .status > RequestStatus .RUNNING :
469- logger .debug (f'Request { request_id } already finished' )
470- continue
471- if request_record .pid is not None :
472- logger .debug (f'Killing request process { request_record .pid } ' )
473- # Use SIGTERM instead of SIGKILL:
474- # - The executor can handle SIGTERM gracefully
475- # - After SIGTERM, the executor can reuse the request process
476- # for other requests, avoiding the overhead of forking a new
477- # process for each request.
478- os .kill (request_record .pid , signal .SIGTERM )
479- request_record .status = RequestStatus .CANCELLED
480- request_record .finished_at = time .time ()
481- cancelled_request_ids .append (request_id )
482- return cancelled_request_ids
483-
484-
485397def create_table (cursor , conn ):
486398 # Enable WAL mode to avoid locking issues.
487399 # See: issue #1441 and PR #1509
@@ -625,6 +537,128 @@ def request_lock_path(request_id: str) -> str:
625537 return os .path .join (lock_path , f'.{ request_id } .lock' )
626538
627539
540+ def kill_cluster_requests (cluster_name : str , exclude_request_name : str ):
541+ """Kill all pending and running requests for a cluster.
542+
543+ Args:
544+ cluster_name: the name of the cluster.
545+ exclude_request_names: exclude requests with these names. This is to
546+ prevent killing the caller request.
547+ """
548+ request_ids = [
549+ request_task .request_id
550+ for request_task in get_request_tasks (req_filter = RequestTaskFilter (
551+ status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
552+ exclude_request_names = [exclude_request_name ],
553+ cluster_names = [cluster_name ],
554+ fields = ['request_id' ]))
555+ ]
556+ _kill_requests (request_ids )
557+
558+
559+ def kill_requests_with_prefix (request_ids : Optional [List [str ]] = None ,
560+ user_id : Optional [str ] = None ) -> List [str ]:
561+ """Kill requests with a given request ID prefix."""
562+ expanded_request_ids : Optional [List [str ]] = None
563+ if request_ids is not None :
564+ expanded_request_ids = []
565+ for request_id in request_ids :
566+ request_tasks = get_requests_with_prefix (request_id ,
567+ fields = ['request_id' ])
568+ if request_tasks is None or len (request_tasks ) == 0 :
569+ continue
570+ if len (request_tasks ) > 1 :
571+ raise ValueError (f'Multiple requests found for '
572+ f'request ID prefix: { request_id } ' )
573+ expanded_request_ids .append (request_tasks [0 ].request_id )
574+ return _kill_requests (request_ids = expanded_request_ids , user_id = user_id )
575+
576+
577+ def _should_kill_request (request_id : str ,
578+ request_record : Optional [Request ]) -> bool :
579+ if request_record is None :
580+ logger .debug (f'No request ID { request_id } ' )
581+ return False
582+ # Skip internal requests. The internal requests are scheduled with
583+ # request_id in range(len(INTERNAL_REQUEST_EVENTS)).
584+ if request_record .request_id in set (
585+ event .id for event in daemons .INTERNAL_REQUEST_DAEMONS ):
586+ return False
587+ if request_record .status > RequestStatus .RUNNING :
588+ logger .debug (f'Request { request_id } already finished' )
589+ return False
590+ return True
591+
592+
593+ def _kill_requests (request_ids : Optional [List [str ]] = None ,
594+ user_id : Optional [str ] = None ) -> List [str ]:
595+ """Kill a SkyPilot API request and set its status to cancelled.
596+
597+ Args:
598+ request_ids: The request IDs to kill. If None, all requests for the
599+ user are killed.
600+ user_id: The user ID to kill requests for. If None, all users are
601+ killed.
602+
603+ Returns:
604+ A list of request IDs that were cancelled.
605+ """
606+ if request_ids is None :
607+ request_ids = [
608+ request_task .request_id
609+ for request_task in get_request_tasks (req_filter = RequestTaskFilter (
610+ status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
611+ # Avoid cancelling the cancel request itself.
612+ exclude_request_names = ['sky.api_cancel' ],
613+ user_id = user_id ,
614+ fields = ['request_id' ]))
615+ ]
616+ cancelled_request_ids = []
617+ for request_id in request_ids :
618+ with update_request (request_id ) as request_record :
619+ if not _should_kill_request (request_id , request_record ):
620+ continue
621+ if request_record .pid is not None :
622+ logger .debug (f'Killing request process { request_record .pid } ' )
623+ # Use SIGTERM instead of SIGKILL:
624+ # - The executor can handle SIGTERM gracefully
625+ # - After SIGTERM, the executor can reuse the request process
626+ # for other requests, avoiding the overhead of forking a new
627+ # process for each request.
628+ os .kill (request_record .pid , signal .SIGTERM )
629+ request_record .status = RequestStatus .CANCELLED
630+ request_record .finished_at = time .time ()
631+ cancelled_request_ids .append (request_id )
632+ return cancelled_request_ids
633+
634+
635+ @init_db_async
636+ @asyncio_utils .shield
637+ async def kill_request_async (request_id : str ) -> bool :
638+ """Kill a SkyPilot API request and set its status to cancelled.
639+
640+ Returns:
641+ True if the request was killed, False otherwise.
642+ """
643+ async with filelock .AsyncFileLock (request_lock_path (request_id )):
644+ request = await _get_request_no_lock_async (request_id )
645+ if not _should_kill_request (request_id , request ):
646+ return False
647+ assert request is not None
648+ if request .pid is not None :
649+ logger .debug (f'Killing request process { request .pid } ' )
650+ # Use SIGTERM instead of SIGKILL:
651+ # - The executor can handle SIGTERM gracefully
652+ # - After SIGTERM, the executor can reuse the request process
653+ # for other requests, avoiding the overhead of forking a new
654+ # process for each request.
655+ os .kill (request .pid , signal .SIGTERM )
656+ request .status = RequestStatus .CANCELLED
657+ request .finished_at = time .time ()
658+ await _add_or_update_request_no_lock_async (request )
659+ return True
660+
661+
628662@contextlib .contextmanager
629663@init_db
630664@metrics_lib .time_me
@@ -639,7 +673,7 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
639673 _add_or_update_request_no_lock (request )
640674
641675
642- @init_db
676+ @init_db_async
643677@metrics_lib .time_me
644678@asyncio_utils .shield
645679async def update_status_async (request_id : str , status : RequestStatus ) -> None :
@@ -651,7 +685,7 @@ async def update_status_async(request_id: str, status: RequestStatus) -> None:
651685 await _add_or_update_request_no_lock_async (request )
652686
653687
654- @init_db
688+ @init_db_async
655689@metrics_lib .time_me
656690@asyncio_utils .shield
657691async def update_status_msg_async (request_id : str , status_msg : str ) -> None :
0 commit comments