Skip to content

Commit 3be2ae7

Browse files
authored
Merge pull request #686 from Altinity/backports/24.8/74749_fix_reuse_connections
24.8 Backport of 74749 - fix reuse connections
2 parents 40f384d + 6505945 commit 3be2ae7

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

src/Client/Connection.cpp

+28-21
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include <Common/FailPoint.h>
3838

3939
#include <Common/config_version.h>
40+
#include <Common/scope_guard_safe.h>
4041
#include <Core/Types.h>
4142
#include "config.h"
4243

@@ -220,7 +221,7 @@ void Connection::connect(const ConnectionTimeouts & timeouts)
220221
connected = true;
221222
setDescription();
222223

223-
sendHello();
224+
sendHello(timeouts.handshake_timeout);
224225
receiveHello(timeouts.handshake_timeout);
225226

226227
if (server_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_CHUNKED_PACKETS)
@@ -371,7 +372,7 @@ void Connection::disconnect()
371372
}
372373

373374

374-
void Connection::sendHello()
375+
void Connection::sendHello([[maybe_unused]] const Poco::Timespan & handshake_timeout)
375376
{
376377
/** Disallow control characters in user controlled parameters
377378
* to mitigate the possibility of SSRF.
@@ -424,7 +425,7 @@ void Connection::sendHello()
424425
writeStringBinary(String(EncodedUserInfo::SSH_KEY_AUTHENTICAION_MARKER) + user, *out);
425426
writeStringBinary(password, *out);
426427

427-
performHandshakeForSSHAuth();
428+
performHandshakeForSSHAuth(handshake_timeout);
428429
}
429430
#endif
430431
else if (!jwt.empty())
@@ -461,8 +462,10 @@ void Connection::sendAddendum()
461462

462463

463464
#if USE_SSH
464-
void Connection::performHandshakeForSSHAuth()
465+
void Connection::performHandshakeForSSHAuth(const Poco::Timespan & handshake_timeout)
465466
{
467+
TimeoutSetter timeout_setter(*socket, handshake_timeout, handshake_timeout);
468+
466469
String challenge;
467470
{
468471
writeVarUInt(Protocol::Client::SSHChallengeRequest, *out);
@@ -479,11 +482,7 @@ void Connection::performHandshakeForSSHAuth()
479482
else if (packet_type == Protocol::Server::Exception)
480483
receiveException()->rethrow();
481484
else
482-
{
483-
/// Close connection, to not stay in unsynchronised state.
484-
disconnect();
485-
throwUnexpectedPacket(packet_type, "SSHChallenge or Exception");
486-
}
485+
throwUnexpectedPacket(timeout_setter, packet_type, "SSHChallenge or Exception");
487486
}
488487

489488
writeVarUInt(Protocol::Client::SSHChallengeResponse, *out);
@@ -569,15 +568,7 @@ void Connection::receiveHello(const Poco::Timespan & handshake_timeout)
569568
else if (packet_type == Protocol::Server::Exception)
570569
receiveException()->rethrow();
571570
else
572-
{
573-
/// Reset timeout_setter before disconnect,
574-
/// because after disconnect socket will be invalid.
575-
timeout_setter.reset();
576-
577-
/// Close connection, to not stay in unsynchronised state.
578-
disconnect();
579-
throwUnexpectedPacket(packet_type, "Hello or Exception");
580-
}
571+
throwUnexpectedPacket(timeout_setter, packet_type, "Hello or Exception");
581572
}
582573

583574
void Connection::setDefaultDatabase(const String & database)
@@ -702,7 +693,7 @@ bool Connection::ping(const ConnectionTimeouts & timeouts)
702693
}
703694

704695
if (pong != Protocol::Server::Pong)
705-
throwUnexpectedPacket(pong, "Pong");
696+
throwUnexpectedPacket(timeout_setter, pong, "Pong");
706697
}
707698
catch (const Poco::Exception & e)
708699
{
@@ -741,7 +732,7 @@ TablesStatusResponse Connection::getTablesStatus(const ConnectionTimeouts & time
741732
if (response_type == Protocol::Server::Exception)
742733
receiveException()->rethrow();
743734
else if (response_type != Protocol::Server::TablesStatusResponse)
744-
throwUnexpectedPacket(response_type, "TablesStatusResponse");
735+
throwUnexpectedPacket(timeout_setter, response_type, "TablesStatusResponse");
745736

746737
TablesStatusResponse response;
747738
response.read(*in, server_revision);
@@ -810,6 +801,14 @@ void Connection::sendQuery(
810801

811802
query_id = query_id_;
812803

804+
/// Avoid reusing connections that had been left in the intermediate state
805+
/// (i.e. not all packets had been sent).
806+
bool completed = false;
807+
SCOPE_EXIT({
808+
if (!completed)
809+
disconnect();
810+
});
811+
813812
writeVarUInt(Protocol::Client::Query, *out);
814813
writeStringBinary(query_id, *out);
815814

@@ -910,6 +909,8 @@ void Connection::sendQuery(
910909
sendData(Block(), "", false);
911910
out->next();
912911
}
912+
913+
completed = true;
913914
}
914915

915916

@@ -1436,8 +1437,14 @@ InitialAllRangesAnnouncement Connection::receiveInitialParallelReadAnnouncement(
14361437
}
14371438

14381439

1439-
void Connection::throwUnexpectedPacket(UInt64 packet_type, const char * expected) const
1440+
void Connection::throwUnexpectedPacket(TimeoutSetter & timeout_setter, UInt64 packet_type, const char * expected)
14401441
{
1442+
/// Reset timeout_setter before disconnect, because after disconnect socket will be invalid.
1443+
timeout_setter.reset();
1444+
1445+
/// Close connection, to avoid leaving it in an unsynchronised state.
1446+
disconnect();
1447+
14411448
throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_SERVER,
14421449
"Unexpected packet from server {} (expected {}, got {})",
14431450
getDescription(), expected, String(Protocol::Server::toString(packet_type)));

src/Client/Connection.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace DB
2626
{
2727

2828
struct Settings;
29+
struct TimeoutSetter;
2930

3031
class Connection;
3132
struct ConnectionParameters;
@@ -275,10 +276,10 @@ class Connection : public IServerConnection
275276
AsyncCallback async_callback = {};
276277

277278
void connect(const ConnectionTimeouts & timeouts);
278-
void sendHello();
279+
void sendHello(const Poco::Timespan & handshake_timeout);
279280

280281
#if USE_SSH
281-
void performHandshakeForSSHAuth();
282+
void performHandshakeForSSHAuth(const Poco::Timespan & handshake_timeout);
282283
#endif
283284

284285
void sendAddendum();
@@ -306,7 +307,7 @@ class Connection : public IServerConnection
306307
void initBlockLogsInput();
307308
void initBlockProfileEventsInput();
308309

309-
[[noreturn]] void throwUnexpectedPacket(UInt64 packet_type, const char * expected) const;
310+
[[noreturn]] void throwUnexpectedPacket(TimeoutSetter & timeout_setter, UInt64 packet_type, const char * expected);
310311
};
311312

312313
template <typename Conn>

0 commit comments

Comments
 (0)