Skip to content

Commit a961ad1

Browse files
committed
Merge bitcoin#30202: netbase: extend CreateSock() to support creating arbitrary sockets
1245d13 netbase: extend CreateSock() to support creating arbitrary sockets (Vasil Dimov) Pull request description: Allow the callers of `CreateSock()` to pass all 3 arguments to the `socket(2)` syscall. This makes it possible to create sockets of any domain/type/protocol. In addition to extending arguments, some extra safety checks were put in place. The need for this came up during the discussion in bitcoin#30043 (comment) ACKs for top commit: achow101: ACK 1245d13 tdb3: re ACK 1245d13 theStack: re-ACK 1245d13 Tree-SHA512: cc86b56121293ac98959aed0ed77812d20702ed7029b5a043586f46e74295779c5354bb0d5f9e80be6c29e535df980d34c1dbf609064fb7ea3e5ca0f0ed54d6b
2 parents 21656e9 + 1245d13 commit a961ad1

File tree

6 files changed

+33
-33
lines changed

6 files changed

+33
-33
lines changed

src/net.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -3029,7 +3029,7 @@ bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError,
30293029
return false;
30303030
}
30313031

3032-
std::unique_ptr<Sock> sock = CreateSock(addrBind.GetSAFamily());
3032+
std::unique_ptr<Sock> sock = CreateSock(addrBind.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP);
30333033
if (!sock) {
30343034
strError = strprintf(Untranslated("Couldn't open socket for incoming connections (socket returned error %s)"), NetworkErrorString(WSAGetLastError()));
30353035
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "%s\n", strError.original);

src/netbase.cpp

+18-16
Original file line numberDiff line numberDiff line change
@@ -487,24 +487,23 @@ bool Socks5(const std::string& strDest, uint16_t port, const ProxyCredentials* a
487487
}
488488
}
489489

490-
std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family)
490+
std::unique_ptr<Sock> CreateSockOS(int domain, int type, int protocol)
491491
{
492492
// Not IPv4, IPv6 or UNIX
493-
if (address_family == AF_UNSPEC) return nullptr;
494-
495-
int protocol{IPPROTO_TCP};
496-
#if HAVE_SOCKADDR_UN
497-
if (address_family == AF_UNIX) protocol = 0;
498-
#endif
493+
if (domain == AF_UNSPEC) return nullptr;
499494

500495
// Create a socket in the specified address family.
501-
SOCKET hSocket = socket(address_family, SOCK_STREAM, protocol);
496+
SOCKET hSocket = socket(domain, type, protocol);
502497
if (hSocket == INVALID_SOCKET) {
503498
return nullptr;
504499
}
505500

506501
auto sock = std::make_unique<Sock>(hSocket);
507502

503+
if (domain != AF_INET && domain != AF_INET6 && domain != AF_UNIX) {
504+
return sock;
505+
}
506+
508507
// Ensure that waiting for I/O on this socket won't result in undefined
509508
// behavior.
510509
if (!sock->IsSelectable()) {
@@ -529,18 +528,21 @@ std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family)
529528
}
530529

531530
#if HAVE_SOCKADDR_UN
532-
if (address_family == AF_UNIX) return sock;
531+
if (domain == AF_UNIX) return sock;
533532
#endif
534533

535-
// Set the no-delay option (disable Nagle's algorithm) on the TCP socket.
536-
const int on{1};
537-
if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) {
538-
LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n");
534+
if (protocol == IPPROTO_TCP) {
535+
// Set the no-delay option (disable Nagle's algorithm) on the TCP socket.
536+
const int on{1};
537+
if (sock->SetSockOpt(IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)) == SOCKET_ERROR) {
538+
LogPrint(BCLog::NET, "Unable to set TCP_NODELAY on a newly created socket, continuing anyway\n");
539+
}
539540
}
541+
540542
return sock;
541543
}
542544

543-
std::function<std::unique_ptr<Sock>(const sa_family_t&)> CreateSock = CreateSockOS;
545+
std::function<std::unique_ptr<Sock>(int, int, int)> CreateSock = CreateSockOS;
544546

