Skip to content

Commit c4052a5

Browse files
committed
Fix connection creation leak
1 parent b9dafff commit c4052a5

File tree

7 files changed

+71
-45
lines changed

7 files changed

+71
-45
lines changed

local_database/docker-compose.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,5 @@ services:
1212
- POSTGRES_PASSWORD=HanviliciousHamiltonHilltops
1313
- POSTGRES_USER=test_source_collector_user
1414
- POSTGRES_DB=source_collector_test_db
15-
command: ['postgres', '-c', 'max_connections=160']
1615
volumes:
1716
dbscripts:

src/api/main.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from discord_poster import DiscordPoster
66
from fastapi import FastAPI
77
from pdap_access_manager import AccessManager
8+
from sqlalchemy.ext.asyncio import create_async_engine
89
from starlette.responses import RedirectResponse
910

1011
from src.api.endpoints.agencies.routes import agencies_router
@@ -52,12 +53,9 @@ async def lifespan(app: FastAPI):
5253
env.read_env()
5354

5455
# Initialize shared dependencies
55-
db_client = DatabaseClient(
56-
db_url=env_var_manager.get_postgres_connection_string()
57-
)
58-
adb_client = AsyncDatabaseClient(
59-
db_url=env_var_manager.get_postgres_connection_string(is_async=True)
60-
)
56+
57+
db_client = DatabaseClient()
58+
adb_client = AsyncDatabaseClient()
6159
await setup_database(db_client)
6260
core_logger = AsyncCoreLogger(adb_client=adb_client)
6361

src/db/client/async_.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from functools import wraps
33
from typing import Optional, Type, Any, List, Sequence
44

5-
from sqlalchemy import select, func, Select, and_, update, Row, text
6-
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
5+
from sqlalchemy import select, func, Select, and_, update, Row, text, Engine
6+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker, AsyncEngine
77
from sqlalchemy.orm import selectinload
88

99
from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse
@@ -103,15 +103,15 @@
103103

104104

105105
class AsyncDatabaseClient:
106-
def __init__(self, db_url: str | None = None):
107-
if db_url is None:
106+
def __init__(self, engine: AsyncEngine | None = None):
107+
if engine is None:
108108
db_url = EnvVarManager.get().get_postgres_connection_string(is_async=True)
109-
self.db_url = db_url
110-
echo = ConfigManager.get_sqlalchemy_echo()
111-
self.engine = create_async_engine(
112-
url=db_url,
113-
echo=echo,
114-
)
109+
echo = ConfigManager.get_sqlalchemy_echo()
110+
engine = create_async_engine(
111+
url=db_url,
112+
echo=echo,
113+
)
114+
self.engine = engine
115115
self.session_maker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
116116
self.statement_composer = StatementComposer()
117117

src/db/client/sync.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from functools import wraps
22
from typing import List
33

4-
from sqlalchemy import create_engine, Select
4+
from sqlalchemy import create_engine, Select, Engine
55
from sqlalchemy.exc import IntegrityError
66
from sqlalchemy.orm import sessionmaker, scoped_session, Session
77

@@ -28,15 +28,19 @@
2828

2929
# Database Client
3030
class DatabaseClient:
31-
def __init__(self, db_url: str | None = None):
31+
def __init__(
32+
self,
33+
engine: Engine | None = None
34+
):
3235
"""Initialize the DatabaseClient."""
33-
if db_url is None:
36+
if engine is None:
3437
db_url = EnvVarManager.get().get_postgres_connection_string(is_async=True)
38+
engine = create_engine(
39+
url=db_url,
40+
echo=ConfigManager.get_sqlalchemy_echo(),
41+
)
3542

36-
self.engine = create_engine(
37-
url=db_url,
38-
echo=ConfigManager.get_sqlalchemy_echo(),
39-
)
43+
self.engine = engine
4044
self.session_maker = scoped_session(sessionmaker(bind=self.engine))
4145
self.session = None
4246

tests/automated/integration/readonly/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
import pytest_asyncio
6+
from sqlalchemy import Engine
67
from starlette.testclient import TestClient
78

