Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 107 additions & 21 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import json
import os
import signal
import smtplib
import sys
import threading
Expand Down Expand Up @@ -2850,21 +2851,26 @@ def __init__(
self._db_reconnect_lock = asyncio.Lock()
self._db_health_watchdog_task: Optional[asyncio.Task] = None
self._db_last_reconnect_attempt_ts: float = 0.0
self._db_last_health_watchdog_reconnect_attempt_ts: float = 0.0
self._db_reconnect_cooldown_seconds: int = max(
1, int(os.getenv("PRISMA_RECONNECT_COOLDOWN_SECONDS", "15"))
)
self._db_health_watchdog_interval_seconds: int = max(
5, int(os.getenv("PRISMA_HEALTH_WATCHDOG_INTERVAL_SECONDS", "30"))
5, int(os.getenv("PRISMA_HEALTH_WATCHDOG_INTERVAL_SECONDS", "120"))
)
self._db_health_watchdog_enabled: bool = (
str_to_bool(os.getenv("PRISMA_HEALTH_WATCHDOG_ENABLED", "true")) is True
)
self._db_health_watchdog_probe_timeout_seconds: float = max(
0.5,
float(os.getenv("PRISMA_HEALTH_WATCHDOG_PROBE_TIMEOUT_SECONDS", "5.0")),
float(os.getenv("PRISMA_HEALTH_WATCHDOG_PROBE_TIMEOUT_SECONDS", "15")),
)
self._db_health_watchdog_reconnect_cooldown_seconds: int = max(
1,
int(os.getenv("PRISMA_HEALTH_WATCHDOG_RECONNECT_COOLDOWN_SECONDS", "120")),
)
self._db_watchdog_reconnect_timeout_seconds: float = max(
1.0, float(os.getenv("PRISMA_WATCHDOG_RECONNECT_TIMEOUT_SECONDS", "30.0"))
1.0, float(os.getenv("PRISMA_WATCHDOG_RECONNECT_TIMEOUT_SECONDS", "60.0"))
)
self._db_auth_reconnect_timeout_seconds: float = max(
0.5, float(os.getenv("PRISMA_AUTH_RECONNECT_TIMEOUT_SECONDS", "2.0"))
Expand Down Expand Up @@ -4202,6 +4208,65 @@ def _is_engine_alive(self) -> bool:
except (PermissionError, OSError):
return True

@staticmethod
def _format_signal_name(signal_number: int) -> str:
try:
return signal.Signals(signal_number).name
except ValueError:
return f"UNKNOWN_SIGNAL_{signal_number}"

@staticmethod
def _format_engine_wait_status(wait_status: int) -> str:
if os.WIFEXITED(wait_status):
return f"exit_code={os.WEXITSTATUS(wait_status)}"
elif os.WIFSIGNALED(wait_status):
signal_number = os.WTERMSIG(wait_status)
signal_name = PrismaClient._format_signal_name(signal_number)
core_dumped = (
os.WCOREDUMP(wait_status) if hasattr(os, "WCOREDUMP") else False
)
return (
f"signal={signal_name} signal_number={signal_number} "
f"core_dumped={core_dumped}"
)
elif os.WIFSTOPPED(wait_status):
signal_number = os.WSTOPSIG(wait_status)
signal_name = PrismaClient._format_signal_name(signal_number)
return f"stopped_by_signal={signal_name} signal_number={signal_number}"
elif hasattr(os, "WIFCONTINUED") and os.WIFCONTINUED(wait_status):
return "continued=True"
else:
return f"raw_wait_status={wait_status}"

@staticmethod
def _format_prisma_engine_exit_reason(
*,
detection_method: str,
wait_status: Optional[int],
) -> str:
if wait_status is None:
return f"detection_method={detection_method} exit_status=unavailable"
return (
f"detection_method={detection_method} "
f"{PrismaClient._format_engine_wait_status(wait_status)}"
)

@staticmethod
def _log_prisma_engine_exit_reason(
*,
pid: int,
detection_method: str,
wait_status: Optional[int],
) -> None:
Comment on lines +4242 to +4260
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _format_prisma_engine_exit_reason and _log_prisma_engine_exit_reason reference no instance state (self is only used to dispatch to _format_engine_wait_status, which is already a @staticmethod). Decorating them as @staticmethod keeps the API consistent with _format_engine_wait_status and signals to readers that they are pure utilities with no side effects on the object.

Suggested change
def _format_prisma_engine_exit_reason(
self,
*,
detection_method: str,
wait_status: Optional[int],
) -> str:
if wait_status is None:
return f"detection_method={detection_method} exit_status=unavailable"
return (
f"detection_method={detection_method} "
f"{self._format_engine_wait_status(wait_status)}"
)
def _log_prisma_engine_exit_reason(
self,
*,
pid: int,
detection_method: str,
wait_status: Optional[int],
) -> None:
@staticmethod
def _format_prisma_engine_exit_reason(
*,
detection_method: str,
wait_status: Optional[int],
) -> str:
if wait_status is None:
return f"detection_method={detection_method} exit_status=unavailable"
return (
f"detection_method={detection_method} "
f"{PrismaClient._format_engine_wait_status(wait_status)}"
)
@staticmethod
def _log_prisma_engine_exit_reason(
*,
pid: int,
detection_method: str,
wait_status: Optional[int],
) -> None:

verbose_proxy_logger.error(
"prisma-query-engine PID %s exited; %s; triggering reconnect.",
pid,
PrismaClient._format_prisma_engine_exit_reason(
detection_method=detection_method,
wait_status=wait_status,
),
)

@staticmethod
def _reap_all_zombies() -> set:
"""Reap ALL zombie child processes via waitpid(-1, WNOHANG).
Expand Down Expand Up @@ -4240,7 +4305,7 @@ def _try_waitpid_watch(self, pid: int) -> bool:
if sys.platform == "win32":
return False
try:
probe_pid, _ = os.waitpid(pid, os.WNOHANG)
probe_pid, wait_status = os.waitpid(pid, os.WNOHANG)
except ChildProcessError:
verbose_proxy_logger.debug(
"PID %s is not a child process; skipping waitpid watch.",
Expand All @@ -4249,9 +4314,10 @@ def _try_waitpid_watch(self, pid: int) -> bool:
return False

if probe_pid == pid:
verbose_proxy_logger.warning(
"prisma-query-engine PID %s already dead at watch start.",
pid,
self._log_prisma_engine_exit_reason(
pid=pid,
detection_method="waitpid watch start",
wait_status=wait_status,
)
self._engine_confirmed_dead = True
self._reap_all_zombies()
Expand Down Expand Up @@ -4286,26 +4352,32 @@ def _waitpid_thread_func(self, pid: int, loop: asyncio.AbstractEventLoop) -> Non
in its SIGCHLD handler. In that case our waitpid raises ChildProcessError.
we still notify the event loop because the engine is dead either way.
"""
wait_status: Optional[int] = None
try:
os.waitpid(pid, 0)
_, wait_status = os.waitpid(pid, 0)
except ChildProcessError:
pass
except OSError:
pass
try:
loop.call_soon_threadsafe(self._on_engine_death_from_thread, pid)
loop.call_soon_threadsafe(
self._on_engine_death_from_thread, pid, wait_status
)
except RuntimeError:
pass

def _on_engine_death_from_thread(self, dead_pid: int) -> None:
def _on_engine_death_from_thread(
self, dead_pid: int, wait_status: Optional[int] = None
) -> None:
"""Called on the event loop thread when the waitpid thread detects engine death."""
if self._engine_confirmed_dead:
return
if dead_pid != self._engine_pid:
return
verbose_proxy_logger.error(
"prisma-query-engine PID %s exited (waitpid thread); triggering reconnect.",
dead_pid,
self._log_prisma_engine_exit_reason(
pid=dead_pid,
detection_method="waitpid thread",
wait_status=wait_status,
)
self._engine_confirmed_dead = True
self._reap_all_zombies()
Expand Down Expand Up @@ -4357,9 +4429,10 @@ def _on_pidfd_readable(self) -> None:
self._engine_pidfd = -1
return
dead_pid = self._engine_pid
verbose_proxy_logger.error(
"prisma-query-engine PID %s exited (pidfd event); triggering reconnect.",
dead_pid,
self._log_prisma_engine_exit_reason(
pid=dead_pid,
detection_method="pidfd event",
wait_status=None,
)
self._engine_confirmed_dead = True
self._reap_all_zombies()
Expand All @@ -4380,9 +4453,11 @@ async def _poll_engine_proc(self) -> None:
try:
os.kill(self._engine_pid, 0)
except ProcessLookupError:
verbose_proxy_logger.error(
"prisma-query-engine PID %s gone; triggering reconnect.",
self._engine_pid,
dead_pid = self._engine_pid
self._log_prisma_engine_exit_reason(
pid=dead_pid,
detection_method="os.kill polling",
wait_status=None,
)
self._engine_confirmed_dead = True
self._reap_all_zombies()
Expand Down Expand Up @@ -4702,9 +4777,9 @@ async def start_db_health_watchdog_task(self) -> None:
self._db_health_watchdog_loop()
)
verbose_proxy_logger.info(
"Started Prisma DB health watchdog (interval=%ss, reconnect_cooldown=%ss, probe_timeout=%ss, reconnect_timeout=%ss)",
"Started Prisma DB health watchdog (interval=%ss, watchdog_reconnect_cooldown=%ss, probe_timeout=%ss, reconnect_timeout=%ss)",
self._db_health_watchdog_interval_seconds,
self._db_reconnect_cooldown_seconds,
self._db_health_watchdog_reconnect_cooldown_seconds,
self._db_health_watchdog_probe_timeout_seconds,
self._db_watchdog_reconnect_timeout_seconds,
)
Expand Down Expand Up @@ -4737,8 +4812,19 @@ async def _db_health_watchdog_loop(self) -> None:
if isinstance(
e, asyncio.TimeoutError
) or PrismaDBExceptionHandler.is_database_connection_error(e):
now = time.time()
if (
now - self._db_last_health_watchdog_reconnect_attempt_ts
< self._db_health_watchdog_reconnect_cooldown_seconds
):
verbose_proxy_logger.debug(
"Skipping DB health watchdog reconnect due to watchdog cooldown."
)
continue
self._db_last_health_watchdog_reconnect_attempt_ts = now
await self.attempt_db_reconnect(
reason="db_health_watchdog_connection_error",
force=True,
timeout_seconds=self._db_watchdog_reconnect_timeout_seconds,
)
else:
Expand Down
35 changes: 31 additions & 4 deletions tests/litellm/proxy/test_prisma_engine_watchdog.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

import asyncio
import os
import threading
import time
import signal
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
Expand Down Expand Up @@ -91,6 +90,19 @@ def test_is_engine_alive_returns_true_for_running_process(engine_client):
assert engine_client._is_engine_alive() is True


def test_format_engine_wait_status_for_exit_code(engine_client):
wait_status = 7 << 8

assert engine_client._format_engine_wait_status(wait_status) == "exit_code=7"


def test_format_engine_wait_status_for_signal(engine_client):
assert (
engine_client._format_engine_wait_status(signal.SIGTERM.value)
== "signal=SIGTERM signal_number=15 core_dumped=False"
)


# ---------------------------------------------------------------------------
# _poll_engine_proc — calls attempt_db_reconnect on death
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -399,14 +411,29 @@ def test_try_waitpid_watch_starts_thread_for_child(engine_client):
with (
patch("os.waitpid", return_value=(0, 0)),
patch("asyncio.get_running_loop", return_value=mock_loop),
patch("threading.Thread", return_value=mock_thread) as mock_thread_cls,
patch("threading.Thread", return_value=mock_thread),
):
result = engine_client._try_waitpid_watch(1234)
assert result is True
mock_thread.start.assert_called_once()
assert engine_client._engine_wait_thread is mock_thread


def test_waitpid_thread_passes_exit_status_to_event_loop(engine_client):
"""waitpid thread forwards the raw wait status so logs can include the exit reason."""
mock_loop = MagicMock()
wait_status = 9 << 8

with patch("os.waitpid", return_value=(1234, wait_status)):
engine_client._waitpid_thread_func(1234, mock_loop)

mock_loop.call_soon_threadsafe.assert_called_once_with(
engine_client._on_engine_death_from_thread,
1234,
wait_status,
)


@pytest.mark.asyncio
async def test_try_waitpid_watch_handles_already_dead_engine(engine_client) -> None:
"""_try_waitpid_watch detects engine already dead at watch start."""
Expand Down Expand Up @@ -451,7 +478,7 @@ def capture_task(coro):
return MagicMock()

with patch("asyncio.create_task", side_effect=capture_task):
engine_client._on_engine_death_from_thread(1234)
engine_client._on_engine_death_from_thread(1234, 7 << 8)

assert len(created_coros) == 1
await created_coros[0]
Expand Down
30 changes: 30 additions & 0 deletions tests/test_litellm/proxy/db/test_prisma_self_heal.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ async def test_db_health_watchdog_should_trigger_reconnect_on_db_error(

client.attempt_db_reconnect.assert_awaited_once_with(
reason="db_health_watchdog_connection_error",
force=True,
timeout_seconds=7.0,
)

Expand Down Expand Up @@ -302,10 +303,39 @@ async def test_db_health_watchdog_should_trigger_reconnect_on_probe_timeout(

client.attempt_db_reconnect.assert_awaited_once_with(
reason="db_health_watchdog_connection_error",
force=True,
timeout_seconds=9.0,
)


@pytest.mark.asyncio
async def test_db_health_watchdog_should_skip_reconnect_during_watchdog_cooldown(
mock_proxy_logging,
):
client = PrismaClient(
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
)
client.db.query_raw = AsyncMock(side_effect=Exception("db connection dropped"))
client.attempt_db_reconnect = AsyncMock(return_value=True)
client._db_health_watchdog_interval_seconds = 1
client._db_health_watchdog_reconnect_cooldown_seconds = 3600
client._db_last_health_watchdog_reconnect_attempt_ts = time.time()

with (
patch(
"litellm.proxy.utils.asyncio.sleep",
AsyncMock(side_effect=[None, asyncio.CancelledError()]),
),
patch(
"litellm.proxy.db.exception_handler.PrismaDBExceptionHandler.is_database_connection_error",
return_value=True,
),
):
await client._db_health_watchdog_loop()

client.attempt_db_reconnect.assert_not_called()


@pytest.mark.asyncio
async def test_db_health_watchdog_start_stop_lifecycle(mock_proxy_logging):
client = PrismaClient(
Expand Down
Loading