Skip to content

Commit 06baa06

Browse files
authored
[requests] remove sync filelock operations in async codepath (#7752)
* remove sync code in async codepath * init * refactor
1 parent f387949 commit 06baa06

File tree

4 files changed

+129
-95
lines changed

4 files changed

+129
-95
lines changed

sky/jobs/server/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ async def pool_tail_logs(
206206
request_cluster_name=common.JOB_CONTROLLER_NAME,
207207
)
208208

209-
request_task = api_requests.get_request(request.state.request_id,
210-
fields=['request_id'])
209+
request_task = await api_requests.get_request_async(
210+
request.state.request_id, fields=['request_id'])
211211

212212
return stream_utils.stream_response_for_long_request(
213213
request_id=request_task.request_id,

sky/server/requests/requests.py

Lines changed: 124 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -394,94 +394,6 @@ def _update_request_row_fields(
394394
return tuple(content[col] for col in REQUEST_COLUMNS)
395395

396396

397-
def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
398-
"""Kill all pending and running requests for a cluster.
399-
400-
Args:
401-
cluster_name: the name of the cluster.
402-
exclude_request_names: exclude requests with these names. This is to
403-
prevent killing the caller request.
404-
"""
405-
request_ids = [
406-
request_task.request_id
407-
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
408-
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
409-
exclude_request_names=[exclude_request_name],
410-
cluster_names=[cluster_name],
411-
fields=['request_id']))
412-
]
413-
kill_requests(request_ids)
414-
415-
416-
def kill_requests_with_prefix(request_ids: Optional[List[str]] = None,
417-
user_id: Optional[str] = None) -> List[str]:
418-
"""Kill requests with a given request ID prefix."""
419-
expanded_request_ids: Optional[List[str]] = None
420-
if request_ids is not None:
421-
expanded_request_ids = []
422-
for request_id in request_ids:
423-
request_tasks = get_requests_with_prefix(request_id,
424-
fields=['request_id'])
425-
if request_tasks is None or len(request_tasks) == 0:
426-
continue
427-
if len(request_tasks) > 1:
428-
raise ValueError(f'Multiple requests found for '
429-
f'request ID prefix: {request_id}')
430-
expanded_request_ids.append(request_tasks[0].request_id)
431-
return kill_requests(request_ids=expanded_request_ids, user_id=user_id)
432-
433-
434-
def kill_requests(request_ids: Optional[List[str]] = None,
435-
user_id: Optional[str] = None) -> List[str]:
436-
"""Kill a SkyPilot API request and set its status to cancelled.
437-
438-
Args:
439-
request_ids: The request IDs to kill. If None, all requests for the
440-
user are killed.
441-
user_id: The user ID to kill requests for. If None, all users are
442-
killed.
443-
444-
Returns:
445-
A list of request IDs that were cancelled.
446-
"""
447-
if request_ids is None:
448-
request_ids = [
449-
request_task.request_id
450-
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
451-
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
452-
# Avoid cancelling the cancel request itself.
453-
exclude_request_names=['sky.api_cancel'],
454-
user_id=user_id,
455-
fields=['request_id']))
456-
]
457-
cancelled_request_ids = []
458-
for request_id in request_ids:
459-
with update_request(request_id) as request_record:
460-
if request_record is None:
461-
logger.debug(f'No request ID {request_id}')
462-
continue
463-
# Skip internal requests. The internal requests are scheduled with
464-
# request_id in range(len(INTERNAL_REQUEST_EVENTS)).
465-
if request_record.request_id in set(
466-
event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
467-
continue
468-
if request_record.status > RequestStatus.RUNNING:
469-
logger.debug(f'Request {request_id} already finished')
470-
continue
471-
if request_record.pid is not None:
472-
logger.debug(f'Killing request process {request_record.pid}')
473-
# Use SIGTERM instead of SIGKILL:
474-
# - The executor can handle SIGTERM gracefully
475-
# - After SIGTERM, the executor can reuse the request process
476-
# for other requests, avoiding the overhead of forking a new
477-
# process for each request.
478-
os.kill(request_record.pid, signal.SIGTERM)
479-
request_record.status = RequestStatus.CANCELLED
480-
request_record.finished_at = time.time()
481-
cancelled_request_ids.append(request_id)
482-
return cancelled_request_ids
483-
484-
485397
def create_table(cursor, conn):
486398
# Enable WAL mode to avoid locking issues.
487399
# See: issue #1441 and PR #1509
@@ -625,6 +537,128 @@ def request_lock_path(request_id: str) -> str:
625537
return os.path.join(lock_path, f'.{request_id}.lock')
626538

627539

540+
def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
541+
"""Kill all pending and running requests for a cluster.
542+
543+
Args:
544+
cluster_name: the name of the cluster.
545+
exclude_request_names: exclude requests with these names. This is to
546+
prevent killing the caller request.
547+
"""
548+
request_ids = [
549+
request_task.request_id
550+
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
551+
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
552+
exclude_request_names=[exclude_request_name],
553+
cluster_names=[cluster_name],
554+
fields=['request_id']))
555+
]
556+
_kill_requests(request_ids)
557+
558+
559+
def kill_requests_with_prefix(request_ids: Optional[List[str]] = None,
560+
user_id: Optional[str] = None) -> List[str]:
561+
"""Kill requests with a given request ID prefix."""
562+
expanded_request_ids: Optional[List[str]] = None
563+
if request_ids is not None:
564+
expanded_request_ids = []
565+
for request_id in request_ids:
566+
request_tasks = get_requests_with_prefix(request_id,
567+
fields=['request_id'])
568+
if request_tasks is None or len(request_tasks) == 0:
569+
continue
570+
if len(request_tasks) > 1:
571+
raise ValueError(f'Multiple requests found for '
572+
f'request ID prefix: {request_id}')
573+
expanded_request_ids.append(request_tasks[0].request_id)
574+
return _kill_requests(request_ids=expanded_request_ids, user_id=user_id)
575+
576+
577+
def _should_kill_request(request_id: str,
578+
request_record: Optional[Request]) -> bool:
579+
if request_record is None:
580+
logger.debug(f'No request ID {request_id}')
581+
return False
582+
# Skip internal requests. The internal requests are scheduled with
583+
# request_id in range(len(INTERNAL_REQUEST_EVENTS)).
584+
if request_record.request_id in set(
585+
event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
586+
return False
587+
if request_record.status > RequestStatus.RUNNING:
588+
logger.debug(f'Request {request_id} already finished')
589+
return False
590+
return True
591+
592+
593+
def _kill_requests(request_ids: Optional[List[str]] = None,
594+
user_id: Optional[str] = None) -> List[str]:
595+
"""Kill a SkyPilot API request and set its status to cancelled.
596+
597+
Args:
598+
request_ids: The request IDs to kill. If None, all requests for the
599+
user are killed.
600+
user_id: The user ID to kill requests for. If None, all users are
601+
killed.
602+
603+
Returns:
604+
A list of request IDs that were cancelled.
605+
"""
606+
if request_ids is None:
607+
request_ids = [
608+
request_task.request_id
609+
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
610+
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
611+
# Avoid cancelling the cancel request itself.
612+
exclude_request_names=['sky.api_cancel'],
613+
user_id=user_id,
614+
fields=['request_id']))
615+
]
616+
cancelled_request_ids = []
617+
for request_id in request_ids:
618+
with update_request(request_id) as request_record:
619+
if not _should_kill_request(request_id, request_record):
620+
continue
621+
if request_record.pid is not None:
622+
logger.debug(f'Killing request process {request_record.pid}')
623+
# Use SIGTERM instead of SIGKILL:
624+
# - The executor can handle SIGTERM gracefully
625+
# - After SIGTERM, the executor can reuse the request process
626+
# for other requests, avoiding the overhead of forking a new
627+
# process for each request.
628+
os.kill(request_record.pid, signal.SIGTERM)
629+
request_record.status = RequestStatus.CANCELLED
630+
request_record.finished_at = time.time()
631+
cancelled_request_ids.append(request_id)
632+
return cancelled_request_ids
633+
634+
635+
@init_db_async
636+
@asyncio_utils.shield
637+
async def kill_request_async(request_id: str) -> bool:
638+
"""Kill a SkyPilot API request and set its status to cancelled.
639+
640+
Returns:
641+
True if the request was killed, False otherwise.
642+
"""
643+
async with filelock.AsyncFileLock(request_lock_path(request_id)):
644+
request = await _get_request_no_lock_async(request_id)
645+
if not _should_kill_request(request_id, request):
646+
return False
647+
assert request is not None
648+
if request.pid is not None:
649+
logger.debug(f'Killing request process {request.pid}')
650+
# Use SIGTERM instead of SIGKILL:
651+
# - The executor can handle SIGTERM gracefully
652+
# - After SIGTERM, the executor can reuse the request process
653+
# for other requests, avoiding the overhead of forking a new
654+
# process for each request.
655+
os.kill(request.pid, signal.SIGTERM)
656+
request.status = RequestStatus.CANCELLED
657+
request.finished_at = time.time()
658+
await _add_or_update_request_no_lock_async(request)
659+
return True
660+
661+
628662
@contextlib.contextmanager
629663
@init_db
630664
@metrics_lib.time_me
@@ -639,7 +673,7 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
639673
_add_or_update_request_no_lock(request)
640674

