Skip to content

Single Database Session Per FastAPI Request Lifecycle #728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Trinkes opened this issue Jul 27, 2023 · 46 comments
Open

Single Database Session Per FastAPI Request Lifecycle #728

Trinkes opened this issue Jul 27, 2023 · 46 comments

Comments

@Trinkes
Copy link

Trinkes commented Jul 27, 2023

Hello,
I would like to implement a mechanism that ensures only one database session is created and tied to the FastAPI request lifecycle. The goal is to have a single shared database session across all resources/classes within a request, allowing easy rollback of operations in case of any request-related issues.

Here's the current code example:

import os

from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
from fastapi import FastAPI, Depends
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
import uvicorn

engine = create_engine("sqlite://")

Base = declarative_base()
Base.metadata.create_all(engine)
SessionLocal = sessionmaker(bind=engine)


class ApplicationContainer(containers.DeclarativeContainer):
    wiring_config = containers.WiringConfiguration(modules=[__name__])
    database: SessionLocal = providers.Factory(SessionLocal)


app = FastAPI()


@app.get("/")
@inject
async def root(
    session_1=Depends(Provide[ApplicationContainer.database]),
    session_2=Depends(Provide[ApplicationContainer.database]),
):
    return {"session_1_id": id(session_1), "session_2_id": id(session_2)}


container = ApplicationContainer()

if __name__ == "__main__":
    uvicorn.run(
        os.path.basename(__file__).replace(".py", "") + ":app",
        host="127.0.0.1",
        port=5000,
        log_level="info",
        reload=True,
    )

Currently, when calling the root endpoint, two separate database sessions are created, resulting in different session IDs:

{
  "session_1_id": 4347665504,
  "session_2_id": 4347668912
}

However, the desired behavior is to have both arguments (arg and arg2) hold references to the same database session for each request. Therefore, if we call the request again, the ID would change, indicating that a new session was created:

{
  "session_1_id": 4347665504,
  "session_2_id": 4347665504
}

The ultimate objective is to achieve a single database session per request, which would simplify the rollback process for any issues that might arise during the request.

Thank you for your attention to this matter, and I look forward to your guidance and suggestions.

@dandiep
Copy link

dandiep commented Aug 8, 2023

I am struggling with this exact same thing - because multiple sessions are created, you can end up with deadlocks. Ideally there would be a Request scope in dependency injector for this type of thing.

@theobouwman
Copy link

@Trinkes did you find a solution? I have the exact same problem.

@Trinkes
Copy link
Author

Trinkes commented Nov 8, 2023

@theobouwman not yet. I didn't find the time to investigate the @jess-hwang suggestion.

@theobouwman
Copy link

