@@ -321,7 +321,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD):
321321 return server
322322
323323
324- def loopback (server_factory = None , client_factory = None ):
324+ def loopback (server_factory = None , client_factory = None , blocking = True ):
325325 """
326326 Create a connected socket pair and force two connected SSL sockets
327327 to talk to each other via memory BIOs.
@@ -337,8 +337,8 @@ def loopback(server_factory=None, client_factory=None):
337337
338338 handshake (client , server )
339339
340- server .setblocking (True )
341- client .setblocking (True )
340+ server .setblocking (blocking )
341+ client .setblocking (blocking )
342342 return server , client
343343
344344
@@ -3292,11 +3292,134 @@ def test_memoryview_really_doesnt_overfill(self):
32923292 self ._doesnt_overfill_test (_make_memoryview )
32933293
32943294
3295+ @pytest .fixture
3296+ def nonblocking_tls_connections_pair ():
3297+ """Return a non-blocking TLS loopback connections pair."""
3298+ return loopback (blocking = False )
3299+
3300+
3301+ @pytest .fixture
3302+ def nonblocking_tls_server_connection (nonblocking_tls_connections_pair ):
3303+ """Return a non-blocking TLS server socket connected to loopback."""
3304+ return nonblocking_tls_connections_pair [0 ]
3305+
3306+
3307+ @pytest .fixture
3308+ def nonblocking_tls_client_connection (nonblocking_tls_connections_pair ):
3309+ """Return a non-blocking TLS client socket connected to loopback."""
3310+ return nonblocking_tls_connections_pair [1 ]
3311+
3312+
32953313class TestConnectionSendall :
32963314 """
32973315 Tests for `Connection.sendall`.
32983316 """
32993317
3318+ def test_want_write (
3319+ self ,
3320+ monkeypatch ,
3321+ nonblocking_tls_server_connection ,
3322+ nonblocking_tls_client_connection ,
3323+ ):
3324+ msg = b"x"
3325+ garbage_size = 1024 * 1024 * 64
3326+ large_payload = b"p" * garbage_size * 2
3327+ payload_size = len (large_payload )
3328+
3329+ sent_garbage_size = 0
3330+ try :
3331+ sent_garbage_size += nonblocking_tls_client_connection .send (
3332+ msg * garbage_size ,
3333+ )
3334+ except WantWriteError :
3335+ pass
3336+ for i in range (garbage_size ):
3337+ try :
3338+ sent_garbage_size += nonblocking_tls_client_connection .send (
3339+ msg ,
3340+ )
3341+ except WantWriteError :
3342+ break
3343+ else :
3344+ pytest .fail (
3345+ "Failed to fill socket buffer, cannot test "
3346+ "'want write' in `sendall()`"
3347+ )
3348+ garbage_payload = sent_garbage_size * msg
3349+
3350+ def consume_garbage (conn ):
3351+ assert patched_ssl_write .want_write_counter >= 1
3352+ assert not consume_garbage .garbage_consumed
3353+
3354+ while len (consume_garbage .consumed ) < sent_garbage_size :
3355+ try :
3356+ consume_garbage .consumed += conn .recv (
3357+ sent_garbage_size - len (consume_garbage .consumed ),
3358+ )
3359+ except WantReadError :
3360+ pass
3361+
3362+ assert consume_garbage .consumed == garbage_payload
3363+
3364+ consume_garbage .garbage_consumed = True
3365+
3366+ consume_garbage .garbage_consumed = False
3367+ consume_garbage .consumed = b""
3368+
3369+ def consume_payload (conn ):
3370+ try :
3371+ consume_payload .consumed += conn .recv (payload_size )
3372+ except WantReadError :
3373+ pass
3374+
3375+ consume_payload .consumed = b""
3376+
3377+ original_ssl_write = _lib .SSL_write
3378+
3379+ def patched_ssl_write (ctx , data , size ):
3380+ write_result = original_ssl_write (ctx , data , size )
3381+ try :
3382+ nonblocking_tls_client_connection ._raise_ssl_error (
3383+ ctx ,
3384+ write_result ,
3385+ )
3386+ except WantWriteError :
3387+ patched_ssl_write .want_write_counter += 1
3388+ consume_data_on_server = (
3389+ consume_payload
3390+ if consume_garbage .garbage_consumed
3391+ else consume_garbage
3392+ )
3393+
3394+ consume_data_on_server (nonblocking_tls_server_connection )
3395+ # NOTE: We don't re-raise it as the calling code will do
3396+ # NOTE: the same after the call.
3397+ return write_result
3398+
3399+ patched_ssl_write .want_write_counter = 0
3400+
3401+ # NOTE: Make the client think it needs a handshake so that it'll
3402+ # NOTE: attempt to `do_handshake()` on the next `SSL_write()`
3403+ # NOTE: that originates from `sendall()`:
3404+ nonblocking_tls_client_connection .set_connect_state ()
3405+ try :
3406+ nonblocking_tls_client_connection .do_handshake ()
3407+ except WantWriteError :
3408+ assert True # Sanity check
3409+ except :
3410+ assert False # This should never happen (see the note above)
3411+
3412+ with monkeypatch .context () as mp_ctx :
3413+ mp_ctx .setattr (_lib , "SSL_write" , patched_ssl_write )
3414+ nonblocking_tls_client_connection .sendall (large_payload )
3415+
3416+ assert consume_garbage .garbage_consumed
3417+
3418+ # NOTE: Read the leftover data from the very last `SSL_write()`
3419+ consume_payload (nonblocking_tls_server_connection )
3420+
3421+ assert consume_payload .consumed == large_payload
3422+
33003423 def test_wrong_args (self ):
33013424 """
33023425 When called with arguments other than a string argument for its first
0 commit comments