Skip to content

🎨 Add Reusable Lifespan Contexts for RabbitMQ and Redis in servicelib.fastapi #7547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,77 @@
import contextlib
from collections.abc import Iterator
from typing import Final

from common_library.errors_classes import OsparcErrorMixin
from fastapi import FastAPI
from fastapi_lifespan_manager import State

from ..logging_utils import log_context


class LifespanError(OsparcErrorMixin, RuntimeError): ...


class LifespanOnStartupError(LifespanError):
msg_template = "Failed during startup of {module}"
msg_template = "Failed during startup of {lifespan_name}"


class LifespanOnShutdownError(LifespanError):
msg_template = "Failed during shutdown of {module}"
msg_template = "Failed during shutdown of {lifespan_name}"


class LifespanAlreadyCalledError(LifespanError):
msg_template = "The lifespan '{lifespan_name}' has already been called."


class LifespanExpectedCalledError(LifespanError):
msg_template = "The lifespan '{lifespan_name}' was not called. Ensure it is properly configured and invoked."


_CALLED_LIFESPANS_KEY: Final[str] = "_CALLED_LIFESPANS"


def is_lifespan_called(state: State, lifespan_name: str) -> bool:
# NOTE: This assert is meant to catch a common mistake:
# The `lifespan` function should accept up to two *optional* positional arguments: (app: FastAPI, state: State).
# Valid signatures include: `()`, `(app)`, `(app, state)`, or even `(_, state)`.
# It's easy to accidentally swap or misplace these arguments.
assert not isinstance( # nosec
state, FastAPI
), "Did you swap arguments? `lifespan(app, state)` expects (app: FastAPI, state: State)"

called_lifespans = state.get(_CALLED_LIFESPANS_KEY, set())
return lifespan_name in called_lifespans


def mark_lifespace_called(state: State, lifespan_name: str) -> State:
"""Validates if a lifespan has already been called and records it in the state.
Raises LifespanAlreadyCalledError if the lifespan has already been called.
"""
if is_lifespan_called(state, lifespan_name):
raise LifespanAlreadyCalledError(lifespan_name=lifespan_name)

called_lifespans = state.get(_CALLED_LIFESPANS_KEY, set())
called_lifespans.add(lifespan_name)
return {_CALLED_LIFESPANS_KEY: called_lifespans}


def ensure_lifespan_called(state: State, lifespan_name: str) -> None:
"""Ensures that a lifespan has been called.
Raises LifespanNotCalledError if the lifespan has not been called.
"""
if not is_lifespan_called(state, lifespan_name):
raise LifespanExpectedCalledError(lifespan_name=lifespan_name)


@contextlib.contextmanager
def lifespan_context(
logger, level, lifespan_name: str, state: State
) -> Iterator[State]:
"""Helper context manager to log lifespan event and mark lifespan as called."""

with log_context(logger, level, lifespan_name):
# Check if lifespan has already been called
called_state = mark_lifespace_called(state, lifespan_name)

yield called_state
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from fastapi import FastAPI
from fastapi_lifespan_manager import State
from servicelib.logging_utils import log_catch, log_context
from settings_library.postgres import PostgresSettings
from sqlalchemy.ext.asyncio import AsyncEngine

from ..db_asyncpg_utils import create_async_engine_and_database_ready
from .lifespan_utils import LifespanOnStartupError
from ..logging_utils import log_catch
from .lifespan_utils import LifespanOnStartupError, lifespan_context

_logger = logging.getLogger(__name__)

Expand All @@ -30,8 +30,10 @@ def create_postgres_database_input_state(settings: PostgresSettings) -> State:

async def postgres_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[State]:

with log_context(_logger, logging.INFO, f"{__name__}"):
_lifespan_name = f"{__name__}.{postgres_database_lifespan.__name__}"

with lifespan_context(_logger, logging.INFO, _lifespan_name, state) as called_state:
# Validate input state
settings = state[PostgresLifespanState.POSTGRES_SETTINGS]