@Trinkes (#760) fixes the one session per request

@Trinkes
Copy link
Author

Trinkes commented Nov 8, 2023

I made some local testing with locust, and once I start having many requests executing with multi-threads (I'm using sync endpoints), the sessions don't work as expected

@theobouwman It seems it doesn't work as expected when there is more than 1 request being processed.

@theobouwman
Copy link

@Trinkes you are right.

@theobouwman
Copy link

@jess-hwang do you know the solution?

@jess-hwang
Copy link

jess-hwang commented Nov 9, 2023

Use async_sessionmaker instead of sessionmaker. Fastapi creates a new async task per request.

async_session_factory = async_sessionmaker(
    async_engine,
    expire_on_commit=False,
)
async_scoped_session_factory = async_scoped_session(
    async_session_factory,
    scopefunc=asyncio.current_task,
)

Using scoped_session, you can bind the session to the task.
If you call the session factory within the same task(same reques), the same session will be returned.

@theobouwman
Copy link

@jess-hwang I have implemented the code you gave me:

class Database:

    def __init__(self, db_url: str) -> None:
        self._engine = create_async_engine(
            db_url,
            echo=get_config().QUERY_ECHO,
            echo_pool=get_config().ECHO_POOL,
            json_serializer=_custom_json_serializer,
            pool_pre_ping=True,
            pool_size=get_config().DB_POOL_SIZE,
        )
        async_session_factory = sessionmaker(
            bind=self._engine, 
            autocommit=False,
            autoflush=False,
            expire_on_commit=False,
            class_= AsyncSession
        )
        self._async_scoped_session_factory = async_scoped_session(
            async_session_factory,
            scopefunc=asyncio.current_task,
        )

    def create_database(self) -> None:
        Base.metadata.create_all(self._engine)

    @contextmanager
    def session(self) -> Callable[..., AbstractContextManager[AsyncSession]]:
        session: AsyncSession = self._async_scoped_session_factory()
        try:
            yield session
        except Exception as e:
            # logger.exception("Session rollback because of exception")
            session.rollback()
            raise e
        finally:
            session.close()

But in my repository still a session is create for each query:

class BaseRepository(Generic[T]):
    _model: T # TODO: find out if this is best solution

    def __init__(self, session_factory: Callable[..., AbstractContextManager[AsyncSession]]) -> None:
        self.session_factory = session_factory

    async def get_by_id(self, id: str) -> T:
        with self.session_factory() as session:
            r = await session.execute(select(self._model).filter(self._model.id == id))
            return r.scalar_one_or_none()

This is how I create the dependency injector:

class Container(containers.DeclarativeContainer):

    wiring_config = containers.WiringConfiguration(packages=[
        "api.routes",
        "tasks.routes",
        "common.observability"
    ])

    config = providers.Configuration()
 db = providers.Singleton(Database, db_url=get_config().DB_URL())

    event_repository = providers.Factory(
        EventRepository, session_factory=db.provided.session)

So I dont understand what I am doing wrong? Should the with self.session_factory() as session: not reuse the already created session?

@jess-hwang
Copy link

jess-hwang commented Nov 10, 2023

@theobouwman
I think you should use async_session instead of session.

@asynccontextmanager
async def async_session(self) -> Callable[..., AbstractContextManager[AsyncSession]]:
    session = self.async_session_factory()
    try:
        yield session
    except Exception as e:
        await session.rollback()
        raise
    finally:
        await session.close()

@theobouwman
Copy link

@jess-hwang it is still creating 2 sessions when i implement your code and when I call

 async def get_by_id(self, id: str) -> T:
        async with self.session_factory() as session:
            r = await session.execute(select(self._model).filter(self._model.id == id))
            return r.scalar_one_or_none()

@theobouwman
Copy link

@jess-hwang the session only gets created once the get_by_id function is called in the BaseRepository. How could I create a session which is reused throughout the request and loaded with the dependency injector?

@theobouwman
Copy link

And this is how I inject the services:

@router.get('/event/{event_id}')
@inject
async def testt(event_id: str, event_service: EventService = Depends(Provide[Container.event_service])):
    event1 = await event_service.get_event(event_id)
    event2 = await event_service.get_event(event_id)
    return BaseResponse[List[GetEventResponse]](data=[event1, event2])

@ebrahimradi
Copy link

same problem here, any solution please?

@philipbjorge
Copy link
Contributor

Here is some pseudocode on how to make this work...

class MyContainer(containers.DeclarativeContainer):
  db_session_provider = providers.Factory(async_session_factory)
  db_session = providers.ContextLocalSingleton(provides=AsyncSession)
class DbSessionMiddleware(BaseHTTPMiddleware):
    def __init__(
        self,
        app: ASGIApp,
        db_session_provider: Provider[MyContainer.db_session_provider],
        dispatch: DispatchFunction | None = None,
    ) -> None:
        super().__init__(app, dispatch)
        self._db_session_provider = db_session_provider

    async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
        self._db_session_provider()
        return await call_next(request)
fastapi_app.add_middleware(DbSessionMiddleware)

The trick here is that you need to initialize the database session on the FastAPI request context -- Which is what we accomplish by initializing the context local singleton from the middleware.

@theobouwman
Copy link

theobouwman commented Feb 9, 2024

@philipbjorge do you have a fully working example of this?

@JobaDiniz
Copy link

Same problem. I did not understand how to implement the solution.

Shouldn't this be a feature of python-dependency-injector? This is called scoped lifecycle of container instances. This library should provide something like providers.Scoped(...) and build an extension on top for fastapi providers.FastApiRequestScoped(...).

@rumbarum
Copy link

rumbarum commented Aug 28, 2024

@Trinkes @theobouwman @ebrahimradi
I've been using ContextVar to guarantee single session per every fastapi request.
Using scopefunc=asyncio.current_task could result out different session on same request if session is used under asyncio.gater(...). Because it create different current_task for each coroutine.

Here is code. Below is quoted code, so may have some errors.
And more if you find different session during single request with below. Check your session is come from same container.
Currently, there is no built-in mechanism in PDI to ensure the creation and maintenance of a single global container. So this responsibility belong to user.

# context_func.py
from contextvars import ContextVar, Token

session_context: ContextVar[str] = ContextVar("session_context")

def get_session_context() -> str:
    return session_context.get()

def set_session_context(session_id: str) -> Token:
    return session_context.set(session_id)

def reset_session_context(context: Token) -> None:
    session_context.reset(context)
    
# container.py
from dependency_injector import containers, providers
from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session, create_async_engine

class Container(containers.DeclarativeContainer):
    config = ... # follow your setting from json or yaml or .env
    engine = providers.Singleton(
        create_async_engine,
        config.READER_DB_URL,
     )
    async_session_factory = providers.Singleton(
        sessionmaker,
        class_=AsyncSession,
        autocommit=False,
        autoflush=True,
        expire_on_commit=False,
    )
    session = providers.Singleton(
        async_scoped_session,
        session_factory=async_session_factory,
        scopefunc=get_session_context,
    )
        
# dependency.py
from context_func import get_session_context, set_session_context, reset_session_context
from uuid import uuid4

session_factory: Provider[async_scoped_session] = Provide["session.provider"]

async def get_db():
    # this code is fit for fastapi >= 0.106. Before that, `Depends` are closed after finishing request.
    # And session is wasted wating for closing request.
    # If you need to stick to under version, implementing `@contextmanager` is more efficient way saving session.
    session_id = str(uuid4())
    context = set_session_context(session_id=session_id)
    session = session_factory()
    try:
        yield session
    except Exception as e:
        raise e
    finally:
        await session.remove()
        reset_session_context(context=context)
  1. You can inject session directly to Repository or Router.
  2. If you test without router layer, you should use get_db or implement similar code for test session to attach session to same context.

Here is another version implemented on Middleware.
https://github.com/rumbarum/fastapi-boilerplate-on-di/blob/main/src/application/core/middlewares/sqlalchemy.py

@AndBondStyle
Copy link

@rumbarum
Is it possible to use your idea together with Depends(Provide[...])? I'm trying to make this code to work. It works fine with just Depends(init_session), however when using Depends(Provide[Container.db_session]) fastapi fails to recognize that it is an async generator, and just passes it to the view function as-is. Do you know any workaround?

@rumbarum
Copy link

rumbarum commented Oct 2, 2024

@AndBondStyle Of course, it works. I make use of this on all routers.

I think below would be works.

@app.get("/test2")
@inject
async def test2(db=Depends(Provide[Container.db_scoped_session])):
    res = await db.execute(sa.text("select version()"))
    return res.scalar()

@AndBondStyle
Copy link

AndBondStyle commented Oct 2, 2024

@rumbarum seems like it doesn't, because session object is not closed properly after request has finished. After a few requests there's an error: QueuePool limit of size 5 overflow 10 reached, connection timed out.

It can be fixed by rewriting the view like this, but I really want to avoid that:

@app.get("/test3")
@inject
async def test3(db=Depends(Provide[Container.db_scoped_session])):
    async with db():
        res = await db.execute(sa.text("select version()"))
        return res.scalar()

Is there any way to avoid using async with inside the view function, like in test1?

@rumbarum
Copy link

rumbarum commented Oct 2, 2024

@AndBondStyle

You should close session on dependency func or on middleware.

# this will bring constructor, not session,
session_factory = Provide["db_scoped_session.provider"]

# on dependency func
async def init_session():
    session = session_factory()
    try:
        yield session
    except Exception as e:
        raise e
    finally:
        await session.remove()

@AndBondStyle
Copy link

AndBondStyle commented Oct 2, 2024

@rumbarum but does this mean I can't use Depends(Provide[...]) directly, and instead use Depends(init_session)?

@rumbarum
Copy link

rumbarum commented Oct 2, 2024

Now I understand what you want.

If you need db_session directly, You should first create session and bind to request task coroutine context.
After that, you can get session anywhere through DI. Dont forget remove session also.

# this will bring constructor, not session,
session_factory = Provide["db_scoped_session.provider"]

# on dependency func
async def init_session():
    session = session_factory()
    try:
        yield session
    except Exception as e:
        raise e
    finally:
        await session.remove()

# this will trigger session and close session after request finished
app = FastAPI(lifespan=lifespan, dependencies=[Depends(init_session)])

then /test2 will works.

@AndBondStyle
Copy link

Ended up with this setup (gist). It's pretty hacky inside, but view functions look very clean. Maybe this example will be useful to someone

fastapi_pdi_alchemy.py
import asyncio
import os
from contextlib import asynccontextmanager
from typing import Any

import sqlalchemy as sa
from dependency_injector import providers
from dependency_injector.containers import DeclarativeContainer
from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends, FastAPI
from sqlalchemy.ext.asyncio import (
  AsyncSession,
  async_scoped_session,
  async_sessionmaker,
  create_async_engine,
)


async def init_db_engine():
  dsn = os.environ["POSTGRES_DSN"]
  engine = create_async_engine(dsn, echo=True)
  print("engine start")
  yield engine
  print("engine stop")
  await engine.dispose()


class Container(DeclarativeContainer):
  db_engine = providers.Resource(init_db_engine)
  db_session_factory = providers.Resource(async_sessionmaker, db_engine)
  db_scoped_session = providers.ThreadSafeSingleton(
      async_scoped_session,
      session_factory=db_session_factory,
      scopefunc=asyncio.current_task,
  )
  db_session = providers.Object(None)  # dummy provider
  something = providers.Factory(lambda: 123)  # example of regular dependency


session_factory = Provide["db_scoped_session"]


# Async generator to use directly with fastapi's `Depends(...)`
async def init_session():
  session = (await session_factory)()
  async with session:
      print("session before")
      yield session
      print("session after")


# This replaces the `Depends(Provide[...])` with just `Dep(...)`
# For `db_session` we want to pass `init_session` function directly, avoiding PDI
# Not an elegant solution, but works fine and adds no overhead
def wrap_dependency(dependency: Any) -> Any:
  if dependency is Container.db_session or dependency == "db_session":
      return Depends(init_session)
  return Depends(Provide[dependency])


Dep = wrap_dependency  # shortcut


# This function patches the `APIRouter.api_route` so that PDI's `@inject` decorator
# is added for every view. This needs to be called before any views are defined
def fasatpi_auto_inject():
  original = APIRouter.api_route
  def api_route_patched(self, *args, **kwargs):
      print("api route patched:", kwargs.get("path"))
      decorator = original(self, *args, **kwargs)
      # Composition of two decorators
      return lambda func: decorator(inject(func))
  
  APIRouter.api_route = api_route_patched  # type: ignore

# Call before any view definitions
fasatpi_auto_inject()


@asynccontextmanager
async def lifespan(app: FastAPI):
  container = Container()
  container.wire(modules=[__name__])
  await container.init_resources()  # type: ignore
  yield
  await container.shutdown_resources()  # type: ignore


app = FastAPI(lifespan=lifespan)


@app.get("/test")  # @inject decorator implicitly added
async def test(
  db: AsyncSession = Dep(Container.db_session),  # expands to: Depends(init_session)
  something: int = Dep(Container.something),  # expands to: Depends(Provide[...])
):
  res = await db.execute(sa.text("select version()"))
  return {
      "version": res.scalar(),
      "something": something,
  }

@AndBondStyle
Copy link

@rumbarum one more question, is it important to specifically use ThreadSafeSingleton for async_scoped_session? Seems to work fine with just Singleton or Resource, is it more safe your way?

@rumbarum
Copy link

rumbarum commented Oct 3, 2024

@AndBondStyle
I do not know why it should be that tricky, if you like then no problem.

FastAPI is on single thread, so Singleton or ThreadSafeSingleton is same.
I am not sure about Resource.

@nightblure
Copy link

nightblure commented Nov 13, 2024

@dandiep @ebrahimradi @JobaDiniz @Trinkes
hi!

I think you have the wrong idea how this should work. One sqlalchemy session object = one database connection. You must create this per request as a new object, hence a new connection, otherwise your application by definition will work with at most one client

I present to you a working example of use that will work correctly and optimally:

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from dependency_injector.wiring import Provide, inject
from dependency_injector.containers import DeclarativeContainer


class DIContainer(containers.DeclarativeContainer):
    settings = Singleton(YourSettingsClass)

    sqla_engine = providers.Singleton(
        create_async_engine,
        pool_size=20,
        max_overflow=0,
        url=settings.provided.db_url,
        echo=settings.provided.sqla_echo,
    )

    db_session_factory = providers.Singleton(
        sessionmaker,
        scoped=False,
        autoflush=False,
        engine=sqla_engine,
    )

# use case with FastAPI endpoint
@fastapi_router.get('/')
@inject
def some_handler(session_factory = Provide[DIContainer.db_session_factory]):
    with session_factory() as sqla_session:
        ...

this will work correctly, creating a session instance on every request (this is equivalent to the scope request as mentioned by @dandiep)

if something doesn't work out, I'm ready to help

@bolshakov
Copy link

@nightblure, The problem with this approach is that you don't have a session in your request but a session factory. This means you cannot instantiate your dependencies with an instance of per-request session using DI Factories. The provided approach requires doing it manually in each router:

@fastapi_router.get('/')
@inject
def some_handler(session_factory = Provide[DIContainer.db_session_factory]):
    with session_factory() as sqla_session:
        user_repository = UserRepository(session)
        account_repository = AccountRepository(session)
        use_case = UseCase(user_repository, account_repository)
        use_case.execute(request)

What we really want to do is to inject "use case" (in my example) and let DI container do all the instantiation for us:

@fastapi_router.get('/')
@inject
def some_handler(use_case = Provide[DIContainer.some_handler_use_case]):
  use_case.execute(request)

@nightblure
Copy link

nightblure commented Nov 28, 2024

@bolshakov hi!

The advantage of this approach is that we do not create a session unnecessarily. And also this is an implicit request scope, because the session here is a Factory and it does not live longer than one request

But if you want to achieve the same behavior with an injected session instead of factory, and request scope, check out this part of the docs ("Resources, wiring, and per-function execution scope"), Closing marker

@bolshakov
Copy link

@nightblure, thank you for pointing this out. I tried doing this, and it does not look like the per-function execution scope works properly with FastAPI:

async def get_session(engine: AsyncEngine) -> AsyncGenerator[AsyncSession]:
    async with AsyncSession(engine) as session, session.begin():
        yield session
        await session.commit()

class Container(containers.DeclarativeContainer):
    config = providers.Configuration()

    engine = providers.Singleton(create_async_engine, config.database.url, echo=config.database.echo)
    session = providers.Resource(get_session, engine=engine)

FastAPI Router

@router.get("/test")
@inject
async def test(session: AsyncSession = Depends(Closing[Provide[Container.session]])) -> str:
    await session.execute(text("SELECT 1"))
    return "ok"

When I call this test endpoint, I see the following output:

2024-11-28 18:41:32,532 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-11-28 18:41:32,533 INFO sqlalchemy.engine.Engine SELECT 1
2024-11-28 18:41:32,533 INFO sqlalchemy.engine.Engine [generated in 0.00009s] ()
INFO:     127.0.0.1:52057 - "GET /account/test HTTP/1.1" 200 OK

As you can see, it opens a transaction but does not close it at the end of the "test" function execution. Just to be clear, this is an expected output:

2024-11-28 18:44:16,452 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-11-28 18:44:16,453 INFO sqlalchemy.engine.Engine SELECT 1
2024-11-28 18:44:16,453 INFO sqlalchemy.engine.Engine [generated in 0.00008s] ()
2024-11-28 18:44:16,453 INFO sqlalchemy.engine.Engine COMMIT
INFO:     127.0.0.1:52111 - "GET /account/test HTTP/1.1" 200 OK

I understand we can always provide a session factory instead, but this requires instantiating other dependencies in the router or introducing more factories for dependencies that require sessions.

The sad part is that this kind of session injection works out of the box with native Fast API DI, and I managed to combine them and get a working solution by multiplying the number of factories I use:

@inject
async def _get_session(engine: AsyncEngine = Provide[Container.engine]) -> AsyncGenerator[AsyncSession]:
    async with AsyncSession(engine) as session, session.begin():
        yield session
        await session.commit()


async def get_session() -> AsyncGenerator[AsyncSession]:
    async for session in _get_session():
        yield session

class Container(containers.DeclarativeContainer):
    config = providers.Configuration()
    engine = providers.Singleton(create_async_engine, config.database.url, echo=config.database.echo)
    user_repository_factory = providers.Factory(UserRepository)    
    account_repository_factory = providers.Factory(AccountRepository)
    # ... other 

    create_account_factory = providers.Factory(
        CreateAccountFactory,
        user_repository_factory=user_repository_factory.provider,
        account_repository_factory=account_repository_factory.provider,
        # ... others
    )

@router.post("/account", status_code=status.HTTP_201_CREATED, response_model=CreateAccountResponse)
@inject
async def create_account_route(
    create_account_request: CreateAccountRequest,
    session: AsyncSession = Depends(get_session),
    create_account_factory: CreateAccountFactory = Depends(Provide[Container.create_account_factory]),
) -> CreateAccountResponse:
    create_account: CreateAccount = create_account_factory.create(session)
    account, admin = await create_account.execute(create_account_request.to_command())

    return CreateAccountResponse.from_entity(account, admin)

While this works, all the wiring happens in the CreateAccountFactory now instead of the DI container:

class CreateAccountFactory:
    def __init__(
        self,
        user_repository_factory: Callable[[AsyncSession], UserRepository],
        account_repository_factory: Callable[[AsyncSession], AccountRepository],
        # ... other deps
    ) -> None:
        self._user_repository_factory = user_repository_factory
        self._account_repository_factory = account_repository_factory

    def create(self, session: AsyncSession) -> CreateAccount:
        user_repository = self._user_repository_factory(session)
        account_repository = self._account_repository_factory(session)

        return CreateAccount(
            account_repository=account_repository,
            user_repository=user_repository,
            # ... rest 
        )

Of course I might be missing something and I would appreciate it if someone can stir me to the right path where I can offload all the instatiation to the DI Container rather then do it inside a router.

@nightblure
Copy link

nightblure commented Nov 28, 2024

@bolshakov I spent time trying to get it to work and to my surprise it didn't work! Moreover, different people have repeatedly created issues with this problem

But I found a workaround without using the Closing marker by wrapping the original inject decorator, which, if present in the arguments of the session provider, simply initializes it and disables it after the request ends. This should work fine in all cases

Full code example:

from dataclasses import dataclass
from functools import wraps
from typing import AsyncIterator

import uvicorn
from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide
from fastapi import FastAPI, APIRouter, Depends
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine


async def get_session(engine: AsyncEngine) -> AsyncIterator[AsyncSession]:
    async with AsyncSession(engine) as session, session.begin():
        yield session
        await session.commit()


@dataclass
class Settings:
    database_url: str = "sqlite+aiosqlite://"
    sqla_echo: bool = True


class Container(containers.DeclarativeContainer):
    config = providers.Singleton(Settings)
    engine = providers.Singleton(create_async_engine, config.provided.database_url, echo=config.provided.sqla_echo)
    session = providers.Resource(get_session, engine=engine)


# should be singleton by docs
container_instance = Container()


def inject_fixed(f):
    global container_instance

    @wraps(f)
    async def wrapper(*args, **kwargs):
        # Resolve session provider
        await container_instance.session.init()

        for arg, value in kwargs.items():
            if isinstance(value, Provide) and value.provider is container_instance.session:
                kwargs[arg] = await container_instance.session()

        # Use original inject decorator to inject dependencies
        result = await inject(f)(*args, **kwargs)

        # Close session provider
        await container_instance.session.shutdown()
        return result

    return wrapper


router = APIRouter()


@router.get("/test")
@inject_fixed
async def test(session: AsyncSession = Depends(Provide[container_instance.session])) -> str:
    await session.execute(text("SELECT 1"))
    return "ok"


container_instance.wire(modules=[__name__])
app = FastAPI()
app.include_router(router)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

Console output:

INFO:     Started server process [64989]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
2024-11-28 22:46:02,725 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2024-11-28 22:46:02,726 INFO sqlalchemy.engine.Engine SELECT 1
2024-11-28 22:46:02,726 INFO sqlalchemy.engine.Engine [generated in 0.00016s] ()
2024-11-28 22:46:02,727 INFO sqlalchemy.engine.Engine COMMIT
INFO:     127.0.0.1:52091 - "GET /test HTTP/1.1" 200 OK

I also want to note: if you don't like this wrapper, you can do exactly the same thing, but with fastapi or starlette pure asgi middleware


P.S.
I would recommend using the following option for obtaining a session. This will automatically rollback, you can also log an exception, this option is more reliable:

async def get_session(engine: AsyncEngine) -> AsyncIterator[AsyncSession]:
    session = AsyncSession(engine)
    try:
        print("before")
        yield session
        await session.commit()
        print("after")
    except Exception:
        await session.rollback()
    finally:
        await session.close()

@bolshakov
Copy link

@nightblure, I appreciate you spending so much time digging into this problem!

While inject_fixed works, the whole setup becomes more rigid in tests. However, your approach gave me inspiration for another idea - you don't really need to use a container singleton for the injection itself. Let me show how it works on the user/account example I showed before (I like this example because it requires an initialized session in the container to construct dependencies)

# container.py
from collections.abc import AsyncIterator

from dependency_injector import containers, providers
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine

from application.account.use_cases.create_account import CreateAccount
from application.settings import Settings
from infrastructure.persistence.repositories import AccountRepository, UserRepository
from infrastructure.security import password_service


async def get_session(engine: AsyncEngine) -> AsyncIterator[AsyncSession]:
    async with AsyncSession(engine) as session, session.begin():
        yield session
        await session.commit()


class Container(containers.DeclarativeContainer):
    """Dependency injection container"""

    config = providers.Configuration()
    config.from_dict(Settings().dict())

    engine = providers.Singleton(create_async_engine, config.database.url, echo=config.database.echo)
    session = providers.Resource(get_session, engine=engine)


    user_repository = providers.Factory(UserRepository, session=session)
    account_repository = providers.Factory(AccountRepository, session=session)
    
    password_hashing_service = providers.Singleton(
        password_service.Argon2PasswordHashingService,
        parameters=providers.Callable(
            lambda profile_name: password_service.ARGON2_PROFILES[profile_name],
            config.security.argon2_password_hashing_service.profile,
        ),
    )

    create_account = providers.Factory(
        CreateAccount,
        password_hashing_service=password_hashing_service,
        user_repository=user_repository,
        account_repository=account_repository,
        session=session,
    )

Now, let's take a look at the app initialization:

from collections.abc import AsyncGenerator, Callable

from fastapi import Depends, FastAPI

from application.containers import Container
from interface.api.routes.account import router as account_routes


def session_initializer(container: Container) -> Callable[[], AsyncGenerator]:
    async def init() -> AsyncGenerator:        
        try:
            container.session.init()
            yield
        finally:
            await container.session.shutdown()

    return init


def init_app(container: Container) -> FastAPI:
    app = FastAPI(dependencies=[Depends(session_initializer(container))])
    container.wire(packages=["my_app"])

    app.include_router(account_routes, tags=["Account"])
    return app

As you can see, I added global dependency here, which is called on every request. This allows us to inject the "create account" use case right into the router:

from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends
from starlette import status

from application.account.use_cases.create_account import CreateAccount
from application.containers import Container
from interface.api.dto.account import CreateAccountRequest, CreateAccountResponse

router = APIRouter(prefix="/account", tags=["Account"])


@router.post("/", status_code=status.HTTP_201_CREATED, response_model=CreateAccountResponse)
@inject
async def create_account_route(
    create_account_request: CreateAccountRequest,
    create_account: CreateAccount = Depends(Provide[Container.create_account]),
) -> CreateAccountResponse:
    """Create a new account"""
    account, admin = await create_account(create_account_request.to_command())

    return CreateAccountResponse.from_entity(account, admin)

It behaves as expected, and the session is automatically committed and closed at the end of each request. This approach also gives you more flexibility in tests, since you can use your own container there:

from collections.abc import AsyncGenerator

import pytest
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession

from application.containers import Container
from infrastructure.persistence.models import BaseModel

@pytest.fixture
def container() -> Container:
    container = Container()
    # Use faster hashing for tests
    container.config.security.argon2_password_hashing_service.profile.from_value("CHEAPEST")
    # Setup in-memory SQLite database
    container.config.database.url.from_value("sqlite+aiosqlite:///:memory:")
    container.config.database.echo.from_value(True)

    container.wire(packages=["my_app"])
    return container


@pytest.fixture
def engine(container: Container) -> AsyncEngine:
    return container.engine()

@pytest.fixture
async def setup_database(engine: AsyncEngine) -> None:  # noqa: PT004
    """
    Create all tables in the test database
    Usage:
        @pytest.mark.usefixtures("setup_database")
        def describe_something():
    """
    async with engine.begin() as conn:
        await conn.run_sync(BaseModel.metadata.create_all)

and then, you just need to setup an application with this container and a client:

@pytest.fixture
def app(container: Container) -> FastAPI:
    return init_app(container)


@pytest.fixture
async def client(app: FastAPI) -> AsyncGenerator[AsyncClient]:
    async with AsyncClient(app=app, base_url="http://test") as client:
        yield client

This setup allows you to perform integration testing. For unit testing, you can setup your session like this:

@pytest.fixture
async def session(engine: AsyncEngine) -> AsyncGenerator[AsyncSession]:
    async with AsyncSession(engine) as session, session.begin():
        yield session

While this works in my tests, I am concerned about whether it works correctly with asynchronous code. I can imagine that the session could leak between requests when executed asynchronously.

@nightblure
Copy link

@bolshakov Looks like a completely workable option!
But there is an important note about using it in tests, because it seems to me that you did some extra work. And I would like to tell you how you can make it simpler and use one of the main features for tests - provider overriding

first, you need to make the container a global instance:

# your container.py
class Container(containers.DeclarativeContainer):
    # your providers...

# you should to use only this variable in all places
container_instance = Container()

secondly, let's get down to the main thing - override dependencies. It’s easier to store settings in a dataclass or, for example, the pydantic.BaseSettings class (), let’s say you have it in the form:

# I will give short examples that do not correspond to your structure, but of course with your structure and a little adaptation of the code it will still work
class SecuritySettings(...):
    argon2_profile = ...

# pydantic settings or dataclass
class YourSettingsClass(...):
    security: SecuritySettings
    database_url: str = ...
    # and other config values


class Container(containers.DeclarativeContainer):
    settings = providers.Singleton(YourSettingsClass)
    # settings using example
    engine = providers.Singleton(create_async_engine, settings.provided.database_url, echo=settings.provided.database_echo)
    # other providers...

and finally - override dependencies:

from ...container import container_instance


@pytest.fixture(scope="session")
def container() -> Container:
    return container_instance


@pytest.fixture(scope="session")
def test_settings() -> YourSettingsClass:
    return YourSettingsClass(database_url="sqlite+aiosqlite:///:memory:", security=SecuritySettings(argon2_profile="CHEAPEST"))


# autouse option allows pytest to call this fixture, you don't need to call it manually
@pytest.fixture(scope="session", autouse=True)
def _provider_override(container: Container, test_settings: YourSettingsClass) -> None:
    overrides = {"settings": test_settings} # there may be other clues here as well. the dictionary key is the name of the provider
    with container.override_providers(**overrides):
        yield


@pytest.fixture(scope="session")
async def setup_database(container: Container) -> None:  # noqa: PT004
    """
    Create all tables in the test database
    Usage:
        @pytest.mark.usefixtures("setup_database")
        def describe_something():
    """
    async with container.engine.begin() as conn:
        await conn.run_sync(BaseModel.metadata.create_all)

# Since we've override the dependencies for the entire test session, you don't have to repeat the code - all you have to do is run the tests and that's it! 

more about overriding here: https://python-dependency-injector.ets-labs.org/providers/overriding.html

@bolshakov
Copy link

@nightblure, thank you for your guidance. I'm new to Python, and you helped me a lot 🙇‍♂️

@alk3mist
Copy link

alk3mist commented Jan 6, 2025

import os

import uvicorn
from dependency_injector import containers, providers
from dependency_injector.wiring import Provide, inject
from fastapi import Depends, FastAPI
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, declarative_base, sessionmaker

engine = create_engine("sqlite://")

Base = declarative_base()
Base.metadata.create_all(engine)
SessionLocal = sessionmaker(bind=engine)


class ApplicationContainer(containers.DeclarativeContainer):
    wiring_config = containers.WiringConfiguration(modules=[__name__])
    database = providers.Factory(SessionLocal)


app = FastAPI()


def get_db(s: Session = Depends(Provide[ApplicationContainer.database])) -> Session:
    return s


@app.get("/")
@inject
async def root(
    session_1: Session = Depends(get_db),
    session_2: Session = Depends(get_db),
):
    return {"session_1_id": id(session_1), "session_2_id": id(session_2)}


container = ApplicationContainer()

if __name__ == "__main__":
    uvicorn.run(
        os.path.basename(__file__).replace(".py", "") + ":app",
        host="127.0.0.1",
        port=5000,
        log_level="info",
        reload=True,
    )

see use_cache : https://fastapi.tiangolo.com/reference/dependencies/?h=use_cach#depends

{
  "session_1_id": 140567562331600,
  "session_2_id": 140567562331600
}

@Trinkes
Copy link
Author

Trinkes commented Jan 7, 2025

Thank you for sharing your solution, @alk3mist!

While your approach works for this particular case, it relies on FastAPI itself rather than the DI (Dependency Injection) framework. If I understand correctly, this means dependencies cannot be nested within containers as part of the DI setup.

For example, imagine a scenario where we want to inject a database session into two different objects (Foo and Foo2). Instead of making the request directly depend on the database session, we make it depend on Foo and Foo2. In this case, the database session injected into Foo will be different from the one injected into Foo2.

class Foo:
    def __init__(self, session):
        self.session = session

class Foo2:
    def __init__(self, session):
        self.session = session

In this setup, foo.session and foo2.session could be different, which isn't the intended behavior.

Warning

I didn't test this out

@alk3mist
Copy link

alk3mist commented Jan 7, 2025

Why can't you create your foo.. objects the same way you create a session?

def get_foo(session: Session = Depends(get_db)):
    return Foo(session)

def get_foo2(session: Session = Depends(get_db)):
    return Foo2(session)

@app.get('/foo_with_foo2/')
def read_foo_with_foo2(foo: Foo = Depends(get_foo), foo2: Foo2 = Depends(get_foo2)):
         return {"session_foo_id": id(foo.session), "session_foo2_id": id(foo2.session)}

@Trinkes
Copy link
Author

Trinkes commented Jan 7, 2025

We can, but at this point we're not using the DI framework.

@ivanbicalho
Copy link

ivanbicalho commented Feb 9, 2025

Hey!

I have been also struggling with this same thing, but after some time consuming the docs, chat gpt, StackOverflow and issues in github I think I found a really good, clear and simple way to solve it.

One of the providers we have is the Dependency provider, where you only define the type you need:

class Container(containers.DeclarativeContainer):
    uow = providers.Dependency(instance_of=UnitOfWork)

If you try to run your program without defining the actual dependency, you'll get an error like ('Dependency "Container.uow" is not defined',)

Ok, 50% of the problem solved. Now we can take advantage of the http middleware from FastAPI:

@app.middleware("http")
async def uow_dependency_injection_middleware(request: Request, call_next):
    uow = UnitOfWork()
    container.uow.override(provider=providers.Callable(lambda: uow))
    return await call_next(request)

The Callable provider will always call the callable you provided to resolve the dependency and in my case it will always return the same uow instance.

That's all you need, now your object is 100% scoped by the request. Let me know if it solves your issue and if there are any pitfalls I may have missed here.

UPDATE:
Never mind, we can't do that otherwise we will have multiple overridden objects in memory causing a memory leak. Still trying to figure it out a way to bind it with the request, not sure if it's possible yet.

@ivanbicalho
Copy link

ivanbicalho commented Feb 12, 2025

I found something promising here. In dependency injector we have the ContextLocalSingleton provider and according to the docs it's a Context-local singleton provides single objects in scope of a context.. Given the fact my API is 100% async, this provider takes advantage of ContextVars that isolates current context in asynchronous frameworks. I did some tests here and apparently it works perfectly to isolate my single instances per request.

UPDATE:
Just found out that @philipbjorge had mentioned it before! 😄

@huuvan023
Copy link

huuvan023 commented Feb 13, 2025

Here is my solution:

container.py

class DBSession(resources.Resource):
    def __init__(self):
        super().__init__()
        self._session = None

    def _create_session(self):
        if not self._session:
            self._session = get_session_maker()()

    def commit(self):
        if self._session:
            self._session.commit()

    def close(self):
        if self._session:
            self._session.close()

    def rollback(self):
        if self._session:
            self._session.rollback()

    def init(self, *args, **kwargs) -> Optional[T]:
        self._create_session()
        return self

    def shutdown(self, _: None) -> None:
        self.close()

    @property
    def session(self) -> Session:
        return self._session


class Container(containers.DeclarativeContainer):
    config = providers.Configuration()
    session = providers.Resource(DBSession)

    license_repository = providers.Factory(
        domain.repository.license.LicenseRepository, session=session.provided.session
    )
    license_service = providers.Factory(
        app.service.license.LicenseService, license_repository=license_repository
    )

    user_repository = providers.Factory(
        domain.repository.user.UserRepository, session=session.provided.session
    )
    user_service = providers.Factory(
        app.service.user.UserProfilesService, user_repository=user_repository
    )

dependency/session.py

from contextlib import contextmanager
from typing import Generator

def session_initializer_depend() -> Generator[None, None, None]:
    from di.config import container

    session = container.session()
    try:
        yield
        session.commit()
    except Exception as e:
        session.rollback()
        raise e
    finally:
        session.shutdown(None)

@contextmanager
def session_initializer() -> Generator[None, None, None]:
    from di.config import container

    session = container.session()
    try:
        yield
        session.commit()
    except Exception as e:
        session.rollback()
        raise e
    finally:
        session.shutdown(None)

container = Container() is defined in your main file when you init the app.

Now you can use it like this:

- Use as dependency in your router:

@router.post("/signup", dependencies=[Depends(session_initializer_depend)])
@inject
def sign_up(
    user_service: UserProfilesService = Depends(Provide[Container.user_service]),
    license_service: LicenseService = Depends(Provide[Container.license_service]),
):
    # Your logic goes here

- Use with with(...) (in background task or something like this):

@inject
def task_abc(
    user_service: UserProfilesService = Depends(Provide[Container.user_service]),
    license_service: LicenseService = Depends(Provide[Container.license_service]),
):
    db_id_1 = user_service.test_license()
    db_id_2 = license_service.test_license()
    user_service.create_user_profile()
    return {
        "db_id_1": db_id_1,
        "db_id_2": db_id_2,
    }

def sample_task_wrapper():
    with session_initializer():
        return task_abc()

@router.post("/signup")
def sign_up():
    sample_task_wrapper()
    return True

How do you guys think about this?

@huuvan023
Copy link

huuvan023 commented Feb 14, 2025

UPDATED
I found new way shortened to do with this

#get_session.py

def get_session():
    session = get_session_maker()()
    yield session
    if session:
        session.close()

In Container, we define get_session as Resource
#get_session.py

class Container(containers.DeclarativeContainer):
    session = providers.Resource(get_session)
    user_repository = providers.Factory(UserRepository, session=session)
    user_profile_service = providers.Factory(
        UserProfilesService, user_repository=user_repository
    )

Here the guide to use it:
#dependency.py

def session_initializer_depend() -> Generator[None, None, None]:
    from main import container

    container.session.init()
    try:
        yield
        container.session().commit()
    except Exception as e:
        container.session().rollback()
        raise e
    finally:
        container.session.shutdown()

#router.py

user_router = APIRouter(dependencies=[Depends(session_initializer_depend)])

@user_router.post("/signup", response_model=SuccessResponse)
@inject
def sign_up(
    request: SignUpDTO,
    user_facade: UserFacade = Depends(Provide[Container.user_facade]),
):
    return SuccessResponse(
        data=user_facade.sign_up(
            request.email, request.password, request.first_name, request.last_name
        )
    )

For info you can read here https://python-dependency-injector.ets-labs.org/providers/resource.html#generator-initializer

@amoncusir
Copy link

amoncusir commented Feb 22, 2025

I just created a PR to manage different resources using the init and shutdown methods from container. It can be used to create a scoped resources and manage different lifecycles among them. The tests show how can do it. I hope can helps and be useful!

PR: #858

@dandiep
Copy link

dandiep commented Mar 15, 2025

Hi all, I'm wondering if you can just use ContextSingleton for this and then reset it at the end of the request.

E.g, my database:

class Database(BaseService):
    def __init__(self, db_url: str, echo: bool) -> None:
        super().__init__()

        self._engine = create_engine(db_url, echo=False, future=True)
        # self._session_factory = scoped_session(sessionmaker(bind=self._engine, future=True, class_=_Session))
        self._session_factory = sessionmaker(bind=self._engine, future=True, class_=_Session, expire_on_commit=False)
        self.logger.info("Loaded DB")

    @contextmanager
    def session_factory(self, autoclose: bool = True) -> Session:
        session: Session = self._session_factory()
        try:
            logger.info("Starting session %s " % session.__hash__())
            yield session
            logger.info("Committing session %s" % session.__hash__())
            session.expunge_all()
            session.commit()
        except BaseException as e:
            logger.exception("Session rollback because of exception", exc_info=e)
            session.rollback()
            raise e
        finally:
            if autoclose:
                session.close()

Services:

class Container(containers.DeclarativeContainer):

    __self__ = providers.Self()

    config = providers.Configuration()

    database = providers.Singleton(Database, db_url=config.database.url, echo=config.database.echo.as_(bool))

    # per request services
    user_service = providers.ContextLocalSingleton(UserService, session=database.provided.session.call())

FastAPI request:

    @app.middleware("http")
    async def manage_session(request: Request, call_next):
        response = None
        try:
            for p in container.traverse([providers.ContextLocalSingleton]):
                p.reset()

            with database.session_factory() as session:
                try:
                    response = await call_next(request)
                finally:
                    if response and response.status_code >= 400 and session.is_active:
                        session.rollback()
....

        return response

The weird thing is that I tried doing the reset at the end of the request and it didn't work, but it seems to at the beginning.

Is this a viable approach? Why wouldn't resetting at the end of the request work?

EDIT: one reason why resetting at the end may not have worked for me is that I have background tasks. Putting the reset up front though seems to be working and giving me the desired behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests