37
37
#include < Common/FailPoint.h>
38
38
39
39
#include < Common/config_version.h>
40
+ #include < Common/scope_guard_safe.h>
40
41
#include < Core/Types.h>
41
42
#include " config.h"
42
43
@@ -220,7 +221,7 @@ void Connection::connect(const ConnectionTimeouts & timeouts)
220
221
connected = true ;
221
222
setDescription ();
222
223
223
- sendHello ();
224
+ sendHello (timeouts. handshake_timeout );
224
225
receiveHello (timeouts.handshake_timeout );
225
226
226
227
if (server_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_CHUNKED_PACKETS)
@@ -371,7 +372,7 @@ void Connection::disconnect()
371
372
}
372
373
373
374
374
- void Connection::sendHello ()
375
+ void Connection::sendHello ([[maybe_unused]] const Poco::Timespan & handshake_timeout )
375
376
{
376
377
/* * Disallow control characters in user controlled parameters
377
378
* to mitigate the possibility of SSRF.
@@ -424,7 +425,7 @@ void Connection::sendHello()
424
425
writeStringBinary (String (EncodedUserInfo::SSH_KEY_AUTHENTICAION_MARKER) + user, *out);
425
426
writeStringBinary (password, *out);
426
427
427
- performHandshakeForSSHAuth ();
428
+ performHandshakeForSSHAuth (handshake_timeout );
428
429
}
429
430
#endif
430
431
else if (!jwt.empty ())
@@ -461,8 +462,10 @@ void Connection::sendAddendum()
461
462
462
463
463
464
#if USE_SSH
464
- void Connection::performHandshakeForSSHAuth ()
465
+ void Connection::performHandshakeForSSHAuth (const Poco::Timespan & handshake_timeout )
465
466
{
467
+ TimeoutSetter timeout_setter (*socket, handshake_timeout, handshake_timeout);
468
+
466
469
String challenge;
467
470
{
468
471
writeVarUInt (Protocol::Client::SSHChallengeRequest, *out);
@@ -479,11 +482,7 @@ void Connection::performHandshakeForSSHAuth()
479
482
else if (packet_type == Protocol::Server::Exception)
480
483
receiveException ()->rethrow ();
481
484
else
482
- {
483
- // / Close connection, to not stay in unsynchronised state.
484
- disconnect ();
485
- throwUnexpectedPacket (packet_type, " SSHChallenge or Exception" );
486
- }
485
+ throwUnexpectedPacket (timeout_setter, packet_type, " SSHChallenge or Exception" );
487
486
}
488
487
489
488
writeVarUInt (Protocol::Client::SSHChallengeResponse, *out);
@@ -569,15 +568,7 @@ void Connection::receiveHello(const Poco::Timespan & handshake_timeout)
569
568
else if (packet_type == Protocol::Server::Exception)
570
569
receiveException ()->rethrow ();
571
570
else
572
- {
573
- // / Reset timeout_setter before disconnect,
574
- // / because after disconnect socket will be invalid.
575
- timeout_setter.reset ();
576
-
577
- // / Close connection, to not stay in unsynchronised state.
578
- disconnect ();
579
- throwUnexpectedPacket (packet_type, " Hello or Exception" );
580
- }
571
+ throwUnexpectedPacket (timeout_setter, packet_type, " Hello or Exception" );
581
572
}
582
573
583
574
void Connection::setDefaultDatabase (const String & database)
@@ -702,7 +693,7 @@ bool Connection::ping(const ConnectionTimeouts & timeouts)
702
693
}
703
694
704
695
if (pong != Protocol::Server::Pong)
705
- throwUnexpectedPacket (pong, " Pong" );
696
+ throwUnexpectedPacket (timeout_setter, pong, " Pong" );
706
697
}
707
698
catch (const Poco::Exception & e)
708
699
{
@@ -741,7 +732,7 @@ TablesStatusResponse Connection::getTablesStatus(const ConnectionTimeouts & time
741
732
if (response_type == Protocol::Server::Exception)
742
733
receiveException ()->rethrow ();
743
734
else if (response_type != Protocol::Server::TablesStatusResponse)
744
- throwUnexpectedPacket (response_type, " TablesStatusResponse" );
735
+ throwUnexpectedPacket (timeout_setter, response_type, " TablesStatusResponse" );
745
736
746
737
TablesStatusResponse response;
747
738
response.read (*in, server_revision);
@@ -810,6 +801,14 @@ void Connection::sendQuery(
810
801
811
802
query_id = query_id_;
812
803
804
+ // / Avoid reusing connections that had been left in the intermediate state
805
+ // / (i.e. not all packets had been sent).
806
+ bool completed = false ;
807
+ SCOPE_EXIT ({
808
+ if (!completed)
809
+ disconnect ();
810
+ });
811
+
813
812
writeVarUInt (Protocol::Client::Query, *out);
814
813
writeStringBinary (query_id, *out);
815
814
@@ -910,6 +909,8 @@ void Connection::sendQuery(
910
909
sendData (Block (), " " , false );
911
910
out->next ();
912
911
}
912
+
913
+ completed = true ;
913
914
}
914
915
915
916
@@ -1436,8 +1437,14 @@ InitialAllRangesAnnouncement Connection::receiveInitialParallelReadAnnouncement(
1436
1437
}
1437
1438
1438
1439
1439
- void Connection::throwUnexpectedPacket (UInt64 packet_type, const char * expected) const
1440
+ void Connection::throwUnexpectedPacket (TimeoutSetter & timeout_setter, UInt64 packet_type, const char * expected)
1440
1441
{
1442
+ // / Reset timeout_setter before disconnect, because after disconnect socket will be invalid.
1443
+ timeout_setter.reset ();
1444
+
1445
+ // / Close connection, to avoid leaving it in an unsynchronised state.
1446
+ disconnect ();
1447
+
1441
1448
throw NetException (ErrorCodes::UNEXPECTED_PACKET_FROM_SERVER,
1442
1449
" Unexpected packet from server {} (expected {}, got {})" ,
1443
1450
getDescription (), expected, String (Protocol::Server::toString (packet_type)));
0 commit comments