89
from src.db.helpers.connect import get_postgres_connection_string
@@ -33,8 +34,10 @@ async def california_readonly(
3334
async def readonly_helper(
3435
event_loop,
3536
client: TestClient,
37+
engine: Engine
38+
3639
) -> AsyncGenerator[ReadOnlyTestHelper, Any]:
37-
wipe_database(get_postgres_connection_string())
40+
wipe_database(engine)
3841
db_data_creator = DBDataCreator()
3942
api_test_helper = APITestHelper(
4043
request_validator=RequestValidator(client=client),

tests/conftest.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
import pytest_asyncio
88
from aiohttp import ClientSession
99
from alembic.config import Config
10-
from sqlalchemy import create_engine, inspect, MetaData
10+
from sqlalchemy import create_engine, inspect, MetaData, Engine
11+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
1112
from sqlalchemy.orm import scoped_session, sessionmaker
1213

1314
from src.core.env_var_manager import EnvVarManager
15+
from src.db.client.async_ import AsyncDatabaseClient
16+
from src.db.client.sync import DatabaseClient
17+
from src.db.helpers.connect import get_postgres_connection_string
18+
from src.db.models.impl.log.sqlalchemy import Log # noqa: F401
1419
# Below are to prevent import errors
1520
from src.db.models.impl.missing import Missing # noqa: F401
16-
from src.db.models.impl.log.sqlalchemy import Log # noqa: F401
1721
from src.db.models.impl.task.error import TaskError # noqa: F401
1822
from src.db.models.impl.url.checked_for_duplicate import URLCheckedForDuplicate # noqa: F401
19-
from src.db.client.async_ import AsyncDatabaseClient
20-
from src.db.client.sync import DatabaseClient
21-
from src.db.helpers.connect import get_postgres_connection_string
2223
from src.util.helper_functions import load_from_environment
2324
from tests.helpers.alembic_runner import AlembicRunner
2425
from tests.helpers.data_creator.core import DBDataCreator
@@ -99,33 +100,55 @@ def setup_and_teardown():
99100
live_connection.close()
100101
engine.dispose()
101102

103+
@pytest.fixture(scope="session")
104+
def engine():
105+
conn = get_postgres_connection_string()
106+
engine = create_engine(conn)
107+
yield engine
108+
engine.dispose()
109+
110+
@pytest.fixture(scope="session")
111+
def async_engine():
112+
conn = get_postgres_connection_string(is_async=True)
113+
engine = create_async_engine(conn)
114+
yield engine
115+
engine.dispose()
116+
102117
@pytest.fixture
103-
def wiped_database():
118+
def wiped_database(
119+
engine: Engine
120+
):
104121
"""Wipe all data from database."""
105-
wipe_database(get_postgres_connection_string())
122+
wipe_database(engine)
106123

107124

108125

109126
@pytest.fixture
110-
def db_client_test(wiped_database) -> Generator[DatabaseClient, Any, None]:
127+
def db_client_test(
128+
wiped_database,
129+
engine
130+
) -> Generator[DatabaseClient, Any, None]:
111131
# Drop pre-existing table
112-
conn = get_postgres_connection_string()
113-
db_client = DatabaseClient(db_url=conn)
132+
db_client = DatabaseClient(engine)
114133
yield db_client
115134
db_client.engine.dispose()
116135

117136
@pytest_asyncio.fixture
118-
async def populated_database(wiped_database) -> None:
119-
conn = get_postgres_connection_string(is_async=True)
120-
adb_client = AsyncDatabaseClient(db_url=conn)
137+
async def populated_database(
138+
wiped_database,
139+
async_engine: AsyncEngine
140+
) -> None:
141+
adb_client = AsyncDatabaseClient(async_engine)
121142
await populate_database(adb_client)
122143

123144
@pytest_asyncio.fixture
124-
async def adb_client_test(wiped_database) -> AsyncGenerator[AsyncDatabaseClient, Any]:
125-
conn = get_postgres_connection_string(is_async=True)
126-
adb_client = AsyncDatabaseClient(db_url=conn)
145+
async def adb_client_test(
146+
wiped_database,
147+
async_engine: AsyncEngine
148+
) -> AsyncGenerator[AsyncDatabaseClient, Any]:
149+
adb_client = AsyncDatabaseClient(async_engine)
127150
yield adb_client
128-
adb_client.engine.dispose()
151+
await adb_client.engine.dispose()
129152

130153
@pytest.fixture
131154
def db_data_creator(

tests/helpers/setup/wipe.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from sqlalchemy import create_engine
1+
from sqlalchemy import create_engine, Engine
22

33
from src.db.models.templates_.base import Base
44

55

6-
def wipe_database(connection_string: str) -> None:
6+
def wipe_database(engine: Engine) -> None:
77
"""Wipe all data from database."""
8-
engine = create_engine(connection_string)
98
with engine.connect() as connection:
109
for table in reversed(Base.metadata.sorted_tables):
1110
if table.info == "view":

0 commit comments

Comments
 (0)