Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ async def connect_check_health(
raise TimeoutError("Timeout connecting to server")
except OSError as e:
raise ConnectionError(self._error_message(e))
except ConnectionError:
raise
except Exception as exc:
raise ConnectionError(exc) from exc

Expand Down
39 changes: 15 additions & 24 deletions redis/asyncio/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,22 @@ def __repr__(self):
s += host_info
return s + ")>"

async def connect_to(self, address):
self.host, self.port = address
await self.connect_check_health(
check_health=self.connection_pool.check_connection,
retry_socket_connect=False,
)

async def _connect_retry(self):
if self._reader:
return # already connected
async def _connect(self):
if self.is_connected:
await super()._connect()
return None
Comment on lines +44 to +45
Copy link

Copilot AI Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here is incorrect. When self.is_connected is True, the connection is already established and the method should simply return without calling super()._connect(). The check should just return early:

if self.is_connected:
    return None  # already connected

Calling await super()._connect() when already connected could cause issues as it attempts to set self._reader and self._writer again.

Suggested change
await super()._connect()
return None
return None # already connected

Copilot uses AI. Check for mistakes.
if self.connection_pool.is_master:
await self.connect_to(await self.connection_pool.get_master_address())
else:
async for slave in self.connection_pool.rotate_slaves():
try:
return await self.connect_to(slave)
except ConnectionError:
continue
raise SlaveNotFoundError # Never be here

async def connect(self):
return await self.retry.call_with_retry(
self._connect_retry,
lambda error: asyncio.sleep(0),
)
self.host, self.port = await self.connection_pool.get_master_address()
await super()._connect()
return None
async for slave in self.connection_pool.rotate_slaves():
try:
self.host, self.port = slave
await super()._connect()
return None
except ConnectionError:
continue
raise SlaveNotFoundError # Never be here

async def read_response(
self,
Expand Down
32 changes: 11 additions & 21 deletions redis/sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,19 @@ def __repr__(self):
s = s % host_info
return s

def connect_to(self, address):
self.host, self.port = address

self.connect_check_health(
check_health=self.connection_pool.check_connection,
retry_socket_connect=False,
)

def _connect_retry(self):
def _connect(self):
if self._sock:
return # already connected
return super()._connect() # already connected
if self.connection_pool.is_master:
self.connect_to(self.connection_pool.get_master_address())
else:
for slave in self.connection_pool.rotate_slaves():
try:
return self.connect_to(slave)
except ConnectionError:
continue
raise SlaveNotFoundError # Never be here

def connect(self):
return self.retry.call_with_retry(self._connect_retry, lambda error: None)
self.host, self.port = self.connection_pool.get_master_address()
return super()._connect()
for slave in self.connection_pool.rotate_slaves():
try:
self.host, self.port = slave
return super()._connect()
except ConnectionError:
continue
raise SlaveNotFoundError # Never be here

def read_response(
self,
Expand Down
64 changes: 49 additions & 15 deletions tests/test_asyncio/test_sentinel_managed_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from unittest import mock

import pytest

from redis.asyncio import Connection
from redis.asyncio.retry import Retry
from redis.asyncio.sentinel import SentinelManagedConnection
from redis.backoff import NoBackoff
Expand All @@ -20,18 +22,50 @@ async def test_connect_retry_on_timeout_error(connect_args):
retry=Retry(NoBackoff(), 3),
connection_pool=connection_pool,
)
origin_connect = conn._connect
conn._connect = mock.AsyncMock()

async def mock_connect():
# connect only on the last retry
if conn._connect.call_count <= 2:
raise socket.timeout
else:
return await origin_connect()

conn._connect.side_effect = mock_connect
await conn.connect()
assert conn._connect.call_count == 3
assert connection_pool.get_master_address.call_count == 3
await conn.disconnect()
original_super_connect = Connection._connect.__get__(conn, Connection)

with mock.patch.object(
Connection, "_connect", new_callable=mock.AsyncMock
) as mock_super_connect:

async def side_effect(*args, **kwargs):
if mock_super_connect.await_count <= 2:
raise socket.timeout()
return await original_super_connect(*args, **kwargs)

mock_super_connect.side_effect = side_effect

await conn.connect()
assert mock_super_connect.await_count == 3
assert connection_pool.get_master_address.call_count == 3
await conn.disconnect()


async def test_connect_check_health_retry_on_timeout_error(connect_args):
"""Test that the _connect function is retried in case of a timeout"""
connection_pool = mock.AsyncMock()
connection_pool.get_master_address = mock.AsyncMock(
return_value=(connect_args["host"], connect_args["port"])
)
conn = SentinelManagedConnection(
retry_on_timeout=True,
retry=Retry(NoBackoff(), 3),
connection_pool=connection_pool,
)
original_super_connect = Connection._connect.__get__(conn, Connection)

with mock.patch.object(
Connection, "_connect", new_callable=mock.AsyncMock
) as mock_super_connect:

async def side_effect(*args, **kwargs):
if mock_super_connect.await_count <= 2:
raise socket.timeout()
return await original_super_connect(*args, **kwargs)

mock_super_connect.side_effect = side_effect

await conn.connect_check_health()
assert mock_super_connect.await_count == 3
assert connection_pool.get_master_address.call_count == 3
await conn.disconnect()
65 changes: 49 additions & 16 deletions tests/test_sentinel_managed_connection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import socket
from unittest import mock

from redis import Connection
from redis.retry import Retry
from redis.sentinel import SentinelManagedConnection
from redis.backoff import NoBackoff
from unittest import mock


def test_connect_retry_on_timeout_error(master_host):
Expand All @@ -17,18 +18,50 @@ def test_connect_retry_on_timeout_error(master_host):
retry=Retry(NoBackoff(), 3),
connection_pool=connection_pool,
)
origin_connect = conn._connect
conn._connect = mock.Mock()

def mock_connect():
# connect only on the last retry
if conn._connect.call_count <= 2:
raise socket.timeout
else:
return origin_connect()

conn._connect.side_effect = mock_connect
conn.connect()
assert conn._connect.call_count == 3
assert connection_pool.get_master_address.call_count == 3
conn.disconnect()
original_super_connect = Connection._connect.__get__(conn, Connection)

with mock.patch.object(
Connection, "_connect", new_callable=mock.Mock
) as mock_super_connect:

def side_effect(*args, **kwargs):
if mock_super_connect.call_count <= 2:
raise socket.timeout()
return original_super_connect(*args, **kwargs)

mock_super_connect.side_effect = side_effect

conn.connect()
assert mock_super_connect.call_count == 3
assert connection_pool.get_master_address.call_count == 3
conn.disconnect()


def test_connect_check_health_retry_on_timeout_error(master_host):
"""Test that the _connect function is retried in case of a timeout"""
connection_pool = mock.Mock()
connection_pool.get_master_address = mock.Mock(
return_value=(master_host[0], master_host[1])
)
conn = SentinelManagedConnection(
retry_on_timeout=True,
retry=Retry(NoBackoff(), 3),
connection_pool=connection_pool,
)
original_super_connect = Connection._connect.__get__(conn, Connection)

with mock.patch.object(
Connection, "_connect", new_callable=mock.Mock
) as mock_super_connect:

def side_effect(*args, **kwargs):
if mock_super_connect.call_count <= 2:
raise socket.timeout()
return original_super_connect(*args, **kwargs)

mock_super_connect.side_effect = side_effect

conn.connect_check_health()
assert mock_super_connect.call_count == 3
assert connection_pool.get_master_address.call_count == 3
conn.disconnect()