Skip to content

Commit b7a9039

Browse files
committed
✨ Refactor lifespan management by introducing a context manager for logging and state tracking in FastAPI
1 parent cd98e87 commit b7a9039

File tree

4 files changed

+27
-22
lines changed

4 files changed

+27
-22
lines changed

packages/service-library/src/servicelib/fastapi/lifespan_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import contextlib
2+
from collections.abc import Iterator
13
from typing import Final
24

35
from common_library.errors_classes import OsparcErrorMixin
46
from fastapi import FastAPI
57
from fastapi_lifespan_manager import State
8+
from servicelib.logging_utils import log_context
69

710

811
class LifespanError(OsparcErrorMixin, RuntimeError): ...
@@ -54,3 +57,16 @@ def ensure_lifespan_called(state: State, lifespan_name: str) -> None:
5457
"""
5558
if not is_lifespan_called(state, lifespan_name):
5659
raise LifespanExpectedCalledError(lifespan_name=lifespan_name)
60+
61+
62+
@contextlib.contextmanager
63+
def lifespan_context(
64+
logger, level, lifespan_name: str, state: State
65+
) -> Iterator[State]:
66+
"""Helper context manager to log lifespan event and mark lifespan as called."""
67+
68+
with log_context(logger, level, lifespan_name):
69+
# Check if lifespan has already been called
70+
called_state = mark_lifespace_called(state, lifespan_name)
71+
72+
yield called_state

packages/service-library/src/servicelib/fastapi/postgres_lifespan.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
from fastapi import FastAPI
77
from fastapi_lifespan_manager import State
8-
from servicelib.logging_utils import log_catch, log_context
8+
from servicelib.logging_utils import log_catch
99
from settings_library.postgres import PostgresSettings
1010
from sqlalchemy.ext.asyncio import AsyncEngine
1111

1212
from ..db_asyncpg_utils import create_async_engine_and_database_ready
13-
from .lifespan_utils import LifespanOnStartupError, mark_lifespace_called
13+
from .lifespan_utils import LifespanOnStartupError, lifespan_context
1414

1515
_logger = logging.getLogger(__name__)
1616

@@ -30,11 +30,10 @@ def create_postgres_database_input_state(settings: PostgresSettings) -> State:
3030

3131
async def postgres_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[State]:
3232

33-
with log_context(_logger, logging.INFO, f"{__name__}"):
34-
35-
# Mark lifespan as called
36-
called_state = mark_lifespace_called(state, "postgres_database_lifespan")
33+
_lifespan_name = f"{__name__}.{postgres_database_lifespan.__name__}"
3734

35+
with lifespan_context(_logger, logging.INFO, _lifespan_name, state) as called_state:
36+
# Validate input state
3837
settings = state[PostgresLifespanState.POSTGRES_SETTINGS]
3938

4039
if settings is None or not isinstance(settings, PostgresSettings):

packages/service-library/src/servicelib/fastapi/rabbitmq_lifespan.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
from fastapi import FastAPI
55
from fastapi_lifespan_manager import State
66
from pydantic import BaseModel, ValidationError
7-
from servicelib.logging_utils import log_context
87
from servicelib.rabbitmq import wait_till_rabbitmq_responsive
98
from settings_library.rabbit import RabbitSettings
109

1110
from .lifespan_utils import (
1211
LifespanOnStartupError,
13-
mark_lifespace_called,
12+
lifespan_context,
1413
)
1514

1615
_logger = logging.getLogger(__name__)
@@ -33,10 +32,7 @@ async def rabbitmq_connectivity_lifespan(
3332
"""
3433
_lifespan_name = f"{__name__}.{rabbitmq_connectivity_lifespan.__name__}"
3534

36-
with log_context(_logger, logging.INFO, _lifespan_name):
37-
38-
# Check if lifespan has already been called
39-
called_state = mark_lifespace_called(state, _lifespan_name)
35+
with lifespan_context(_logger, logging.INFO, _lifespan_name, state) as called_state:
4036

4137
# Validate input state
4238
try:

packages/service-library/src/servicelib/fastapi/redis_lifespan.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
from fastapi import FastAPI
77
from fastapi_lifespan_manager import State
8-
from pydantic import BaseModel, ConfigDict, StringConstraints, ValidationError
8+
from pydantic import BaseModel, StringConstraints, ValidationError
99
from servicelib.logging_utils import log_catch, log_context
1010
from settings_library.redis import RedisDatabase, RedisSettings
1111

1212
from ..redis import RedisClientSDK
13-
from .lifespan_utils import LifespanOnStartupError, mark_lifespace_called
13+
from .lifespan_utils import LifespanOnStartupError, lifespan_context
1414

1515
_logger = logging.getLogger(__name__)
1616

@@ -24,17 +24,11 @@ class RedisLifespanState(BaseModel):
2424
REDIS_CLIENT_NAME: Annotated[str, StringConstraints(min_length=3, max_length=32)]
2525
REDIS_CLIENT_DB: RedisDatabase
2626

27-
model_config = ConfigDict(
28-
extra="allow",
29-
arbitrary_types_allowed=True, # RedisClientSDK has some arbitrary types and this class will never be serialized
30-
)
31-
3227

3328
async def redis_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[State]:
34-
with log_context(_logger, logging.INFO, f"{__name__}"):
29+
_lifespan_name = f"{__name__}.{redis_database_lifespan.__name__}"
3530

36-
# Check if lifespan has already been called
37-
called_state = mark_lifespace_called(state, "redis_database_lifespan")
31+
with lifespan_context(_logger, logging.INFO, _lifespan_name, state) as called_state:
3832

3933
# Validate input state
4034
try:

0 commit comments

Comments
 (0)