diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 884110cd45..d2eed40bac 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -20,6 +20,8 @@ from test.asynchronous.utils_spec_runner import AsyncSpecRunner from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket +from pymongo.asynchronous.pool import PoolState +from pymongo.server_selectors import writable_server_selector sys.path[0:0] = [""] @@ -39,6 +41,7 @@ from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.helpers import anext from pymongo.errors import ( + AutoReconnect, CollectionInvalid, ConfigurationError, ConnectionFailure, @@ -386,6 +389,22 @@ async def find_raw_batches(*args, **kwargs): if isinstance(res, (AsyncCommandCursor, AsyncCursor)): await res.to_list() + @async_client_context.require_transactions + async def test_transaction_pool_cleared_error_labelled_transient(self): + c = await self.async_single_client() + + with self.assertRaises(AutoReconnect) as context: + async with c.start_session() as session: + async with await session.start_transaction(): + server = await c._select_server(writable_server_selector, session, "test") + # Pause the server's pool, causing it to fail connection checkout. + server.pool.state = PoolState.PAUSED + async with c._checkout(server, session): + pass + + # Verify that the TransientTransactionError label is present in the error. + self.assertTrue(context.exception.has_error_label("TransientTransactionError")) + class PatchSessionTimeout: """Patches the client_session's with_transaction timeout for testing.""" diff --git a/test/test_transactions.py b/test/test_transactions.py index 80b3e3765e..b883e88efc 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -20,6 +20,8 @@ from test.utils_spec_runner import SpecRunner from gridfs.synchronous.grid_file import GridFS, GridFSBucket +from pymongo.server_selectors import writable_server_selector +from pymongo.synchronous.pool import PoolState sys.path[0:0] = [""] @@ -34,6 +36,7 @@ from bson.raw_bson import RawBSONDocument from pymongo import WriteConcern from pymongo.errors import ( + AutoReconnect, CollectionInvalid, ConfigurationError, ConnectionFailure, @@ -378,6 +381,22 @@ def find_raw_batches(*args, **kwargs): if isinstance(res, (CommandCursor, Cursor)): res.to_list() + @client_context.require_transactions + def test_transaction_pool_cleared_error_labelled_transient(self): + c = self.single_client() + + with self.assertRaises(AutoReconnect) as context: + with c.start_session() as session: + with session.start_transaction(): + server = c._select_server(writable_server_selector, session, "test") + # Pause the server's pool, causing it to fail connection checkout. + server.pool.state = PoolState.PAUSED + with c._checkout(server, session): + pass + + # Verify that the TransientTransactionError label is present in the error. + self.assertTrue(context.exception.has_error_label("TransientTransactionError")) + class PatchSessionTimeout: """Patches the client_session's with_transaction timeout for testing."""