Skip to content

WIP: Is7510/catalog service refactor #7546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,44 @@
from typing import Final

from common_library.errors_classes import OsparcErrorMixin
from fastapi import FastAPI
from fastapi_lifespan_manager import State


class LifespanError(OsparcErrorMixin, RuntimeError): ...


class LifespanOnStartupError(LifespanError):
msg_template = "Failed during startup of {module}"
msg_template = "Failed during startup of {lifespan_name}"


class LifespanOnShutdownError(LifespanError):
msg_template = "Failed during shutdown of {module}"
msg_template = "Failed during shutdown of {lifespan_name}"


class LifespanAlreadyCalledError(LifespanError):
msg_template = "The lifespan '{lifespan_name}' has already been called."


_CALLED_LIFESPANS_KEY: Final[str] = "_CALLED_LIFESPANS"


def is_lifespan_called(state: State, lifespan_name: str) -> bool:
called_lifespans = state.get(_CALLED_LIFESPANS_KEY, set())
return lifespan_name in called_lifespans


def record_lifespan_called_once(state: State, lifespan_name: str) -> State:
"""Validates if a lifespan has already been called and records it in the state.
Raises LifespanAlreadyCalledError if the lifespan has already been called.
"""
assert not isinstance( # nosec
state, FastAPI
), "TIP: lifespan func has (app, state) positional arguments"

if is_lifespan_called(state, lifespan_name):
raise LifespanAlreadyCalledError(lifespan_name=lifespan_name)

called_lifespans = state.get(_CALLED_LIFESPANS_KEY, set())
called_lifespans.add(lifespan_name)
return {_CALLED_LIFESPANS_KEY: called_lifespans}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sqlalchemy.ext.asyncio import AsyncEngine

from ..db_asyncpg_utils import create_async_engine_and_database_ready
from .lifespan_utils import LifespanOnStartupError
from .lifespan_utils import LifespanOnStartupError, record_lifespan_called_once

_logger = logging.getLogger(__name__)

