Skip to content

Commit 92ea0de

Browse files
authored
Fix exception in bolt handshake when log level <10 (#1141)
* Fix exception in bolt handshake when log level <10 * Fix another log level check being backwards * Add type hints to internal functions
1 parent 1714bc2 commit 92ea0de

File tree

10 files changed

+436
-63
lines changed

10 files changed

+436
-63
lines changed

src/neo4j/_async/io/_bolt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def __init_subclass__(cls: type[te.Self], **kwargs: t.Any) -> None:
292292

293293
# [bolt-version-bump] search tag when changing bolt version support
294294
@classmethod
295-
def get_handshake(cls):
295+
def get_handshake(cls) -> bytes:
296296
"""
297297
Return the supported Bolt versions as bytes.
298298

src/neo4j/_async/io/_bolt_socket.py

+49-29
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,15 @@
3939

4040

4141
if t.TYPE_CHECKING:
42+
from ssl import SSLContext
43+
44+
import typing_extensions as te
45+
4246
from ..._deadline import Deadline
47+
from ...addressing import (
48+
Address,
49+
ResolvedAddress,
50+
)
4351

4452

4553
log = logging.getLogger("neo4j.io")
@@ -63,7 +71,11 @@ def __str__(self):
6371

6472

6573
class AsyncBoltSocket(AsyncBoltSocketBase):
66-
async def _parse_handshake_response_v1(self, ctx, response):
74+
async def _parse_handshake_response_v1(
75+
self,
76+
ctx: HandshakeCtx,
77+
response: bytes,
78+
) -> tuple[int, int]:
6779
agreed_version = response[-1], response[-2]
6880
log.debug(
6981
"[#%04X] S: <HANDSHAKE> 0x%06X%02X",
@@ -73,7 +85,11 @@ async def _parse_handshake_response_v1(self, ctx, response):
7385
)
7486
return agreed_version
7587

76-
async def _parse_handshake_response_v2(self, ctx, response):
88+
async def _parse_handshake_response_v2(
89+
self,
90+
ctx: HandshakeCtx,
91+
response: bytes,
92+
) -> tuple[int, int]:
7793
ctx.ctx = "handshake v2 offerings count"
7894
num_offerings = await self._read_varint(ctx)
7995
offerings = []
@@ -85,7 +101,7 @@ async def _parse_handshake_response_v2(self, ctx, response):
85101
ctx.ctx = "handshake v2 capabilities"
86102
_capabilities_offer = await self._read_varint(ctx)
87103

88-
if log.getEffectiveLevel() >= logging.DEBUG:
104+
if log.getEffectiveLevel() <= logging.DEBUG:
89105
log.debug(
90106
"[#%04X] S: <HANDSHAKE> %s [%i] %s %s",
91107
ctx.local_port,
@@ -125,7 +141,7 @@ async def _parse_handshake_response_v2(self, ctx, response):
125141

126142
return chosen_version
127143

128-
async def _read_varint(self, ctx):
144+
async def _read_varint(self, ctx: HandshakeCtx) -> int:
129145
next_byte = (await self._handshake_read(ctx, 1))[0]
130146
res = next_byte & 0x7F
131147
i = 0
@@ -136,15 +152,15 @@ async def _read_varint(self, ctx):
136152
return res
137153

138154
@staticmethod
139-
def _encode_varint(n):
155+
def _encode_varint(n: int) -> bytearray:
140156
res = bytearray()
141157
while n >= 0x80:
142158
res.append(n & 0x7F | 0x80)
143159
n >>= 7
144160
res.append(n)
145161
return res
146162

147-
async def _handshake_read(self, ctx, n):
163+
async def _handshake_read(self, ctx: HandshakeCtx, n: int) -> bytes:
148164
original_timeout = self.gettimeout()
149165
self.settimeout(ctx.deadline.to_timeout())
150166
try:
@@ -193,7 +209,11 @@ async def _handshake_send(self, ctx, data):
193209
finally:
194210
self.settimeout(original_timeout)
195211

196-
async def _handshake(self, resolved_address, deadline):
212+
async def _handshake(
213+
self,
214+
resolved_address: ResolvedAddress,
215+
deadline: Deadline,
216+
) -> tuple[tuple[int, int], bytes, bytes]:
197217
"""
198218
Perform BOLT handshake.
199219
@@ -204,16 +224,16 @@ async def _handshake(self, resolved_address, deadline):
204224
"""
205225
local_port = self.getsockname()[1]
206226

207-
if log.getEffectiveLevel() >= logging.DEBUG:
208-
handshake = self.Bolt.get_handshake()
209-
handshake = struct.unpack(">16B", handshake)
210-
handshake = [
211-
handshake[i : i + 4] for i in range(0, len(handshake), 4)
227+
handshake = self.Bolt.get_handshake()
228+
if log.getEffectiveLevel() <= logging.DEBUG:
229+
handshake_bytes: t.Sequence = struct.unpack(">16B", handshake)
230+
handshake_bytes = [
231+
handshake[i : i + 4] for i in range(0, len(handshake_bytes), 4)
212232
]
213233

214234
supported_versions = [
215235
f"0x{vx[0]:02X}{vx[1]:02X}{vx[2]:02X}{vx[3]:02X}"
216-
for vx in handshake
236+
for vx in handshake_bytes
217237
]
218238

219239
log.debug(
@@ -227,7 +247,7 @@ async def _handshake(self, resolved_address, deadline):
227247
*supported_versions,
228248
)
229249

230-
request = self.Bolt.MAGIC_PREAMBLE + self.Bolt.get_handshake()
250+
request = self.Bolt.MAGIC_PREAMBLE + handshake
231251

232252
ctx = HandshakeCtx(
233253
ctx="handshake opening",
@@ -273,14 +293,14 @@ async def _handshake(self, resolved_address, deadline):
273293
@classmethod
274294
async def connect(
275295
cls,
276-
address,
296+
address: Address,
277297
*,
278-
tcp_timeout,
279-
deadline,
280-
custom_resolver,
281-
ssl_context,
282-
keep_alive,
283-
):
298+
tcp_timeout: float | None,
299+
deadline: Deadline,
300+
custom_resolver: t.Callable | None,
301+
ssl_context: SSLContext | None,
302+
keep_alive: bool,
303+
) -> tuple[te.Self, tuple[int, int], bytes, bytes]:
284304
"""
285305
Connect and perform a handshake.
286306
@@ -313,10 +333,10 @@ async def connect(
313333
)
314334
return s, agreed_version, handshake, response
315335
except (BoltError, DriverError, OSError) as error:
316-
try:
317-
local_port = s.getsockname()[1]
318-
except (OSError, AttributeError, TypeError):
319-
local_port = 0
336+
local_port = 0
337+
if isinstance(s, cls):
338+
with suppress(OSError, AttributeError, TypeError):
339+
local_port = s.getsockname()[1]
320340
err_str = error.__class__.__name__
321341
if str(error):
322342
err_str += ": " + str(error)
@@ -331,10 +351,10 @@ async def connect(
331351
errors.append(error)
332352
failed_addresses.append(resolved_address)
333353
except asyncio.CancelledError:
334-
try:
335-
local_port = s.getsockname()[1]
336-
except (OSError, AttributeError, TypeError):
337-
local_port = 0
354+
local_port = 0
355+
if isinstance(s, cls):
356+
with suppress(OSError, AttributeError, TypeError):
357+
local_port = s.getsockname()[1]
338358
log.debug(
339359
"[#%04X] C: <CANCELED> %s", local_port, resolved_address
340360
)

src/neo4j/_async_compat/network/_bolt_socket.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _sanitize_deadline(deadline):
7979
class AsyncBoltSocketBase(abc.ABC):
8080
Bolt: te.Final[type[AsyncBolt]] = None # type: ignore[assignment]
8181

82-
def __init__(self, reader, protocol, writer):
82+
def __init__(self, reader, protocol, writer) -> None:
8383
self._reader = reader # type: asyncio.StreamReader
8484
self._protocol = protocol # type: asyncio.StreamReaderProtocol
8585
self._writer = writer # type: asyncio.StreamWriter
@@ -171,7 +171,7 @@ def kill(self):
171171
@classmethod
172172
async def _connect_secure(
173173
cls, resolved_address, timeout, keep_alive, ssl_context
174-
):
174+
) -> te.Self:
175175
"""
176176
Connect to the address and return the socket.
177177
@@ -202,7 +202,7 @@ async def _connect_secure(
202202
keep_alive = 1 if keep_alive else 0
203203
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
204204

205-
ssl_kwargs = {}
205+
ssl_kwargs: dict[str, t.Any] = {}
206206

207207
if ssl_context is not None:
208208
hostname = resolved_address._host_name or None

src/neo4j/_sync/io/_bolt.py

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/neo4j/_sync/io/_bolt_socket.py

+49-29
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)