641675

642-
@init_db
676+
@init_db_async
643677
@metrics_lib.time_me
644678
@asyncio_utils.shield
645679
async def update_status_async(request_id: str, status: RequestStatus) -> None:
@@ -651,7 +685,7 @@ async def update_status_async(request_id: str, status: RequestStatus) -> None:
651685
await _add_or_update_request_no_lock_async(request)
652686

653687

654-
@init_db
688+
@init_db_async
655689
@metrics_lib.time_me
656690
@asyncio_utils.shield
657691
async def update_status_msg_async(request_id: str, status_msg: str) -> None:

sky/server/stream_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def stream_response(
373373
async def on_disconnect():
374374
logger.info(f'User terminated the connection for request '
375375
f'{request_id}')
376-
requests_lib.kill_requests([request_id])
376+
await requests_lib.kill_request_async(request_id)
377377

378378
# The background task will be run after returning a response.
379379
# https://fastapi.tiangolo.com/tutorial/background-tasks/

tests/unit_tests/test_sky/server/requests/test_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ async def test_api_cancel_race_condition(isolated_database):
156156
assert await requests_lib.create_if_not_exists_async(req) is True
157157

158158
# Cancel the request before the executor starts.
159-
cancelled = requests_lib.kill_requests(['race-cancel-before'])
160-
assert cancelled == ['race-cancel-before']
159+
cancelled = await requests_lib.kill_request_async('race-cancel-before')
160+
assert cancelled is True
161161

162162
# Execute wrapper should detect CANCELLED and return immediately.
163163
executor._request_execution_wrapper('race-cancel-before',

0 commit comments

Comments
 (0)