if settings is None or not isinstance(settings, PostgresSettings):
Expand All @@ -48,6 +50,7 @@ async def postgres_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[

yield {
PostgresLifespanState.POSTGRES_ASYNC_ENGINE: async_engine,
**called_state,
}

finally:
Expand Down
8 changes: 8 additions & 0 deletions packages/service-library/src/servicelib/fastapi/rabbitmq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings

from fastapi import FastAPI
from models_library.rabbitmq_messages import RabbitMessageBase
Expand Down Expand Up @@ -55,6 +56,13 @@ def setup_rabbit(
settings -- Rabbit settings or if None, the connection to rabbit is not done upon startup
name -- name for the rmq client name
"""
warnings.warn(
"The 'setup_rabbit' function is deprecated and will be removed in a future release. "
"Please use 'rabbitmq_lifespan' for managing RabbitMQ connections.",
DeprecationWarning,
stacklevel=2,
)

app.state.rabbitmq_client = None # RabbitMQClient | None
app.state.rabbitmq_client_name = name
app.state.rabbitmq_settings = settings
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from collections.abc import AsyncIterator

from fastapi import FastAPI
from fastapi_lifespan_manager import State
from pydantic import BaseModel, ValidationError
from settings_library.rabbit import RabbitSettings

from ..rabbitmq import wait_till_rabbitmq_responsive
from .lifespan_utils import (
LifespanOnStartupError,
lifespan_context,
)

_logger = logging.getLogger(__name__)


class RabbitMQConfigurationError(LifespanOnStartupError):
msg_template = "Invalid RabbitMQ config on startup : {validation_error}"


class RabbitMQLifespanState(BaseModel):
RABBIT_SETTINGS: RabbitSettings


async def rabbitmq_connectivity_lifespan(
_: FastAPI, state: State
) -> AsyncIterator[State]:
"""Ensures RabbitMQ connectivity during lifespan.

For creating clients, use additional lifespans like rabbitmq_rpc_client_context.
"""
_lifespan_name = f"{__name__}.{rabbitmq_connectivity_lifespan.__name__}"

with lifespan_context(_logger, logging.INFO, _lifespan_name, state) as called_state:

# Validate input state
try:
rabbit_state = RabbitMQLifespanState.model_validate(state)
rabbit_dsn_with_secrets = rabbit_state.RABBIT_SETTINGS.dsn
except ValidationError as exc:
raise RabbitMQConfigurationError(validation_error=exc, state=state) from exc

# Wait for RabbitMQ to be responsive
await wait_till_rabbitmq_responsive(rabbit_dsn_with_secrets)

yield {"RABBIT_CONNECTIVITY_LIFESPAN_NAME": _lifespan_name, **called_state}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import asyncio
import logging
from collections.abc import AsyncIterator
from typing import Annotated

from fastapi import FastAPI
from fastapi_lifespan_manager import State
from pydantic import BaseModel, StringConstraints, ValidationError
from settings_library.redis import RedisDatabase, RedisSettings

from ..logging_utils import log_catch, log_context
from ..redis import RedisClientSDK
from .lifespan_utils import LifespanOnStartupError, lifespan_context

_logger = logging.getLogger(__name__)


class RedisConfigurationError(LifespanOnStartupError):
msg_template = "Invalid redis config on startup : {validation_error}"


class RedisLifespanState(BaseModel):
REDIS_SETTINGS: RedisSettings
REDIS_CLIENT_NAME: Annotated[str, StringConstraints(min_length=3, max_length=32)]
REDIS_CLIENT_DB: RedisDatabase


async def redis_client_sdk_lifespan(_: FastAPI, state: State) -> AsyncIterator[State]:
_lifespan_name = f"{__name__}.{redis_client_sdk_lifespan.__name__}"

with lifespan_context(_logger, logging.INFO, _lifespan_name, state) as called_state:

# Validate input state
try:
redis_state = RedisLifespanState.model_validate(state)
redis_dsn_with_secrets = redis_state.REDIS_SETTINGS.build_redis_dsn(
redis_state.REDIS_CLIENT_DB
)
except ValidationError as exc:
raise RedisConfigurationError(validation_error=exc, state=state) from exc

# Setup client
with log_context(
_logger,
logging.INFO,
f"Creating redis client with name={redis_state.REDIS_CLIENT_NAME}",
):
# NOTE: sdk integrats waiting until connection is ready
# and will raise an exception if it cannot connect
redis_client = RedisClientSDK(
redis_dsn_with_secrets,
client_name=redis_state.REDIS_CLIENT_NAME,
)

try:
yield {"REDIS_CLIENT_SDK": redis_client, **called_state}
finally:
# Teardown client
with log_catch(_logger, reraise=False):
await asyncio.wait_for(
redis_client.shutdown(),
# NOTE: shutdown already has a _HEALTHCHECK_TASK_TIMEOUT_S of 10s
timeout=20,
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from models_library.rabbitmq_basic_types import RPCNamespace

from ._client import RabbitMQClient
from ._client_rpc import RabbitMQRPCClient
from ._client_rpc import RabbitMQRPCClient, rabbitmq_rpc_client_context
from ._constants import BIND_TO_ALL_TOPICS, RPC_REQUEST_DEFAULT_TIMEOUT_S
from ._errors import (
RemoteMethodNotRegisteredError,
Expand All @@ -28,6 +28,7 @@
"RabbitMQRPCClient",
"RemoteMethodNotRegisteredError",
"is_rabbitmq_responsive",
"rabbitmq_rpc_client_context",
"wait_till_rabbitmq_responsive",
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import functools
import logging
from collections.abc import Callable
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any

Expand Down Expand Up @@ -156,3 +157,19 @@ async def unregister_handler(self, handler: Callable[..., Any]) -> None:
raise RPCNotInitializedError

await self._rpc.unregister(handler)


@asynccontextmanager
async def rabbitmq_rpc_client_context(
rpc_client_name: str, settings: RabbitSettings, **kwargs
) -> AsyncIterator[RabbitMQRPCClient]:
"""
Adapter to create and close a RabbitMQRPCClient using an async context manager.
"""
rpc_client = await RabbitMQRPCClient.create(
client_name=rpc_client_name, settings=settings, **kwargs
)
try:
yield rpc_client
finally:
await rpc_client.close()
43 changes: 39 additions & 4 deletions packages/service-library/tests/fastapi/test_lifespan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
from pytest_mock import MockerFixture
from pytest_simcore.helpers.logging_tools import log_context
from servicelib.fastapi.lifespan_utils import (
LifespanAlreadyCalledError,
LifespanExpectedCalledError,
LifespanOnShutdownError,
LifespanOnStartupError,
ensure_lifespan_called,
mark_lifespace_called,
)


Expand Down Expand Up @@ -186,7 +190,7 @@ async def lifespan_failing_on_startup(app: FastAPI) -> AsyncIterator[State]:
startup_step(_name)
except RuntimeError as exc:
handle_error(_name, exc)
raise LifespanOnStartupError(module=_name) from exc
raise LifespanOnStartupError(lifespan_name=_name) from exc
yield {}
shutdown_step(_name)

Expand All @@ -201,7 +205,7 @@ async def lifespan_failing_on_shutdown(app: FastAPI) -> AsyncIterator[State]:
shutdown_step(_name)
except RuntimeError as exc:
handle_error(_name, exc)
raise LifespanOnShutdownError(module=_name) from exc
raise LifespanOnShutdownError(lifespan_name=_name) from exc

return {
"startup_step": startup_step,
Expand All @@ -228,7 +232,7 @@ async def test_app_lifespan_with_error_on_startup(
assert not failing_lifespan_manager["startup_step"].called
assert not failing_lifespan_manager["shutdown_step"].called
assert exception.error_context() == {
"module": "lifespan_failing_on_startup",
"lifespan_name": "lifespan_failing_on_startup",
"message": "Failed during startup of lifespan_failing_on_startup",
"code": "RuntimeError.LifespanError.LifespanOnStartupError",
}
Expand All @@ -250,7 +254,38 @@ async def test_app_lifespan_with_error_on_shutdown(
assert failing_lifespan_manager["startup_step"].called
assert not failing_lifespan_manager["shutdown_step"].called
assert exception.error_context() == {
"module": "lifespan_failing_on_shutdown",
"lifespan_name": "lifespan_failing_on_shutdown",
"message": "Failed during shutdown of lifespan_failing_on_shutdown",
"code": "RuntimeError.LifespanError.LifespanOnShutdownError",
}


async def test_lifespan_called_more_than_once(is_pdb_enabled: bool):
app_lifespan = LifespanManager()

@app_lifespan.add
async def _one(_, state: State) -> AsyncIterator[State]:
called_state = mark_lifespace_called(state, "test_lifespan_one")
yield {"other": 0, **called_state}

@app_lifespan.add
async def _two(_, state: State) -> AsyncIterator[State]:
ensure_lifespan_called(state, "test_lifespan_one")

with pytest.raises(LifespanExpectedCalledError):
ensure_lifespan_called(state, "test_lifespan_three")

called_state = mark_lifespace_called(state, "test_lifespan_two")
yield {"something": 0, **called_state}

app_lifespan.add(_one) # added "by mistake"

with pytest.raises(LifespanAlreadyCalledError) as err_info:
async with ASGILifespanManager(
FastAPI(lifespan=app_lifespan),
startup_timeout=None if is_pdb_enabled else 10,
shutdown_timeout=None if is_pdb_enabled else 10,
):
...

assert err_info.value.lifespan_name == "test_lifespan_one"
Loading
Loading