diff --git a/Cargo.toml b/Cargo.toml index f179d0f..b2bf291 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "rust-sctp" -version = "0.0.5" +version = "0.0.6" description = "High level SCTP networking library" repository = "https://github.com/phsym/rust-sctp" documentation = "http://phsym.github.io/rust-sctp" diff --git a/src/lib.rs b/src/lib.rs index e3335ad..a79471d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,6 +19,10 @@ use std::net::{ToSocketAddrs, SocketAddr, Shutdown}; #[cfg(target_os="linux")] use std::os::unix::io::{AsRawFd, RawFd, FromRawFd}; + +#[cfg(target_os="linux")] +pub mod mio_unix; + #[cfg(target_os="windows")] use std::os::windows::io::{AsRawHandle, RawHandle, FromRawHandle}; @@ -60,9 +64,9 @@ impl SctpStream { /// Create a new stream by connecting it to a remote endpoint pub fn connect(address: A) -> Result { - let raw_addr = try!(SocketAddr::from_addr(&address)); - let sock = try!(SctpSocket::new(raw_addr.family(), SOCK_STREAM)); - try!(sock.connect(raw_addr)); + let raw_addr = SocketAddr::from_addr(&address)?; + let sock = SctpSocket::new(raw_addr.family(), SOCK_STREAM)?; + sock.connect(raw_addr)?; return Ok(SctpStream(sock)); } @@ -72,26 +76,32 @@ impl SctpStream { let mut vec = Vec::with_capacity(addresses.len()); let mut family = AF_INET; for address in addresses { - let a = try!(SocketAddr::from_addr(address)); + let a = SocketAddr::from_addr(address)?; if a.family() == AF_INET6 { family = AF_INET6; } vec.push(a); } - let sock = try!(SctpSocket::new(family, SOCK_STREAM)); - try!(sock.connectx(&vec)); + let sock = SctpSocket::new(family, SOCK_STREAM)?; + sock.connectx(&vec)?; return Ok(SctpStream(sock)); } /// Send bytes on the specified SCTP stream. On success, returns the /// quantity of bytes read pub fn sendmsg(&self, msg: &[u8], stream: u16) -> Result { - return self.0.sendmsg::(msg, None, stream, 0); + return self.0.sendmsg::(msg, None, 0, stream, 0); + } + + /// Send bytes on the specified SCTP stream. On success, returns the + /// quantity of bytes read + pub fn sendmsg_ppid(&self, msg: &[u8], ppid: u32, stream: u16) -> Result { + return self.0.sendmsg::(msg, None, ppid, stream, 0); } /// Read bytes. On success, return a tuple with the quantity of /// bytes received and the stream they were recived on pub fn recvmsg(&self, msg: &mut [u8]) -> Result<(usize, u16)> { - let (size, stream, _) = try!(self.0.recvmsg(msg)); + let (size, stream, _) = self.0.recvmsg(msg)?; return Ok((size, stream)); } @@ -118,7 +128,7 @@ impl SctpStream { /// Verify if SCTP_NODELAY option is activated for this socket pub fn has_nodelay(&self) -> Result { - let val: libc::c_int = try!(self.0.sctp_opt_info(sctp_sys::SCTP_NODELAY, 0)); + let val: libc::c_int = self.0.sctp_opt_info(sctp_sys::SCTP_NODELAY, 0)?; return Ok(val == 1); } @@ -129,8 +139,8 @@ impl SctpStream { } /// Get the socket buffer size for the direction specified by `dir` - pub fn get_buffer_size(&self, dir: SoDirection) -> Result<(usize)> { - let val: u32 = try!(self.0.getsockopt(SOL_SOCKET, dir.buffer_opt())); + pub fn get_buffer_size(&self, dir: SoDirection) -> Result { + let val: u32 = self.0.getsockopt(SOL_SOCKET, dir.buffer_opt())?; return Ok(val as usize); } @@ -144,7 +154,7 @@ impl SctpStream { /// Try to clone the SctpStream. On success, returns a new stream /// wrapping a new socket handler pub fn try_clone(&self) -> Result { - return Ok(SctpStream(try!(self.0.try_clone()))); + return Ok(SctpStream(self.0.try_clone()?)); } } @@ -200,10 +210,10 @@ impl SctpEndpoint { /// Create a one-to-many SCTP endpoint bound to a single address pub fn bind(address: A) -> Result { - let raw_addr = try!(SocketAddr::from_addr(&address)); - let sock = try!(SctpSocket::new(raw_addr.family(), SOCK_SEQPACKET)); - try!(sock.bind(raw_addr)); - try!(sock.listen(-1)); + let raw_addr = SocketAddr::from_addr(&address)?; + let sock = SctpSocket::new(raw_addr.family(), SOCK_SEQPACKET)?; + sock.bind(raw_addr)?; + sock.listen(-1)?; return Ok(SctpEndpoint(sock)); } @@ -213,14 +223,14 @@ impl SctpEndpoint { let mut vec = Vec::with_capacity(addresses.len()); let mut family = AF_INET; for address in addresses { - let a = try!(SocketAddr::from_addr(address)); + let a = SocketAddr::from_addr(address)?; if a.family() == AF_INET6 { family = AF_INET6; } vec.push(a); } - let sock = try!(SctpSocket::new(family, SOCK_SEQPACKET)); - try!(sock.bindx(&vec, BindOp::AddAddr)); - try!(sock.listen(-1)); + let sock = SctpSocket::new(family, SOCK_SEQPACKET)?; + sock.bindx(&vec, BindOp::AddAddr)?; + sock.listen(-1)?; return Ok(SctpEndpoint(sock)); } @@ -234,7 +244,7 @@ impl SctpEndpoint { /// Send data in Sctp style, to the provided address on the stream `stream`. /// On success, returns the quantity on bytes sent pub fn send_to(&self, msg: &mut [u8], address: A, stream: u16) -> Result { - return self.0.sendmsg(msg, Some(address), stream, 0); + return self.0.sendmsg(msg, Some(address), 0, stream, 0); } /// Get local socket addresses to which this socket is bound @@ -255,7 +265,7 @@ impl SctpEndpoint { /// Verify if SCTP_NODELAY option is activated for this socket pub fn has_nodelay(&self) -> Result { - let val: libc::c_int = try!(self.0.sctp_opt_info(sctp_sys::SCTP_NODELAY, 0)); + let val: libc::c_int = self.0.sctp_opt_info(sctp_sys::SCTP_NODELAY, 0)?; return Ok(val == 1); } @@ -266,8 +276,8 @@ impl SctpEndpoint { } /// Get the socket buffer size for the direction specified by `dir` - pub fn get_buffer_size(&self, dir: SoDirection) -> Result<(usize)> { - let val: u32 = try!(self.0.getsockopt(SOL_SOCKET, dir.buffer_opt())); + pub fn get_buffer_size(&self, dir: SoDirection) -> Result { + let val: u32 = self.0.getsockopt(SOL_SOCKET, dir.buffer_opt())?; return Ok(val as usize); } @@ -280,7 +290,7 @@ impl SctpEndpoint { /// Try to clone this socket pub fn try_clone(&self) -> Result { - return Ok(SctpEndpoint(try!(self.0.try_clone()))); + return Ok(SctpEndpoint(self.0.try_clone()?)); } } @@ -336,10 +346,10 @@ impl SctpListener { /// Create a listener bound to a single address pub fn bind(address: A) -> Result { - let raw_addr = try!(SocketAddr::from_addr(&address)); - let sock = try!(SctpSocket::new(raw_addr.family(), SOCK_STREAM)); - try!(sock.bind(raw_addr)); - try!(sock.listen(-1)); + let raw_addr = SocketAddr::from_addr(&address)?; + let sock = SctpSocket::new(raw_addr.family(), SOCK_STREAM)?; + sock.bind(raw_addr)?; + sock.listen(-1)?; return Ok(SctpListener(sock)); } @@ -349,20 +359,20 @@ impl SctpListener { let mut vec = Vec::with_capacity(addresses.len()); let mut family = AF_INET; for address in addresses { - let a = try!(SocketAddr::from_addr(address)); + let a = SocketAddr::from_addr(address)?; if a.family() == AF_INET6 { family = AF_INET6; } vec.push(a); } - let sock = try!(SctpSocket::new(family, SOCK_STREAM)); - try!(sock.bindx(&vec, BindOp::AddAddr)); - try!(sock.listen(-1)); + let sock = SctpSocket::new(family, SOCK_STREAM)?; + sock.bindx(&vec, BindOp::AddAddr)?; + sock.listen(-1)?; return Ok(SctpListener(sock)); } /// Accept a new connection pub fn accept(&self) -> Result<(SctpStream, SocketAddr)> { - let (sock, addr) = try!(self.0.accept()); + let (sock, addr) = self.0.accept()?; return Ok((SctpStream(sock), addr)); } @@ -385,7 +395,7 @@ impl SctpListener { /// Try to clone this listener pub fn try_clone(&self) -> Result { - return Ok(SctpListener(try!(self.0.try_clone()))); + return Ok(SctpListener(self.0.try_clone()?)); } } diff --git a/src/mio_unix.rs b/src/mio_unix.rs new file mode 100644 index 0000000..c099605 --- /dev/null +++ b/src/mio_unix.rs @@ -0,0 +1,135 @@ +#[macro_export] +// copy from mio::sys::unix::mod +// https://github.com/faern/mio/blob/master/src/sys/unix/mod.rs +/// Helper macro to execute a system call that returns an `io::Result`. +// +// Macro must be defined before any modules that uses them. +#[allow(unused_macros)] +macro_rules! syscall { + ($fn: ident ( $($arg: expr),* $(,)* ) ) => {{ + let res = unsafe { libc::$fn($($arg, )*) }; + if res == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res) + } + }}; +} + +#[macro_export] +#[allow(unused_macros)] +macro_rules! sctp_syscall { + ($fn: ident ( $($arg: expr),* $(,)* ) ) => {{ + let res = unsafe { sctp_sys::$fn($($arg, )*) }; + if res == -1 { + Err(std::io::Error::last_os_error()) + } else { + Ok(res) + } + }}; +} + +use std::mem; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; + +// copy from mio::sys::unix::net +// https://github.com/faern/mio/blob/master/src/sys/unix/net.rs +/// A type with the same memory layout as `libc::sockaddr`. Used in converting Rust level +/// SocketAddr* types into their system representation. The benefit of this specific +/// type over using `libc::sockaddr_storage` is that this type is exactly as large as it +/// needs to be and not a lot larger. And it can be initialized cleaner from Rust. +#[repr(C)] +pub(crate) union SocketAddrCRepr { + v4: libc::sockaddr_in, + v6: libc::sockaddr_in6, +} + +impl SocketAddrCRepr { + pub(crate) fn as_ptr(&self) -> *const libc::sockaddr { + self as *const _ as *const libc::sockaddr + } +} + +/// Converts a Rust `SocketAddr` into the system representation. +pub(crate) fn socket_addr(addr: &SocketAddr) -> (SocketAddrCRepr, libc::socklen_t) { + match addr { + SocketAddr::V4(ref addr) => { + // `s_addr` is stored as BE on all machine and the array is in BE order. + // So the native endian conversion method is used so that it's never swapped. + let sin_addr = libc::in_addr { s_addr: u32::from_ne_bytes(addr.ip().octets()) }; + + let sockaddr_in = libc::sockaddr_in { + sin_family: libc::AF_INET as libc::sa_family_t, + sin_port: addr.port().to_be(), + sin_addr, + sin_zero: [0; 8], + #[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "ios", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd" + ))] + sin_len: 0, + }; + + let sockaddr = SocketAddrCRepr { v4: sockaddr_in }; + (sockaddr, mem::size_of::() as libc::socklen_t) + } + SocketAddr::V6(ref addr) => { + let sockaddr_in6 = libc::sockaddr_in6 { + sin6_family: libc::AF_INET6 as libc::sa_family_t, + sin6_port: addr.port().to_be(), + sin6_addr: libc::in6_addr { s6_addr: addr.ip().octets() }, + sin6_flowinfo: addr.flowinfo(), + sin6_scope_id: addr.scope_id(), + #[cfg(any( + target_os = "dragonfly", + target_os = "freebsd", + target_os = "ios", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd" + ))] + sin6_len: 0, + #[cfg(any(target_os = "solaris", target_os = "illumos"))] + __sin6_src_id: 0, + }; + + let sockaddr = SocketAddrCRepr { v6: sockaddr_in6 }; + (sockaddr, mem::size_of::() as libc::socklen_t) + } + } +} + +/// Converts a `libc::sockaddr` compatible struct into a native Rust `SocketAddr`. +/// +/// # Safety +/// +/// `storage` must have the `ss_family` field correctly initialized. +/// `storage` must be initialised to a `sockaddr_in` or `sockaddr_in6`. +pub(crate) unsafe fn to_socket_addr( + storage: *const libc::sockaddr_storage, +) -> std::io::Result { + match (*storage).ss_family as libc::c_int { + libc::AF_INET => { + // Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in. + let addr: &libc::sockaddr_in = &*(storage as *const libc::sockaddr_in); + let ip = Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()); + let port = u16::from_be(addr.sin_port); + Ok(SocketAddr::V4(SocketAddrV4::new(ip, port))) + }, + libc::AF_INET6 => { + // Safety: if the ss_family field is AF_INET6 then storage must be a sockaddr_in6. + let addr: &libc::sockaddr_in6 = &*(storage as *const libc::sockaddr_in6); + let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr); + let port = u16::from_be(addr.sin6_port); + Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, addr.sin6_flowinfo, addr.sin6_scope_id))) + }, + _ => Err(std::io::ErrorKind::InvalidInput.into()), + } +} +// +// CODE COPY ENDS HERE +// diff --git a/src/sctpsock.rs b/src/sctpsock.rs index 3be3394..36a7787 100644 --- a/src/sctpsock.rs +++ b/src/sctpsock.rs @@ -3,11 +3,20 @@ use libc; use sctp_sys; use std::io::{Result, Error, ErrorKind, Read, Write}; -use std::net::{ToSocketAddrs, SocketAddr, Shutdown}; -use std::mem::{transmute, size_of, zeroed}; +use std::net::{ToSocketAddrs, SocketAddr, SocketAddrV4, SocketAddrV6, Ipv4Addr, Ipv6Addr, Shutdown}; +use std::mem::{size_of, MaybeUninit}; + +// import macros from lib +#[cfg(target_os="linux")] +use crate::{syscall, sctp_syscall}; + +// import sockaddr helpers from lib +#[cfg(target_os="linux")] +use mio_unix::{socket_addr, to_socket_addr}; #[cfg(target_os="linux")] use std::os::unix::io::{AsRawFd, RawFd, FromRawFd}; + #[cfg(target_os="windows")] use std::os::windows::io::{AsRawHandle, RawHandle, FromRawHandle}; @@ -16,8 +25,8 @@ mod win { use std::io::{Result, Error}; use libc; use winapi; - - pub use ws2_32::{socket, connect, bind, listen, accept, recv, send, shutdown, setsockopt, closesocket}; + + pub use ws2_32::{socket, closesocket}; pub use winapi::{SOCKADDR as sockaddr, SOCKADDR_IN as sockaddr_in, sockaddr_in6, socklen_t, AF_INET, AF_INET6, SOCKET}; pub type RWlen = i32; @@ -37,8 +46,8 @@ mod linux { use std::io::{Result, Error}; use libc; - pub use libc::{sockaddr, sockaddr_in, sockaddr_in6, socklen_t, AF_INET,AF_INET6, socket, connect, bind, listen, accept, recv, send, shutdown, setsockopt, SHUT_RDWR, SHUT_RD, SHUT_WR}; - + pub use libc::{sockaddr, sockaddr_in, sockaddr_in6, socklen_t, AF_INET, AF_INET6, socket, SHUT_RDWR, SHUT_RD, SHUT_WR, EINPROGRESS}; + pub type SOCKET = libc::c_int; pub type RWlen = libc::size_t; @@ -57,14 +66,6 @@ use self::win::*; #[cfg(target_os="linux")] use self::linux::*; -// XXX: Until getsockopt is available in libc crate -extern "system" { - #[cfg(target_os="linux")] - fn getsockopt(sock: SOCKET, level: libc::c_int, optname: libc::c_int, optval: *mut libc::c_void, optlen: *mut socklen_t) -> libc::c_int; - #[cfg(target_os="windows")] - fn getsockopt(sock: SOCKET, level: libc::c_int, optname: libc::c_int, optval: *mut libc::c_char, optlen: *mut libc::c_int) -> libc::c_int; -} - /// SCTP bind operation #[allow(dead_code)] pub enum BindOp { @@ -104,23 +105,14 @@ impl SctpAddrType { } } - /// Manage low level socket address structure pub trait RawSocketAddr: Sized { /// Get the address family for this socket address fn family(&self) -> i32; - /// Get the raw socket address structure size - fn addr_len(&self) -> socklen_t; - /// Create from a raw socket address unsafe fn from_raw_ptr(addr: *const sockaddr, len: socklen_t) -> Result; - /// Return an immutable pointer to the raw socket address structure - fn as_ptr(&self) -> *const sockaddr; - - /// Return a mutable pointer to the raw socket address structure - fn as_mut_ptr(&mut self) -> *mut sockaddr; /// Create from a ToSocketAddrs fn from_addr(address: A) -> Result; @@ -134,41 +126,32 @@ impl RawSocketAddr for SocketAddr { }; } - fn addr_len(&self) -> socklen_t { - return match *self { - SocketAddr::V4(..) => size_of::(), - SocketAddr::V6(..) => size_of::() - } as socklen_t; - } - unsafe fn from_raw_ptr(addr: *const sockaddr, len: socklen_t) -> Result { if len < size_of::() as socklen_t { return Err(Error::new(ErrorKind::InvalidInput, "Invalid address length")); } - return match (*addr).sa_family as libc::c_int { - AF_INET if len >= size_of::() as socklen_t => Ok(SocketAddr::V4(transmute(*(addr as *const sockaddr_in)))), - AF_INET6 if len >= size_of::() as socklen_t => Ok(SocketAddr::V6(transmute(*(addr as *const sockaddr_in6)))), - _ => Err(Error::new(ErrorKind::InvalidInput, "Invalid socket address")) - }; - } - - fn as_ptr(&self) -> *const sockaddr { - return match *self { - SocketAddr::V4(ref a) => unsafe { transmute(a) }, - SocketAddr::V6(ref a) => unsafe { transmute(a) } - }; - } - - fn as_mut_ptr(&mut self) -> *mut sockaddr { - return match *self { - SocketAddr::V4(ref mut a) => unsafe { transmute(a) }, - SocketAddr::V6(ref mut a) => unsafe { transmute(a) } - }; + match (*addr).sa_family as libc::c_int { + AF_INET => { + let in_addr = std::ptr::read(addr as *const sockaddr_in); + let ip_addr = Ipv4Addr::from(in_addr.sin_addr.s_addr.to_be()); + let socket_addr_v4 = SocketAddrV4::new(ip_addr, u16::from_be(in_addr.sin_port)); + return Ok(SocketAddr::V4(socket_addr_v4)); + } + AF_INET6 if len >= size_of::() as socklen_t => { + let in6_addr = std::ptr::read(addr as *const sockaddr_in6); + let ip6_addr = Ipv6Addr::from(in6_addr.sin6_addr.s6_addr); + let socket_addr_v6 = SocketAddrV6::new(ip6_addr, u16::from_be(in6_addr.sin6_port), in6_addr.sin6_flowinfo, in6_addr.sin6_scope_id); + return Ok(SocketAddr::V6(socket_addr_v6)); + } + _ => Err(Error::new(ErrorKind::InvalidInput, "Invalid socket address")), + } } fn from_addr(address: A) -> Result { - return try!(address.to_socket_addrs().or(Err(Error::new(ErrorKind::InvalidInput, "Address is not valid")))) - .next().ok_or(Error::new(ErrorKind::InvalidInput, "Address is not valid")); + return address + .to_socket_addrs()? + .next() + .ok_or_else(|| Error::new(ErrorKind::InvalidInput, "Address is not valid")); } } @@ -180,128 +163,133 @@ impl SctpSocket { /// Create a new SCTP socket pub fn new(family: libc::c_int, sock_type: libc::c_int) -> Result { unsafe { - return Ok(SctpSocket(try!(check_socket(socket(family, sock_type, sctp_sys::IPPROTO_SCTP))))); + return Ok(SctpSocket(check_socket(socket(family, sock_type, sctp_sys::IPPROTO_SCTP))?)); } } /// Connect the socket to `address` pub fn connect(&self, address: A) -> Result<()> { - let raw_addr = try!(SocketAddr::from_addr(&address)); - unsafe { - return match connect(self.0, raw_addr.as_ptr(), raw_addr.addr_len()) { - 0 => Ok(()), - _ => Err(Error::last_os_error()) - }; + let addrobj = SocketAddr::from_addr(&address)?; + let (raw_addr, raw_addr_length) = socket_addr(&addrobj); + match syscall!(connect(self.0, raw_addr.as_ptr(), raw_addr_length)) { + Err(err) if err.raw_os_error() != Some(EINPROGRESS) => { + Err(err) + } + _ => { + Ok(()) + } } } /// Connect the socket to multiple addresses pub fn connectx(&self, addresses: &[A]) -> Result { if addresses.len() == 0 { return Err(Error::new(ErrorKind::InvalidInput, "No addresses given")); } - unsafe { - let buf: *mut u8 = libc::malloc((addresses.len() * size_of::()) as libc::size_t) as *mut u8; - if buf.is_null() { - return Err(Error::new(ErrorKind::Other, "Out of memory")); - } - let mut offset = 0isize; - for address in addresses { - let raw = try!(SocketAddr::from_addr(address)); - let len = raw.addr_len(); - std::ptr::copy_nonoverlapping(raw.as_ptr() as *mut u8, buf.offset(offset), len as usize); - offset += len as isize; - } + + let buf: *mut u8 = unsafe { libc::malloc((addresses.len() * size_of::()) as libc::size_t) as *mut u8 }; + if buf.is_null() { + return Err(Error::new(ErrorKind::Other, "Out of memory")); + } + let mut offset = 0isize; + for address in addresses { + let addrobj = SocketAddr::from_addr(&address)?; + let (raw_addr, raw_addr_length) = socket_addr(&addrobj); + unsafe { std::ptr::copy_nonoverlapping(raw_addr.as_ptr() as *mut u8, buf.offset(offset), raw_addr_length as usize) }; + offset += raw_addr_length as isize; + } - let mut assoc: sctp_sys::sctp_assoc_t = 0; - let ret = match sctp_sys::sctp_connectx(self.0, buf as *mut sockaddr, addresses.len() as i32, &mut assoc) { - 0 => Ok(assoc), - _ => Err(Error::last_os_error()), - }; - libc::free(buf as *mut libc::c_void); - return ret; + let mut assoc: sctp_sys::sctp_assoc_t = 0; + + match sctp_syscall!(sctp_connectx(self.0, buf as *mut sockaddr, addresses.len() as i32, &mut assoc)) + { + Err(err) => { unsafe{libc::free(buf as *mut libc::c_void)};Err(err) }, + Ok(_) => { unsafe{libc::free(buf as *mut libc::c_void)};Ok(assoc) }, } } /// Bind the socket to a single address pub fn bind(&self, address: A) -> Result<()> { - let raw_addr = try!(SocketAddr::from_addr(&address)); - unsafe { - return match bind(self.0, raw_addr.as_ptr(), raw_addr.addr_len()) { - 0 => Ok(()), - _ => Err(Error::last_os_error()) - }; - } + let addrobj = SocketAddr::from_addr(&address)?; + let (raw_addr, raw_addr_length) = socket_addr(&addrobj); + syscall!(bind(self.0, raw_addr.as_ptr(), raw_addr_length))?; + Ok(()) } /// Bind the socket on multiple addresses pub fn bindx(&self, addresses: &[A], op: BindOp) -> Result<()> { if addresses.len() == 0 { return Err(Error::new(ErrorKind::InvalidInput, "No addresses given")); } - unsafe { - let buf: *mut u8 = libc::malloc((addresses.len() * size_of::()) as libc::size_t) as *mut u8; - if buf.is_null() { - return Err(Error::new(ErrorKind::Other, "Out of memory")); - } - let mut offset = 0isize; - for address in addresses { - let raw = try!(SocketAddr::from_addr(address)); - let len = raw.addr_len(); - std::ptr::copy_nonoverlapping(raw.as_ptr() as *mut u8, buf.offset(offset), len as usize); - offset += len as isize; - } - let ret = match sctp_sys::sctp_bindx(self.0, buf as *mut sockaddr, addresses.len() as i32, op.flag()) { - 0 => Ok(()), - _ => Err(Error::last_os_error()) - }; - libc::free(buf as *mut libc::c_void); - return ret; + let buf: *mut u8 = unsafe { libc::malloc((addresses.len() * size_of::()) as libc::size_t) as *mut u8 }; + if buf.is_null() { + return Err(Error::new(ErrorKind::Other, "Out of memory")); + } + let mut offset = 0isize; + for address in addresses { + let addrobj = SocketAddr::from_addr(&address)?; + let (raw_addr, raw_addr_length) = socket_addr(&addrobj); + unsafe { std::ptr::copy_nonoverlapping(raw_addr.as_ptr() as *mut u8, buf.offset(offset), raw_addr_length as usize) }; + offset += raw_addr_length as isize; + } + + match sctp_syscall!(sctp_bindx(self.0, buf as *mut sockaddr, addresses.len() as i32, op.flag())) + { + Err(err) => { unsafe{libc::free(buf as *mut libc::c_void)};Err(err) }, + Ok(_) => { Ok(()) }, } } /// Listen pub fn listen(&self, backlog: libc::c_int) -> Result<()> { - unsafe { - return match listen(self.0, backlog) { - 0 => Ok(()), - _ => Err(Error::last_os_error()) - }; - } + syscall!(listen(self.0, backlog))?; + Ok(()) } /// Accept connection to this socket pub fn accept(&self) -> Result<(SctpSocket, SocketAddr)> { - let mut addr: sockaddr_in6 = unsafe { std::mem::zeroed() }; - let mut len: socklen_t = size_of::() as socklen_t; - unsafe { - let addr_ptr: *mut sockaddr = transmute(&mut addr); - let sock = try!(check_socket(accept(self.0, addr_ptr, &mut len))); - let addr = try!(SocketAddr::from_raw_ptr(addr_ptr, len)); - return Ok((SctpSocket(sock), addr)); - } + // prepare buffer to store client address + // TODO: this will not be compatible with windows environments as we use libc structs + let mut addr_storage: MaybeUninit = MaybeUninit::uninit(); + let mut addr_storage_length = size_of::() as libc::socklen_t; + + let stream = { + syscall!(accept(self.0, addr_storage.as_mut_ptr() as *mut _, &mut addr_storage_length)) + .map(|socket| SctpSocket(socket)) + }?; + + unsafe { to_socket_addr(addr_storage.as_ptr()) }.map(|addr| (stream, addr)) } fn addrs(&self, id: sctp_sys::sctp_assoc_t, what: SctpAddrType) -> Result> { unsafe { - let mut addrs: *mut u8 = std::ptr::null_mut(); - let len = what.get(self.0, id, transmute(&mut addrs)); + // Initialize a pointer that will hold the addresses + let mut addrs: *mut sockaddr = std::ptr::null_mut(); + let len = what.get(self.0, id, &mut addrs); + if len < 0 { return Err(Error::new(ErrorKind::Other, "Cannot retrieve addresses")); } if len == 0 { return Err(Error::new(ErrorKind::AddrNotAvailable, "Socket is unbound")); } + // Prepare a vector to hold the addresses let mut vec = Vec::with_capacity(len as usize); let mut offset = 0; for _ in 0..len { - let sockaddr = addrs.offset(offset) as *const sockaddr; - let len = match (*sockaddr).sa_family as i32 { - AF_INET => size_of::(), - AF_INET6 => size_of::(), - f => { - what.free(addrs as *mut sockaddr); - return Err(Error::new(ErrorKind::Other, format!("Unsupported address family : {}", f))); + let sockaddr_ptr = addrs.offset(offset) as *const sockaddr; + let family = (*sockaddr_ptr).sa_family as i32; + let sockaddr_len = match family { + AF_INET => size_of::() as socklen_t, + AF_INET6 => size_of::() as socklen_t, + _ => { + what.free(addrs); + return Err(Error::new(ErrorKind::Other, format!("Unsupported address family : {}", family))); } - } as socklen_t; - vec.push(try!(SocketAddr::from_raw_ptr(sockaddr, len))); - offset += len as isize; + }; + + // convert raw pointer to `SocketAddr` + vec.push(SocketAddr::from_raw_ptr(sockaddr_ptr, sockaddr_len)?); + offset += sockaddr_len as isize; } - what.free(addrs as *mut sockaddr); + + // free allocated addresses + what.free(addrs); + return Ok(vec); } } @@ -318,23 +306,23 @@ impl SctpSocket { /// Receive data in TCP style. Only works for a connected one to one socket pub fn recv(&mut self, buf: &mut [u8]) -> Result { - unsafe { - let len = buf.len() as RWlen; - return match recv(self.0, transmute(buf.as_mut_ptr()), len, 0) { - res if res >= 0 => Ok(res as usize), - _ => Err(Error::last_os_error()) - }; + let len = buf.len() as RWlen; + + match syscall!(recv(self.0, buf.as_mut_ptr() as *mut libc::c_void, len, 0)) + { + Err(err) => Err(err), + Ok(recvlen) => Ok(recvlen as usize), } } - /// Send data in TCP style. Only works for a connected one to one socket + /// Send data in TCP style. Only wmmatorks for a connected one to one socket pub fn send(&mut self, buf: &[u8]) -> Result { - unsafe { - let len = buf.len() as RWlen; - return match send(self.0, transmute(buf.as_ptr()), len, 0) { - res if res >= 0 => Ok(res as usize), - _ => Err(Error::last_os_error()) - }; + let len = buf.len() as RWlen; + + match syscall!(send(self.0, buf.as_ptr() as *const libc::c_void, len, 0)) + { + Err(err) => Err(err), + Ok(recvlen) => Ok(recvlen as usize), } } @@ -343,35 +331,57 @@ impl SctpSocket { /// the socket address used by the peer to send the data pub fn recvmsg(&self, msg: &mut [u8]) -> Result<(usize, u16, SocketAddr)> { let len = msg.len() as libc::size_t; - let mut addr: sockaddr_in6 = unsafe { std::mem::zeroed() }; - let mut addr_len: socklen_t = size_of::() as socklen_t; + let mut flags: libc::c_int = 0; - unsafe { - let addr_ptr: *mut sockaddr = transmute(&mut addr); - let mut info: sctp_sys::sctp_sndrcvinfo = std::mem::zeroed(); - return match sctp_sys::sctp_recvmsg(self.0, msg.as_mut_ptr() as *mut libc::c_void, len, addr_ptr, &mut addr_len, &mut info, &mut flags) { - res if res > 0 => Ok((res as usize, info.sinfo_stream, try!(SocketAddr::from_raw_ptr(addr_ptr, addr_len)))), - _ => Err(Error::last_os_error()) - }; - } + let mut info: sctp_sys::sctp_sndrcvinfo = unsafe { std::mem::zeroed() }; + + // prepare buffer to store client address + // TODO: this will not be compatible with windows environments as we use libc structs + let mut addr_storage: MaybeUninit = MaybeUninit::uninit(); + let mut addr_storage_length = size_of::() as libc::socklen_t; + + let recvlen = sctp_syscall!(sctp_recvmsg( + self.0, + msg.as_mut_ptr() as *mut _, + len, + addr_storage.as_mut_ptr() as *mut _, + &mut addr_storage_length, + &mut info, + &mut flags + ))?; + + unsafe { to_socket_addr(addr_storage.as_ptr()) }.map(|addr| (recvlen as usize, info.sinfo_stream, addr)) } /// Send data in Sctp style, to the provided address (may be `None` if the socket is connected), on the stream `stream`, with the TTL `ttl`. /// On success, returns the quantity on bytes sent - pub fn sendmsg(&self, msg: &[u8], address: Option, stream: u16, ttl: libc::c_ulong) -> Result { + pub fn sendmsg(&self, msg: &[u8], address: Option, ppid: u32, stream: u16, ttl: libc::c_ulong) -> Result { let len = msg.len() as libc::size_t; let (raw_addr, addr_len) = match address { Some(a) => { - let mut addr = try!(SocketAddr::from_addr(a)); - (addr.as_mut_ptr(), addr.addr_len()) + let addrobj = SocketAddr::from_addr(a)?; + let (addr_c_struct, addr_c_struct_len) = socket_addr(&addrobj); + (addr_c_struct.as_ptr() as *mut sockaddr, addr_c_struct_len) }, None => (std::ptr::null_mut(), 0) }; - unsafe { - return match sctp_sys::sctp_sendmsg(self.0, msg.as_ptr() as *const libc::c_void, len, raw_addr, addr_len, 0, 0, stream, ttl, 0) { - res if res > 0 => Ok(res as usize), - _ => Err(Error::last_os_error()) - }; + let ppid = ppid.to_be(); + + match sctp_syscall!(sctp_sendmsg( + self.0, + msg.as_ptr() as *const libc::c_void, + len, + raw_addr, + addr_len, + ppid as libc::c_ulong, + 0, + stream, + ttl, + 0 + )) + { + Err(err) => Err(err), + Ok(sendlen) => Ok(sendlen as usize), } } @@ -382,51 +392,55 @@ impl SctpSocket { Shutdown::Write => SHUT_WR, Shutdown::Both => SHUT_RDWR }; - return match unsafe { shutdown(self.0, side) } { - 0 => Ok(()), - _ => Err(Error::last_os_error()) - }; + match syscall!(shutdown(self.0, side)) + { + Err(err) => Err(err), + Ok(_) => Ok(()), + } } /// Set socket option pub fn setsockopt(&self, level: libc::c_int, optname: libc::c_int, optval: &T) -> Result<()> { - unsafe { - return match setsockopt(self.0, level, optname, transmute(optval), size_of::() as socklen_t) { - 0 => Ok(()), - _ => Err(Error::last_os_error()) - }; + let optval_ptr = optval as *const T as *const libc::c_void; + let optlen = size_of::() as socklen_t; + + match syscall!(setsockopt(self.0, level, optname, optval_ptr, optlen)) + { + Err(err) => Err(err), + Ok(_) => Ok(()), } } /// Get socket option pub fn getsockopt(&self, level: libc::c_int, optname: libc::c_int) -> Result { - unsafe { - let mut val: T = zeroed(); - let mut len = size_of::() as socklen_t; - return match getsockopt(self.0, level, optname, transmute(&mut val), &mut len) { - 0 => Ok(val), - _ => Err(Error::last_os_error()) - }; + let mut val: T = unsafe { std::mem::zeroed() }; + + let mut len = size_of::() as socklen_t; + + match syscall!(getsockopt(self.0, level, optname, &mut val as *mut T as *mut libc::c_void, &mut len)) + { + Err(err) => Err(err), + Ok(_) => Ok(val), } } /// Get SCTP socket option pub fn sctp_opt_info(&self, optname: libc::c_int, assoc: sctp_sys::sctp_assoc_t) -> Result { - unsafe { - let mut val: T = zeroed(); - let mut len = size_of::() as socklen_t; - return match sctp_sys::sctp_opt_info(self.0, assoc, optname, transmute(&mut val), &mut len) { - 0 => Ok(val), - _ => Err(Error::last_os_error()) - }; + let mut val: T = unsafe { std::mem::zeroed() }; + let mut len = size_of::() as socklen_t; + + match sctp_syscall!(sctp_opt_info(self.0, assoc, optname, &mut val as *mut T as *mut libc::c_void, &mut len)) + { + Err(err) => Err(err), + Ok(_) => Ok(val), } } /// Try to clone this socket pub fn try_clone(&self) -> Result { - unsafe { - let new_sock = try!(check_socket(libc::dup(self.0 as i32) as SOCKET)); - return Ok(SctpSocket(new_sock)); + match syscall!(dup(self.0 as i32)) { + Err(err) => Err(err), + Ok(new_sock) => Ok(SctpSocket(new_sock as SOCKET)), } } }