|
7 | 7 | import pytest_asyncio |
8 | 8 | from aiohttp import ClientSession |
9 | 9 | from alembic.config import Config |
10 | | -from sqlalchemy import create_engine, inspect, MetaData |
| 10 | +from sqlalchemy import create_engine, inspect, MetaData, Engine |
| 11 | +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine |
11 | 12 | from sqlalchemy.orm import scoped_session, sessionmaker |
12 | 13 |
|
13 | 14 | from src.core.env_var_manager import EnvVarManager |
| 15 | +from src.db.client.async_ import AsyncDatabaseClient |
| 16 | +from src.db.client.sync import DatabaseClient |
| 17 | +from src.db.helpers.connect import get_postgres_connection_string |
| 18 | +from src.db.models.impl.log.sqlalchemy import Log # noqa: F401 |
14 | 19 | # Below are to prevent import errors |
15 | 20 | from src.db.models.impl.missing import Missing # noqa: F401 |
16 | | -from src.db.models.impl.log.sqlalchemy import Log # noqa: F401 |
17 | 21 | from src.db.models.impl.task.error import TaskError # noqa: F401 |
18 | 22 | from src.db.models.impl.url.checked_for_duplicate import URLCheckedForDuplicate # noqa: F401 |
19 | | -from src.db.client.async_ import AsyncDatabaseClient |
20 | | -from src.db.client.sync import DatabaseClient |
21 | | -from src.db.helpers.connect import get_postgres_connection_string |
22 | 23 | from src.util.helper_functions import load_from_environment |
23 | 24 | from tests.helpers.alembic_runner import AlembicRunner |
24 | 25 | from tests.helpers.data_creator.core import DBDataCreator |
@@ -99,33 +100,55 @@ def setup_and_teardown(): |
99 | 100 | live_connection.close() |
100 | 101 | engine.dispose() |
101 | 102 |
|
| 103 | +@pytest.fixture(scope="session") |
| 104 | +def engine(): |
| 105 | + conn = get_postgres_connection_string() |
| 106 | + engine = create_engine(conn) |
| 107 | + yield engine |
| 108 | + engine.dispose() |
| 109 | + |
| 110 | +@pytest.fixture(scope="session") |
| 111 | +def async_engine(): |
| 112 | + conn = get_postgres_connection_string(is_async=True) |
| 113 | + engine = create_async_engine(conn) |
| 114 | + yield engine |
| 115 | + engine.dispose() |
| 116 | + |
102 | 117 | @pytest.fixture |
103 | | -def wiped_database(): |
| 118 | +def wiped_database( |
| 119 | + engine: Engine |
| 120 | +): |
104 | 121 | """Wipe all data from database.""" |
105 | | - wipe_database(get_postgres_connection_string()) |
| 122 | + wipe_database(engine) |
106 | 123 |
|
107 | 124 |
|
108 | 125 |
|
109 | 126 | @pytest.fixture |
110 | | -def db_client_test(wiped_database) -> Generator[DatabaseClient, Any, None]: |
| 127 | +def db_client_test( |
| 128 | + wiped_database, |
| 129 | + engine |
| 130 | +) -> Generator[DatabaseClient, Any, None]: |
111 | 131 | # Drop pre-existing table |
112 | | - conn = get_postgres_connection_string() |
113 | | - db_client = DatabaseClient(db_url=conn) |
| 132 | + db_client = DatabaseClient(engine) |
114 | 133 | yield db_client |
115 | 134 | db_client.engine.dispose() |
116 | 135 |
|
117 | 136 | @pytest_asyncio.fixture |
118 | | -async def populated_database(wiped_database) -> None: |
119 | | - conn = get_postgres_connection_string(is_async=True) |
120 | | - adb_client = AsyncDatabaseClient(db_url=conn) |
| 137 | +async def populated_database( |
| 138 | + wiped_database, |
| 139 | + async_engine: AsyncEngine |
| 140 | +) -> None: |
| 141 | + adb_client = AsyncDatabaseClient(async_engine) |
121 | 142 | await populate_database(adb_client) |
122 | 143 |
|
123 | 144 | @pytest_asyncio.fixture |
124 | | -async def adb_client_test(wiped_database) -> AsyncGenerator[AsyncDatabaseClient, Any]: |
125 | | - conn = get_postgres_connection_string(is_async=True) |
126 | | - adb_client = AsyncDatabaseClient(db_url=conn) |
| 145 | +async def adb_client_test( |
| 146 | + wiped_database, |
| 147 | + async_engine: AsyncEngine |
| 148 | +) -> AsyncGenerator[AsyncDatabaseClient, Any]: |
| 149 | + adb_client = AsyncDatabaseClient(async_engine) |
127 | 150 | yield adb_client |
128 | | - adb_client.engine.dispose() |
| 151 | + await adb_client.engine.dispose() |
129 | 152 |
|
130 | 153 | @pytest.fixture |
131 | 154 | def db_data_creator( |
|
0 commit comments