diff --git a/src/connection.rs b/src/connection.rs index 23a8d4b..aaa1aab 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -23,6 +23,7 @@ use std::fmt::{Debug, Display}; use std::hash::Hash; use std::io; use std::net::{Shutdown, TcpStream, ToSocketAddrs}; +use std::os::fd::IntoRawFd; use std::os::unix::io::AsRawFd; use std::time::Duration; @@ -62,6 +63,8 @@ pub trait NetConnection: NetStream + AsRawFd + Debug { fn remote_addr(&self) -> io::Result; fn local_addr(&self) -> io::Result; + #[cfg(feature = "nonblocking")] + fn set_tcp_keepalive(&mut self, keepalive: &socket2::TcpKeepalive) -> io::Result<()>; fn set_read_timeout(&mut self, dur: Option) -> io::Result<()>; fn set_write_timeout(&mut self, dur: Option) -> io::Result<()>; fn read_timeout(&self) -> io::Result>; @@ -108,6 +111,14 @@ impl NetConnection for TcpStream { fn local_addr(&self) -> io::Result { Ok(TcpStream::local_addr(self)?.into()) } + #[cfg(feature = "nonblocking")] + fn set_tcp_keepalive(&mut self, keepalive: &socket2::TcpKeepalive) -> io::Result<()> { + use std::os::fd::FromRawFd; + let socket = unsafe { socket2::Socket::from_raw_fd(self.as_raw_fd()) }; + socket.set_tcp_keepalive(keepalive)?; + let _ = socket.into_raw_fd(); // preventing from closing the socket + Ok(()) + } fn set_read_timeout(&mut self, dur: Option) -> io::Result<()> { TcpStream::set_read_timeout(self, dur) } @@ -240,6 +251,11 @@ impl NetConnection for socket2::Socket { .into()) } + #[cfg(feature = "nonblocking")] + fn set_tcp_keepalive(&mut self, keepalive: &socket2::TcpKeepalive) -> io::Result<()> { + socket2::Socket::set_tcp_keepalive(self, keepalive) + } + fn set_read_timeout(&mut self, dur: Option) -> io::Result<()> { socket2::Socket::set_read_timeout(self, dur) }