diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 1d50b53ee2..72f5d75dac 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -323,6 +323,8 @@ async def connect_check_health( raise TimeoutError("Timeout connecting to server") except OSError as e: raise ConnectionError(self._error_message(e)) + except ConnectionError: + raise except Exception as exc: raise ConnectionError(exc) from exc diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py index d0455ab6eb..9d8c170556 100644 --- a/redis/asyncio/sentinel.py +++ b/redis/asyncio/sentinel.py @@ -39,31 +39,22 @@ def __repr__(self): s += host_info return s + ")>" - async def connect_to(self, address): - self.host, self.port = address - await self.connect_check_health( - check_health=self.connection_pool.check_connection, - retry_socket_connect=False, - ) - - async def _connect_retry(self): - if self._reader: - return # already connected + async def _connect(self): + if self.is_connected: + await super()._connect() + return None if self.connection_pool.is_master: - await self.connect_to(await self.connection_pool.get_master_address()) - else: - async for slave in self.connection_pool.rotate_slaves(): - try: - return await self.connect_to(slave) - except ConnectionError: - continue - raise SlaveNotFoundError # Never be here - - async def connect(self): - return await self.retry.call_with_retry( - self._connect_retry, - lambda error: asyncio.sleep(0), - ) + self.host, self.port = await self.connection_pool.get_master_address() + await super()._connect() + return None + async for slave in self.connection_pool.rotate_slaves(): + try: + self.host, self.port = slave + await super()._connect() + return None + except ConnectionError: + continue + raise SlaveNotFoundError # Never be here async def read_response( self, diff --git a/redis/sentinel.py b/redis/sentinel.py index f12bd8dd5d..edc47b5540 100644 --- a/redis/sentinel.py +++ b/redis/sentinel.py @@ -37,29 +37,19 @@ def __repr__(self): s = s % host_info return s - def connect_to(self, address): - self.host, self.port = address - - self.connect_check_health( - check_health=self.connection_pool.check_connection, - retry_socket_connect=False, - ) - - def _connect_retry(self): + def _connect(self): if self._sock: - return # already connected + return super()._connect() # already connected if self.connection_pool.is_master: - self.connect_to(self.connection_pool.get_master_address()) - else: - for slave in self.connection_pool.rotate_slaves(): - try: - return self.connect_to(slave) - except ConnectionError: - continue - raise SlaveNotFoundError # Never be here - - def connect(self): - return self.retry.call_with_retry(self._connect_retry, lambda error: None) + self.host, self.port = self.connection_pool.get_master_address() + return super()._connect() + for slave in self.connection_pool.rotate_slaves(): + try: + self.host, self.port = slave + return super()._connect() + except ConnectionError: + continue + raise SlaveNotFoundError # Never be here def read_response( self, diff --git a/tests/test_asyncio/test_sentinel_managed_connection.py b/tests/test_asyncio/test_sentinel_managed_connection.py index 5a511b2793..cec55cdc1c 100644 --- a/tests/test_asyncio/test_sentinel_managed_connection.py +++ b/tests/test_asyncio/test_sentinel_managed_connection.py @@ -2,6 +2,8 @@ from unittest import mock import pytest + +from redis.asyncio import Connection from redis.asyncio.retry import Retry from redis.asyncio.sentinel import SentinelManagedConnection from redis.backoff import NoBackoff @@ -20,18 +22,50 @@ async def test_connect_retry_on_timeout_error(connect_args): retry=Retry(NoBackoff(), 3), connection_pool=connection_pool, ) - origin_connect = conn._connect - conn._connect = mock.AsyncMock() - - async def mock_connect(): - # connect only on the last retry - if conn._connect.call_count <= 2: - raise socket.timeout - else: - return await origin_connect() - - conn._connect.side_effect = mock_connect - await conn.connect() - assert conn._connect.call_count == 3 - assert connection_pool.get_master_address.call_count == 3 - await conn.disconnect() + original_super_connect = Connection._connect.__get__(conn, Connection) + + with mock.patch.object( + Connection, "_connect", new_callable=mock.AsyncMock + ) as mock_super_connect: + + async def side_effect(*args, **kwargs): + if mock_super_connect.await_count <= 2: + raise socket.timeout() + return await original_super_connect(*args, **kwargs) + + mock_super_connect.side_effect = side_effect + + await conn.connect() + assert mock_super_connect.await_count == 3 + assert connection_pool.get_master_address.call_count == 3 + await conn.disconnect() + + +async def test_connect_check_health_retry_on_timeout_error(connect_args): + """Test that the _connect function is retried in case of a timeout""" + connection_pool = mock.AsyncMock() + connection_pool.get_master_address = mock.AsyncMock( + return_value=(connect_args["host"], connect_args["port"]) + ) + conn = SentinelManagedConnection( + retry_on_timeout=True, + retry=Retry(NoBackoff(), 3), + connection_pool=connection_pool, + ) + original_super_connect = Connection._connect.__get__(conn, Connection) + + with mock.patch.object( + Connection, "_connect", new_callable=mock.AsyncMock + ) as mock_super_connect: + + async def side_effect(*args, **kwargs): + if mock_super_connect.await_count <= 2: + raise socket.timeout() + return await original_super_connect(*args, **kwargs) + + mock_super_connect.side_effect = side_effect + + await conn.connect_check_health() + assert mock_super_connect.await_count == 3 + assert connection_pool.get_master_address.call_count == 3 + await conn.disconnect() diff --git a/tests/test_sentinel_managed_connection.py b/tests/test_sentinel_managed_connection.py index 6fe5f7cd5b..6ac16c76dd 100644 --- a/tests/test_sentinel_managed_connection.py +++ b/tests/test_sentinel_managed_connection.py @@ -1,9 +1,10 @@ import socket +from unittest import mock +from redis import Connection from redis.retry import Retry from redis.sentinel import SentinelManagedConnection from redis.backoff import NoBackoff -from unittest import mock def test_connect_retry_on_timeout_error(master_host): @@ -17,18 +18,50 @@ def test_connect_retry_on_timeout_error(master_host): retry=Retry(NoBackoff(), 3), connection_pool=connection_pool, ) - origin_connect = conn._connect - conn._connect = mock.Mock() - - def mock_connect(): - # connect only on the last retry - if conn._connect.call_count <= 2: - raise socket.timeout - else: - return origin_connect() - - conn._connect.side_effect = mock_connect - conn.connect() - assert conn._connect.call_count == 3 - assert connection_pool.get_master_address.call_count == 3 - conn.disconnect() + original_super_connect = Connection._connect.__get__(conn, Connection) + + with mock.patch.object( + Connection, "_connect", new_callable=mock.Mock + ) as mock_super_connect: + + def side_effect(*args, **kwargs): + if mock_super_connect.call_count <= 2: + raise socket.timeout() + return original_super_connect(*args, **kwargs) + + mock_super_connect.side_effect = side_effect + + conn.connect() + assert mock_super_connect.call_count == 3 + assert connection_pool.get_master_address.call_count == 3 + conn.disconnect() + + +def test_connect_check_health_retry_on_timeout_error(master_host): + """Test that the _connect function is retried in case of a timeout""" + connection_pool = mock.Mock() + connection_pool.get_master_address = mock.Mock( + return_value=(master_host[0], master_host[1]) + ) + conn = SentinelManagedConnection( + retry_on_timeout=True, + retry=Retry(NoBackoff(), 3), + connection_pool=connection_pool, + ) + original_super_connect = Connection._connect.__get__(conn, Connection) + + with mock.patch.object( + Connection, "_connect", new_callable=mock.Mock + ) as mock_super_connect: + + def side_effect(*args, **kwargs): + if mock_super_connect.call_count <= 2: + raise socket.timeout() + return original_super_connect(*args, **kwargs) + + mock_super_connect.side_effect = side_effect + + conn.connect_check_health() + assert mock_super_connect.call_count == 3 + assert connection_pool.get_master_address.call_count == 3 + conn.disconnect()