3737#include < Common/FailPoint.h>
3838
3939#include < Common/config_version.h>
40+ #include < Common/scope_guard_safe.h>
4041#include < Core/Types.h>
4142#include " config.h"
4243
@@ -220,7 +221,7 @@ void Connection::connect(const ConnectionTimeouts & timeouts)
220221 connected = true ;
221222 setDescription ();
222223
223- sendHello ();
224+ sendHello (timeouts. handshake_timeout );
224225 receiveHello (timeouts.handshake_timeout );
225226
226227 if (server_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_CHUNKED_PACKETS)
@@ -371,7 +372,7 @@ void Connection::disconnect()
371372}
372373
373374
374- void Connection::sendHello ()
375+ void Connection::sendHello ([[maybe_unused]] const Poco::Timespan & handshake_timeout )
375376{
376377 /* * Disallow control characters in user controlled parameters
377378 * to mitigate the possibility of SSRF.
@@ -424,7 +425,7 @@ void Connection::sendHello()
424425 writeStringBinary (String (EncodedUserInfo::SSH_KEY_AUTHENTICAION_MARKER) + user, *out);
425426 writeStringBinary (password, *out);
426427
427- performHandshakeForSSHAuth ();
428+ performHandshakeForSSHAuth (handshake_timeout );
428429 }
429430#endif
430431 else if (!jwt.empty ())
@@ -461,8 +462,10 @@ void Connection::sendAddendum()
461462
462463
463464#if USE_SSH
464- void Connection::performHandshakeForSSHAuth ()
465+ void Connection::performHandshakeForSSHAuth (const Poco::Timespan & handshake_timeout )
465466{
467+ TimeoutSetter timeout_setter (*socket, handshake_timeout, handshake_timeout);
468+
466469 String challenge;
467470 {
468471 writeVarUInt (Protocol::Client::SSHChallengeRequest, *out);
@@ -479,11 +482,7 @@ void Connection::performHandshakeForSSHAuth()
479482 else if (packet_type == Protocol::Server::Exception)
480483 receiveException ()->rethrow ();
481484 else
482- {
483- // / Close connection, to not stay in unsynchronised state.
484- disconnect ();
485- throwUnexpectedPacket (packet_type, " SSHChallenge or Exception" );
486- }
485+ throwUnexpectedPacket (timeout_setter, packet_type, " SSHChallenge or Exception" );
487486 }
488487
489488 writeVarUInt (Protocol::Client::SSHChallengeResponse, *out);
@@ -569,15 +568,7 @@ void Connection::receiveHello(const Poco::Timespan & handshake_timeout)
569568 else if (packet_type == Protocol::Server::Exception)
570569 receiveException ()->rethrow ();
571570 else
572- {
573- // / Reset timeout_setter before disconnect,
574- // / because after disconnect socket will be invalid.
575- timeout_setter.reset ();
576-
577- // / Close connection, to not stay in unsynchronised state.
578- disconnect ();
579- throwUnexpectedPacket (packet_type, " Hello or Exception" );
580- }
571+ throwUnexpectedPacket (timeout_setter, packet_type, " Hello or Exception" );
581572}
582573
583574void Connection::setDefaultDatabase (const String & database)
@@ -702,7 +693,7 @@ bool Connection::ping(const ConnectionTimeouts & timeouts)
702693 }
703694
704695 if (pong != Protocol::Server::Pong)
705- throwUnexpectedPacket (pong, " Pong" );
696+ throwUnexpectedPacket (timeout_setter, pong, " Pong" );
706697 }
707698 catch (const Poco::Exception & e)
708699 {
@@ -741,7 +732,7 @@ TablesStatusResponse Connection::getTablesStatus(const ConnectionTimeouts & time
741732 if (response_type == Protocol::Server::Exception)
742733 receiveException ()->rethrow ();
743734 else if (response_type != Protocol::Server::TablesStatusResponse)
744- throwUnexpectedPacket (response_type, " TablesStatusResponse" );
735+ throwUnexpectedPacket (timeout_setter, response_type, " TablesStatusResponse" );
745736
746737 TablesStatusResponse response;
747738 response.read (*in, server_revision);
@@ -810,6 +801,14 @@ void Connection::sendQuery(
810801
811802 query_id = query_id_;
812803
804+ // / Avoid reusing connections that had been left in the intermediate state
805+ // / (i.e. not all packets had been sent).
806+ bool completed = false ;
807+ SCOPE_EXIT ({
808+ if (!completed)
809+ disconnect ();
810+ });
811+
813812 writeVarUInt (Protocol::Client::Query, *out);
814813 writeStringBinary (query_id, *out);
815814
@@ -910,6 +909,8 @@ void Connection::sendQuery(
910909 sendData (Block (), " " , false );
911910 out->next ();
912911 }
912+
913+ completed = true ;
913914}
914915
915916
@@ -1436,8 +1437,14 @@ InitialAllRangesAnnouncement Connection::receiveInitialParallelReadAnnouncement(
14361437}
14371438
14381439
1439- void Connection::throwUnexpectedPacket (UInt64 packet_type, const char * expected) const
1440+ void Connection::throwUnexpectedPacket (TimeoutSetter & timeout_setter, UInt64 packet_type, const char * expected)
14401441{
1442+ // / Reset timeout_setter before disconnect, because after disconnect socket will be invalid.
1443+ timeout_setter.reset ();
1444+
1445+ // / Close connection, to avoid leaving it in an unsynchronised state.
1446+ disconnect ();
1447+
14411448 throw NetException (ErrorCodes::UNEXPECTED_PACKET_FROM_SERVER,
14421449 " Unexpected packet from server {} (expected {}, got {})" ,
14431450 getDescription (), expected, String (Protocol::Server::toString (packet_type)));
0 commit comments