Expand All @@ -32,6 +32,9 @@ async def postgres_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[

with log_context(_logger, logging.INFO, f"{__name__}"):

# Mark lifespan as called
called_state = record_lifespan_called_once(state, "postgres_database_lifespan")

settings = state[PostgresLifespanState.POSTGRES_SETTINGS]

if settings is None or not isinstance(settings, PostgresSettings):
Expand All @@ -48,6 +51,7 @@ async def postgres_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[

yield {
PostgresLifespanState.POSTGRES_ASYNC_ENGINE: async_engine,
**called_state,
}

finally:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import asyncio
import logging
from collections.abc import AsyncIterator
from typing import Annotated

from fastapi import FastAPI
from fastapi_lifespan_manager import State
from pydantic import BaseModel, ConfigDict, StringConstraints, ValidationError
from servicelib.logging_utils import log_catch, log_context
from settings_library.redis import RedisDatabase, RedisSettings

from ..redis import RedisClientSDK
from .lifespan_utils import LifespanOnStartupError, record_lifespan_called_once

_logger = logging.getLogger(__name__)


class RedisConfigurationError(LifespanOnStartupError):
msg_template = "Invalid redis config on startup : {validation_error}"


class RedisLifespanState(BaseModel):
REDIS_SETTINGS: RedisSettings
REDIS_CLIENT_NAME: Annotated[str, StringConstraints(min_length=3, max_length=32)]
REDIS_CLIENT_DB: RedisDatabase

model_config = ConfigDict(
extra="allow",
arbitrary_types_allowed=True, # RedisClientSDK has some arbitrary types and this class will never be serialized
)


async def redis_database_lifespan(_: FastAPI, state: State) -> AsyncIterator[State]:
with log_context(_logger, logging.INFO, f"{__name__}"):

# Check if lifespan has already been called
called_state = record_lifespan_called_once(state, "redis_database_lifespan")

# Validate input state
try:
redis_state = RedisLifespanState.model_validate(state)
redis_dsn_with_secrets = redis_state.REDIS_SETTINGS.build_redis_dsn(
redis_state.REDIS_CLIENT_DB
)
except ValidationError as exc:
raise RedisConfigurationError(validation_error=exc, state=state) from exc

# Setup client
with log_context(
_logger,
logging.INFO,
f"Creating redis client with name={redis_state.REDIS_CLIENT_NAME}",
):
redis_client = RedisClientSDK(
redis_dsn_with_secrets,
client_name=redis_state.REDIS_CLIENT_NAME,
)

try:
yield {"REDIS_CLIENT_SDK": redis_client, **called_state}
finally:
# Teardown client
if redis_client:
with log_catch(_logger, reraise=False):
await asyncio.wait_for(
redis_client.shutdown(),
# NOTE: shutdown already has a _HEALTHCHECK_TASK_TIMEOUT_S of 10s
timeout=20,
)
38 changes: 34 additions & 4 deletions packages/service-library/tests/fastapi/test_lifespan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from pytest_mock import MockerFixture
from pytest_simcore.helpers.logging_tools import log_context
from servicelib.fastapi.lifespan_utils import (
LifespanAlreadyCalledError,
LifespanOnShutdownError,
LifespanOnStartupError,
record_lifespan_called_once,
)


Expand Down Expand Up @@ -186,7 +188,7 @@ async def lifespan_failing_on_startup(app: FastAPI) -> AsyncIterator[State]:
startup_step(_name)
except RuntimeError as exc:
handle_error(_name, exc)
raise LifespanOnStartupError(module=_name) from exc
raise LifespanOnStartupError(lifespan_name=_name) from exc
yield {}
shutdown_step(_name)

Expand All @@ -201,7 +203,7 @@ async def lifespan_failing_on_shutdown(app: FastAPI) -> AsyncIterator[State]:
shutdown_step(_name)
except RuntimeError as exc:
handle_error(_name, exc)
raise LifespanOnShutdownError(module=_name) from exc
raise LifespanOnShutdownError(lifespan_name=_name) from exc

return {
"startup_step": startup_step,
Expand All @@ -228,7 +230,7 @@ async def test_app_lifespan_with_error_on_startup(
assert not failing_lifespan_manager["startup_step"].called
assert not failing_lifespan_manager["shutdown_step"].called
assert exception.error_context() == {
"module": "lifespan_failing_on_startup",
"lifespan_name": "lifespan_failing_on_startup",
"message": "Failed during startup of lifespan_failing_on_startup",
"code": "RuntimeError.LifespanError.LifespanOnStartupError",
}
Expand All @@ -250,7 +252,35 @@ async def test_app_lifespan_with_error_on_shutdown(
assert failing_lifespan_manager["startup_step"].called
assert not failing_lifespan_manager["shutdown_step"].called
assert exception.error_context() == {
"module": "lifespan_failing_on_shutdown",
"lifespan_name": "lifespan_failing_on_shutdown",
"message": "Failed during shutdown of lifespan_failing_on_shutdown",
"code": "RuntimeError.LifespanError.LifespanOnShutdownError",
}


async def test_lifespan_called_more_than_once(is_pdb_enabled: bool):
state = {}

app_lifespan = LifespanManager()

@app_lifespan.add
async def _one(_, state: State) -> AsyncIterator[State]:
called_state = record_lifespan_called_once(state, "test_lifespan_one")
yield {"other": 0, **called_state}

@app_lifespan.add
async def _two(_, state: State) -> AsyncIterator[State]:
called_state = record_lifespan_called_once(state, "test_lifespan_two")
yield {"something": 0, **called_state}

app_lifespan.add(_one) # added "by mistake"

with pytest.raises(LifespanAlreadyCalledError) as err_info:
async with ASGILifespanManager(
FastAPI(lifespan=app_lifespan),
startup_timeout=None if is_pdb_enabled else 10,
shutdown_timeout=None if is_pdb_enabled else 10,
):
...

assert err_info.value.lifespan_name == "test_lifespan_one"
130 changes: 130 additions & 0 deletions packages/service-library/tests/fastapi/test_redis_lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# pylint: disable=protected-access
# pylint: disable=redefined-outer-name
# pylint: disable=too-many-arguments
# pylint: disable=unused-argument
# pylint: disable=unused-variable

from collections.abc import AsyncIterator
from typing import Annotated, Any

import pytest
import servicelib.fastapi.redis_lifespan
from asgi_lifespan import LifespanManager as ASGILifespanManager
from fastapi import FastAPI
from fastapi_lifespan_manager import LifespanManager, State
from pydantic import Field
from pytest_mock import MockerFixture, MockType
from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict
from pytest_simcore.helpers.typing_env import EnvVarsDict
from servicelib.fastapi.redis_lifespan import (
RedisConfigurationError,
RedisLifespanState,
redis_database_lifespan,
)
from settings_library.application import BaseApplicationSettings
from settings_library.redis import RedisDatabase, RedisSettings


@pytest.fixture
def mock_redis_client_sdk(mocker: MockerFixture) -> MockType:
return mocker.patch.object(
servicelib.fastapi.redis_lifespan,
"RedisClientSDK",
return_value=mocker.AsyncMock(),
)


@pytest.fixture
def app_environment(monkeypatch: pytest.MonkeyPatch) -> EnvVarsDict:
return setenvs_from_dict(
monkeypatch, RedisSettings.model_json_schema()["examples"][0]
)


@pytest.fixture
def app_lifespan(
app_environment: EnvVarsDict,
mock_redis_client_sdk: MockType,
) -> LifespanManager:
assert app_environment

class AppSettings(BaseApplicationSettings):
CATALOG_REDIS: Annotated[
RedisSettings,
Field(json_schema_extra={"auto_default_from_env": True}),
]

async def my_app_settings(app: FastAPI) -> AsyncIterator[State]:
app.state.settings = AppSettings.create_from_envs()

yield RedisLifespanState(
REDIS_SETTINGS=app.state.settings.CATALOG_REDIS,
REDIS_CLIENT_NAME="test_client",
REDIS_CLIENT_DB=RedisDatabase.LOCKS,
).model_dump()

app_lifespan = LifespanManager()
app_lifespan.add(my_app_settings)
app_lifespan.add(redis_database_lifespan)

assert not mock_redis_client_sdk.called

return app_lifespan


async def test_lifespan_redis_database_in_an_app(
is_pdb_enabled: bool,
app_environment: EnvVarsDict,
mock_redis_client_sdk: MockType,
app_lifespan: LifespanManager,
):
app = FastAPI(lifespan=app_lifespan)

async with ASGILifespanManager(
app,
startup_timeout=None if is_pdb_enabled else 10,
shutdown_timeout=None if is_pdb_enabled else 10,
) as asgi_manager:
# Verify that the Redis client SDK was created
mock_redis_client_sdk.assert_called_once_with(
app.state.settings.CATALOG_REDIS.build_redis_dsn(RedisDatabase.LOCKS),
client_name="test_client",
)

# Verify that the Redis client SDK is in the lifespan manager state
assert "REDIS_CLIENT_SDK" in asgi_manager._state # noqa: SLF001
assert app.state.settings.CATALOG_REDIS
assert (
asgi_manager._state["REDIS_CLIENT_SDK"] # noqa: SLF001
== mock_redis_client_sdk.return_value
)

# Verify that the Redis client SDK was shut down
redis_client: Any = mock_redis_client_sdk.return_value
redis_client.shutdown.assert_called_once()


async def test_lifespan_redis_database_with_invalid_settings(
is_pdb_enabled: bool,
):
async def my_app_settings(app: FastAPI) -> AsyncIterator[State]:
yield {"REDIS_SETTINGS": None}

app_lifespan = LifespanManager()
app_lifespan.add(my_app_settings)
app_lifespan.add(redis_database_lifespan)

app = FastAPI(lifespan=app_lifespan)

with pytest.raises(RedisConfigurationError, match="Invalid redis") as excinfo:
async with ASGILifespanManager(
app,
startup_timeout=None if is_pdb_enabled else 10,
shutdown_timeout=None if is_pdb_enabled else 10,
):
...

exception = excinfo.value
assert isinstance(exception, RedisConfigurationError)
assert exception.validation_error
assert exception.state["REDIS_SETTINGS"] is None
16 changes: 16 additions & 0 deletions packages/settings-library/src/settings_library/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pydantic.networks import RedisDsn
from pydantic.types import SecretStr
from pydantic_settings import SettingsConfigDict

from .base import BaseCustomSettings
from .basic_types import PortInt
Expand Down Expand Up @@ -45,3 +46,18 @@ def build_redis_dsn(self, db_index: RedisDatabase) -> str:
path=f"{db_index}",
)
)

model_config = SettingsConfigDict(
json_schema_extra={
"examples": [
# minimal required
{
"REDIS_SECURE": "0",
"REDIS_HOST": "localhost",
"REDIS_PORT": "6379",
"REDIS_USER": "user",
"REDIS_PASSWORD": "secret",
}
],
}
)
Loading
Loading