545547
template<typename... Args>
546548
static void LogConnectFailure(bool manual_connection, const char* fmt, const Args&... args) {
@@ -609,7 +611,7 @@ static bool ConnectToSocket(const Sock& sock, struct sockaddr* sockaddr, socklen
609611

610612
std::unique_ptr<Sock> ConnectDirectly(const CService& dest, bool manual_connection)
611613
{
612-
auto sock = CreateSock(dest.GetSAFamily());
614+
auto sock = CreateSock(dest.GetSAFamily(), SOCK_STREAM, IPPROTO_TCP);
613615
if (!sock) {
614616
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", dest.ToStringAddrPort());
615617
return {};
@@ -637,7 +639,7 @@ std::unique_ptr<Sock> Proxy::Connect() const
637639
if (!m_is_unix_socket) return ConnectDirectly(proxy, /*manual_connection=*/true);
638640

639641
#if HAVE_SOCKADDR_UN
640-
auto sock = CreateSock(AF_UNIX);
642+
auto sock = CreateSock(AF_UNIX, SOCK_STREAM, 0);
641643
if (!sock) {
642644
LogPrintLevel(BCLog::NET, BCLog::Level::Error, "Cannot create a socket for connecting to %s\n", m_unix_socket_path);
643645
return {};

src/netbase.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -262,16 +262,18 @@ CService LookupNumeric(const std::string& name, uint16_t portDefault = 0, DNSLoo
262262
CSubNet LookupSubNet(const std::string& subnet_str);
263263

264264
/**
265-
* Create a TCP or UNIX socket in the given address family.
266-
* @param[in] address_family to use for the socket.
265+
* Create a real socket from the operating system.
266+
* @param[in] domain Communications domain, first argument to the socket(2) syscall.
267+
* @param[in] type Type of the socket, second argument to the socket(2) syscall.
268+
* @param[in] protocol The particular protocol to be used with the socket, third argument to the socket(2) syscall.
267269
* @return pointer to the created Sock object or unique_ptr that owns nothing in case of failure
268270
*/
269-
std::unique_ptr<Sock> CreateSockOS(sa_family_t address_family);
271+
std::unique_ptr<Sock> CreateSockOS(int domain, int type, int protocol);
270272

271273
/**
272274
* Socket factory. Defaults to `CreateSockOS()`, but can be overridden by unit tests.
273275
*/
274-
extern std::function<std::unique_ptr<Sock>(const sa_family_t&)> CreateSock;
276+
extern std::function<std::unique_ptr<Sock>(int, int, int)> CreateSock;
275277

276278
/**
277279
* Create a socket and try to connect to the specified service.

src/test/fuzz/fuzz.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,9 @@ void ResetCoverageCounters() {}
101101

102102
void initialize()
103103
{
104-
// Terminate immediately if a fuzzing harness ever tries to create a TCP socket.
105-
CreateSock = [](const sa_family_t&) -> std::unique_ptr<Sock> { std::terminate(); };
104+
// Terminate immediately if a fuzzing harness ever tries to create a socket.
105+
// Individual tests can override this by pointing CreateSock to a mocked alternative.
106+
CreateSock = [](int, int, int) -> std::unique_ptr<Sock> { std::terminate(); };
106107

107108
// Terminate immediately if a fuzzing harness ever tries to perform a DNS lookup.
108109
g_dns_lookup = [](const std::string& name, bool allow_lookup) {

src/test/fuzz/i2p.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ FUZZ_TARGET(i2p, .init = initialize_i2p)
2727

2828
// Mock CreateSock() to create FuzzedSock.
2929
auto CreateSockOrig = CreateSock;
30-
CreateSock = [&fuzzed_data_provider](const sa_family_t&) {
30+
CreateSock = [&fuzzed_data_provider](int, int, int) {
3131
return std::make_unique<FuzzedSock>(fuzzed_data_provider);
3232
};
3333

src/test/i2p_tests.cpp

+4-9
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,14 @@ class EnvTestingSetup : public BasicTestingSetup
3939

4040
private:
4141
const BCLog::Level m_prev_log_level;
42-
const std::function<std::unique_ptr<Sock>(const sa_family_t&)> m_create_sock_orig;
42+
const decltype(CreateSock) m_create_sock_orig;
4343
};
4444

4545
BOOST_FIXTURE_TEST_SUITE(i2p_tests, EnvTestingSetup)
4646

4747
BOOST_AUTO_TEST_CASE(unlimited_recv)
4848
{
49-
// Mock CreateSock() to create MockSock.
50-
CreateSock = [](const sa_family_t&) {
49+
CreateSock = [](int, int, int) {
5150
return std::make_unique<StaticContentsSock>(std::string(i2p::sam::MAX_MSG_SIZE + 1, 'a'));
5251
};
5352

@@ -69,7 +68,7 @@ BOOST_AUTO_TEST_CASE(unlimited_recv)
6968
BOOST_AUTO_TEST_CASE(listen_ok_accept_fail)
7069
{
7170
size_t num_sockets{0};
72-
CreateSock = [&num_sockets](const sa_family_t&) {
71+
CreateSock = [&num_sockets](int, int, int) {
7372
// clang-format off
7473
++num_sockets;
7574
// First socket is the control socket for creating the session.
@@ -133,9 +132,7 @@ BOOST_AUTO_TEST_CASE(listen_ok_accept_fail)
133132

134133
BOOST_AUTO_TEST_CASE(damaged_private_key)
135134
{
136-
const auto CreateSockOrig = CreateSock;
137-
138-
CreateSock = [](const sa_family_t&) {
135+
CreateSock = [](int, int, int) {
139136
return std::make_unique<StaticContentsSock>("HELLO REPLY RESULT=OK VERSION=3.1\n"
140137
"SESSION STATUS RESULT=OK DESTINATION=\n");
141138
};
@@ -172,8 +169,6 @@ BOOST_AUTO_TEST_CASE(damaged_private_key)
172169
BOOST_CHECK(!session.Connect(CService{}, conn, proxy_error));
173170
}
174171
}
175-
176-
CreateSock = CreateSockOrig;
177172
}
178173

179174
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)