diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 20235362..ed42d479 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -262,6 +262,15 @@ async def connect_async( KeyError: Unsupported database driver Must be one of pymysql, asyncpg, pg8000, and pytds. """ + # check if event loop is running in current thread + if self._loop != asyncio.get_running_loop(): + raise ConnectorLoopError( + "Running event loop does not match 'connector._loop'. " + "Connector.connect_async() must be called from the event loop " + "the Connector was initialized with. If you need to connect " + "across event loops, please use a new Connector object." + ) + if self._keys is None: self._keys = asyncio.create_task(generate_keys()) if self._client is None: diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 7a59b448..95423fca 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -15,6 +15,7 @@ """ import asyncio +from threading import Thread from typing import Union from aiohttp import ClientResponseError @@ -274,6 +275,38 @@ async def test_Connector_connect_async( assert connection is True +@pytest.mark.asyncio +async def test_Connector_connect_async_multiple_event_loops( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """Test that Connector.connect_async errors when run on wrong event loop.""" + + new_loop = asyncio.new_event_loop() + thread = Thread(target=new_loop.run_forever, daemon=True) + thread.start() + + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + connector._client = fake_client + with pytest.raises(ConnectorLoopError) as exc_info: + future = asyncio.run_coroutine_threadsafe( + connector.connect_async( + "test-project:test-region:test-instance", "asyncpg" + ), + loop=new_loop, + ) + future.result() + assert ( + exc_info.value.args[0] == "Running event loop does not match " + "'connector._loop'. Connector.connect_async() must be called from " + "the event loop the Connector was initialized with. If you need to " + "connect across event loops, please use a new Connector object." + ) + new_loop.call_soon_threadsafe(new_loop.stop) + thread.join() + + @pytest.mark.asyncio async def test_create_async_connector(fake_credentials: Credentials) -> None: """Test that create_async_connector properly initializes connector