39
39
40
40
41
41
if t .TYPE_CHECKING :
42
+ from ssl import SSLContext
43
+
44
+ import typing_extensions as te
45
+
42
46
from ..._deadline import Deadline
47
+ from ...addressing import (
48
+ Address ,
49
+ ResolvedAddress ,
50
+ )
43
51
44
52
45
53
log = logging .getLogger ("neo4j.io" )
@@ -63,7 +71,11 @@ def __str__(self):
63
71
64
72
65
73
class AsyncBoltSocket (AsyncBoltSocketBase ):
66
- async def _parse_handshake_response_v1 (self , ctx , response ):
74
+ async def _parse_handshake_response_v1 (
75
+ self ,
76
+ ctx : HandshakeCtx ,
77
+ response : bytes ,
78
+ ) -> tuple [int , int ]:
67
79
agreed_version = response [- 1 ], response [- 2 ]
68
80
log .debug (
69
81
"[#%04X] S: <HANDSHAKE> 0x%06X%02X" ,
@@ -73,7 +85,11 @@ async def _parse_handshake_response_v1(self, ctx, response):
73
85
)
74
86
return agreed_version
75
87
76
- async def _parse_handshake_response_v2 (self , ctx , response ):
88
+ async def _parse_handshake_response_v2 (
89
+ self ,
90
+ ctx : HandshakeCtx ,
91
+ response : bytes ,
92
+ ) -> tuple [int , int ]:
77
93
ctx .ctx = "handshake v2 offerings count"
78
94
num_offerings = await self ._read_varint (ctx )
79
95
offerings = []
@@ -85,7 +101,7 @@ async def _parse_handshake_response_v2(self, ctx, response):
85
101
ctx .ctx = "handshake v2 capabilities"
86
102
_capabilities_offer = await self ._read_varint (ctx )
87
103
88
- if log .getEffectiveLevel () > = logging .DEBUG :
104
+ if log .getEffectiveLevel () < = logging .DEBUG :
89
105
log .debug (
90
106
"[#%04X] S: <HANDSHAKE> %s [%i] %s %s" ,
91
107
ctx .local_port ,
@@ -125,7 +141,7 @@ async def _parse_handshake_response_v2(self, ctx, response):
125
141
126
142
return chosen_version
127
143
128
- async def _read_varint (self , ctx ) :
144
+ async def _read_varint (self , ctx : HandshakeCtx ) -> int :
129
145
next_byte = (await self ._handshake_read (ctx , 1 ))[0 ]
130
146
res = next_byte & 0x7F
131
147
i = 0
@@ -136,15 +152,15 @@ async def _read_varint(self, ctx):
136
152
return res
137
153
138
154
@staticmethod
139
- def _encode_varint (n ) :
155
+ def _encode_varint (n : int ) -> bytearray :
140
156
res = bytearray ()
141
157
while n >= 0x80 :
142
158
res .append (n & 0x7F | 0x80 )
143
159
n >>= 7
144
160
res .append (n )
145
161
return res
146
162
147
- async def _handshake_read (self , ctx , n ) :
163
+ async def _handshake_read (self , ctx : HandshakeCtx , n : int ) -> bytes :
148
164
original_timeout = self .gettimeout ()
149
165
self .settimeout (ctx .deadline .to_timeout ())
150
166
try :
@@ -193,7 +209,11 @@ async def _handshake_send(self, ctx, data):
193
209
finally :
194
210
self .settimeout (original_timeout )
195
211
196
- async def _handshake (self , resolved_address , deadline ):
212
+ async def _handshake (
213
+ self ,
214
+ resolved_address : ResolvedAddress ,
215
+ deadline : Deadline ,
216
+ ) -> tuple [tuple [int , int ], bytes , bytes ]:
197
217
"""
198
218
Perform BOLT handshake.
199
219
@@ -204,16 +224,16 @@ async def _handshake(self, resolved_address, deadline):
204
224
"""
205
225
local_port = self .getsockname ()[1 ]
206
226
207
- if log . getEffectiveLevel () >= logging . DEBUG :
208
- handshake = self . Bolt . get_handshake ()
209
- handshake = struct .unpack (">16B" , handshake )
210
- handshake = [
211
- handshake [i : i + 4 ] for i in range (0 , len (handshake ), 4 )
227
+ handshake = self . Bolt . get_handshake ()
228
+ if log . getEffectiveLevel () <= logging . DEBUG :
229
+ handshake_bytes : t . Sequence = struct .unpack (">16B" , handshake )
230
+ handshake_bytes = [
231
+ handshake [i : i + 4 ] for i in range (0 , len (handshake_bytes ), 4 )
212
232
]
213
233
214
234
supported_versions = [
215
235
f"0x{ vx [0 ]:02X} { vx [1 ]:02X} { vx [2 ]:02X} { vx [3 ]:02X} "
216
- for vx in handshake
236
+ for vx in handshake_bytes
217
237
]
218
238
219
239
log .debug (
@@ -227,7 +247,7 @@ async def _handshake(self, resolved_address, deadline):
227
247
* supported_versions ,
228
248
)
229
249
230
- request = self .Bolt .MAGIC_PREAMBLE + self . Bolt . get_handshake ()
250
+ request = self .Bolt .MAGIC_PREAMBLE + handshake
231
251
232
252
ctx = HandshakeCtx (
233
253
ctx = "handshake opening" ,
@@ -273,14 +293,14 @@ async def _handshake(self, resolved_address, deadline):
273
293
@classmethod
274
294
async def connect (
275
295
cls ,
276
- address ,
296
+ address : Address ,
277
297
* ,
278
- tcp_timeout ,
279
- deadline ,
280
- custom_resolver ,
281
- ssl_context ,
282
- keep_alive ,
283
- ):
298
+ tcp_timeout : float | None ,
299
+ deadline : Deadline ,
300
+ custom_resolver : t . Callable | None ,
301
+ ssl_context : SSLContext | None ,
302
+ keep_alive : bool ,
303
+ ) -> tuple [ te . Self , tuple [ int , int ], bytes , bytes ] :
284
304
"""
285
305
Connect and perform a handshake.
286
306
@@ -313,10 +333,10 @@ async def connect(
313
333
)
314
334
return s , agreed_version , handshake , response
315
335
except (BoltError , DriverError , OSError ) as error :
316
- try :
317
- local_port = s . getsockname ()[ 1 ]
318
- except (OSError , AttributeError , TypeError ):
319
- local_port = 0
336
+ local_port = 0
337
+ if isinstance ( s , cls ):
338
+ with suppress (OSError , AttributeError , TypeError ):
339
+ local_port = s . getsockname ()[ 1 ]
320
340
err_str = error .__class__ .__name__
321
341
if str (error ):
322
342
err_str += ": " + str (error )
@@ -331,10 +351,10 @@ async def connect(
331
351
errors .append (error )
332
352
failed_addresses .append (resolved_address )
333
353
except asyncio .CancelledError :
334
- try :
335
- local_port = s . getsockname ()[ 1 ]
336
- except (OSError , AttributeError , TypeError ):
337
- local_port = 0
354
+ local_port = 0
355
+ if isinstance ( s , cls ):
356
+ with suppress (OSError , AttributeError , TypeError ):
357
+ local_port = s . getsockname ()[ 1 ]
338
358
log .debug (
339
359
"[#%04X] C: <CANCELED> %s" , local_port , resolved_address
340
360
)
0 commit comments