diff --git a/rust/bindings/bindings_helper.h b/rust/bindings/bindings_helper.h index 65b98831b97560..6581e94c681b92 100644 --- a/rust/bindings/bindings_helper.h +++ b/rust/bindings/bindings_helper.h @@ -9,12 +9,15 @@ #include #include #include +#include #include #include +#include #include #include #include #include +#include #include #include diff --git a/rust/kernel/net.rs b/rust/kernel/net.rs index fe415cb369d3ac..f4ab478b25e0a9 100644 --- a/rust/kernel/net.rs +++ b/rust/kernel/net.rs @@ -1,6 +1,186 @@ // SPDX-License-Identifier: GPL-2.0 -//! Networking. +//! Network subsystem. +//! +//! This module contains the kernel APIs related to networking that have been ported or wrapped in Rust. +//! +//! C header: [`include/linux/net.h`](../../../../include/linux/net.h) and related + +use crate::error::{code, Error}; +use core::cell::UnsafeCell; #[cfg(CONFIG_RUST_PHYLIB_ABSTRACTIONS)] pub mod phy; +pub mod addr; +pub mod ip; +pub mod socket; +pub mod tcp; +pub mod udp; + +/// The address family. +/// +/// See [`man 7 address families`](https://man7.org/linux/man-pages/man7/address_families.7.html) for more information. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum AddressFamily { + /// Unspecified address family. + Unspec = bindings::AF_UNSPEC as isize, + /// Local to host (pipes and file-domain). + Unix = bindings::AF_UNIX as isize, + /// Internetwork: UDP, TCP, etc. + Inet = bindings::AF_INET as isize, + /// Amateur radio AX.25. + Ax25 = bindings::AF_AX25 as isize, + /// IPX. + Ipx = bindings::AF_IPX as isize, + /// Appletalk DDP. + Appletalk = bindings::AF_APPLETALK as isize, + /// AX.25 packet layer protocol. + Netrom = bindings::AF_NETROM as isize, + /// Bridge link. + Bridge = bindings::AF_BRIDGE as isize, + /// ATM PVCs. + Atmpvc = bindings::AF_ATMPVC as isize, + /// X.25 (ISO-8208). + X25 = bindings::AF_X25 as isize, + /// IPv6. + Inet6 = bindings::AF_INET6 as isize, + /// ROSE protocol. + Rose = bindings::AF_ROSE as isize, + /// DECnet protocol. + Decnet = bindings::AF_DECnet as isize, + /// 802.2LLC project. + Netbeui = bindings::AF_NETBEUI as isize, + /// Firewall hooks. + Security = bindings::AF_SECURITY as isize, + /// Key management protocol. + Key = bindings::AF_KEY as isize, + /// Netlink. + Netlink = bindings::AF_NETLINK as isize, + /// Low-level packet interface. + Packet = bindings::AF_PACKET as isize, + /// Acorn Econet protocol. + Econet = bindings::AF_ECONET as isize, + /// ATM SVCs. + Atmsvc = bindings::AF_ATMSVC as isize, + /// RDS sockets. + Rds = bindings::AF_RDS as isize, + /// IRDA sockets. + Irda = bindings::AF_IRDA as isize, + /// Generic PPP. + Pppox = bindings::AF_PPPOX as isize, + /// Legacy WAN networks protocol. + Wanpipe = bindings::AF_WANPIPE as isize, + /// LLC protocol. + Llc = bindings::AF_LLC as isize, + /// Infiniband. + Ib = bindings::AF_IB as isize, + /// Multiprotocol label switching. + Mpls = bindings::AF_MPLS as isize, + /// Controller Area Network. + Can = bindings::AF_CAN as isize, + /// TIPC sockets. + Tipc = bindings::AF_TIPC as isize, + /// Bluetooth sockets. + Bluetooth = bindings::AF_BLUETOOTH as isize, + /// IUCV sockets. + Iucv = bindings::AF_IUCV as isize, + /// RxRPC sockets. + Rxrpc = bindings::AF_RXRPC as isize, + /// Modular ISDN protocol. + Isdn = bindings::AF_ISDN as isize, + /// Nokia cellular modem interface. + Phonet = bindings::AF_PHONET as isize, + /// IEEE 802.15.4 sockets. + Ieee802154 = bindings::AF_IEEE802154 as isize, + /// CAIF sockets. + Caif = bindings::AF_CAIF as isize, + /// Kernel crypto API + Alg = bindings::AF_ALG as isize, + /// VMware VSockets. + Vsock = bindings::AF_VSOCK as isize, + /// KCM sockets. + Kcm = bindings::AF_KCM as isize, + /// Qualcomm IPC router protocol. + Qipcrtr = bindings::AF_QIPCRTR as isize, + /// SMC sockets. + Smc = bindings::AF_SMC as isize, + /// Express Data Path sockets. + Xdp = bindings::AF_XDP as isize, +} + +impl From for isize { + fn from(family: AddressFamily) -> Self { + family as isize + } +} + +impl TryFrom for AddressFamily { + type Error = Error; + + fn try_from(value: isize) -> Result { + let val = value as u32; + match val { + bindings::AF_UNSPEC => Ok(Self::Unspec), + bindings::AF_UNIX => Ok(Self::Unix), + bindings::AF_INET => Ok(Self::Inet), + bindings::AF_AX25 => Ok(Self::Ax25), + bindings::AF_IPX => Ok(Self::Ipx), + bindings::AF_APPLETALK => Ok(Self::Appletalk), + bindings::AF_NETROM => Ok(Self::Netrom), + bindings::AF_BRIDGE => Ok(Self::Bridge), + bindings::AF_ATMPVC => Ok(Self::Atmpvc), + bindings::AF_X25 => Ok(Self::X25), + bindings::AF_INET6 => Ok(Self::Inet6), + bindings::AF_ROSE => Ok(Self::Rose), + bindings::AF_DECnet => Ok(Self::Decnet), + bindings::AF_NETBEUI => Ok(Self::Netbeui), + bindings::AF_SECURITY => Ok(Self::Security), + bindings::AF_KEY => Ok(Self::Key), + bindings::AF_NETLINK => Ok(Self::Netlink), + bindings::AF_PACKET => Ok(Self::Packet), + bindings::AF_ECONET => Ok(Self::Econet), + bindings::AF_ATMSVC => Ok(Self::Atmsvc), + bindings::AF_RDS => Ok(Self::Rds), + bindings::AF_IRDA => Ok(Self::Irda), + bindings::AF_PPPOX => Ok(Self::Pppox), + bindings::AF_WANPIPE => Ok(Self::Wanpipe), + bindings::AF_LLC => Ok(Self::Llc), + bindings::AF_IB => Ok(Self::Ib), + bindings::AF_MPLS => Ok(Self::Mpls), + bindings::AF_CAN => Ok(Self::Can), + bindings::AF_TIPC => Ok(Self::Tipc), + bindings::AF_BLUETOOTH => Ok(Self::Bluetooth), + bindings::AF_IUCV => Ok(Self::Iucv), + bindings::AF_RXRPC => Ok(Self::Rxrpc), + bindings::AF_ISDN => Ok(Self::Isdn), + bindings::AF_PHONET => Ok(Self::Phonet), + bindings::AF_IEEE802154 => Ok(Self::Ieee802154), + bindings::AF_CAIF => Ok(Self::Caif), + bindings::AF_ALG => Ok(Self::Alg), + bindings::AF_VSOCK => Ok(Self::Vsock), + bindings::AF_KCM => Ok(Self::Kcm), + bindings::AF_QIPCRTR => Ok(Self::Qipcrtr), + bindings::AF_SMC => Ok(Self::Smc), + bindings::AF_XDP => Ok(Self::Xdp), + _ => Err(code::EINVAL), + } + } +} + +/// Network namespace. +/// +/// Wraps the `net` struct. +#[repr(transparent)] +pub struct Namespace(UnsafeCell); + +/// The global network namespace. +/// +/// This is the default and initial namespace. +/// This function replaces the C `init_net` global variable. +pub fn init_net() -> &'static Namespace { + // SAFETY: `init_net` is a global variable and is always valid. + let ptr = unsafe { core::ptr::addr_of!(bindings::init_net) }; + // SAFETY: the address of `init_net` is always valid, always points to initialized memory, + // and is always aligned. + unsafe { &*(ptr.cast()) } +} diff --git a/rust/kernel/net/addr.rs b/rust/kernel/net/addr.rs new file mode 100644 index 00000000000000..e6b1ba7320db4b --- /dev/null +++ b/rust/kernel/net/addr.rs @@ -0,0 +1,1215 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Network address types. +//! +//! This module contains the types and APIs related to network addresses. +//! The methods and types of this API are inspired by the [Rust standard library's `std::net` module](https://doc.rust-lang.org/std/net/index.html), +//! but have been ported to use the kernel's C APIs. + +use crate::error::{code, Error, Result}; +use crate::net::{init_net, AddressFamily, Namespace}; +use crate::str::{CStr, CString}; +use crate::{c_str, fmt}; +use core::cmp::Ordering; +use core::fmt::{Debug, Display, Formatter}; +use core::hash::{Hash, Hasher}; +use core::mem::MaybeUninit; +use core::ptr; +use core::str::FromStr; + +/// An IPv4 address. +/// +/// Wraps a `struct in_addr`. +#[derive(Default, Copy, Clone)] +#[repr(transparent)] +pub struct Ipv4Addr(pub(crate) bindings::in_addr); + +impl Ipv4Addr { + /// The maximum length of an IPv4 address string. + /// + /// This is the length of the string representation of the address. + /// It does not include the null terminator. + pub const MAX_STRING_LEN: usize = 15; + + /// Create a new IPv4 address from four 8-bit integers. + /// + /// The IP address will be `a.b.c.d`. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::new(192, 168, 0, 1); + /// ``` + pub const fn new(a: u8, b: u8, c: u8, d: u8) -> Self { + Self::from_bits(u32::from_be_bytes([a, b, c, d])) + } + + /// Get the octets of the address. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::new(192, 168, 0, 1); + /// let expected = [192, 168, 0, 1]; + /// assert_eq!(addr.octets(), &expected); + /// ``` + pub const fn octets(&self) -> &[u8; 4] { + // SAFETY: The s_addr field is a 32-bit integer, which is the same size as the array. + unsafe { &*(&self.0.s_addr as *const _ as *const [u8; 4]) } + } + + /// Create a new IPv4 address from a 32-bit integer. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::from_bits(0xc0a80001); + /// assert_eq!(addr, Ipv4Addr::new(192, 168, 0, 1)); + /// ``` + pub const fn from_bits(bits: u32) -> Self { + Ipv4Addr(bindings::in_addr { + s_addr: bits.to_be(), + }) + } + + /// Get the 32-bit integer representation of the address. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::new(192, 168, 0, 1); + /// assert_eq!(addr.to_bits(), 0xc0a80001); + /// ``` + pub const fn to_bits(&self) -> u32 { + u32::from_be(self.0.s_addr) + } + + /// The broadcast address: `255.255.255.255` + /// + /// Used to send a message to all hosts on the network. + pub const BROADCAST: Self = Self::new(255, 255, 255, 255); + + /// "None" address + /// + /// Can be used as return value to indicate an error. + pub const NONE: Self = Self::new(255, 255, 255, 255); + + /// The "any" address: `0.0.0.0` + /// Used to accept any incoming message. + pub const UNSPECIFIED: Self = Self::new(0, 0, 0, 0); + + /// A dummy address: `192.0.0.8` + /// Used as ICMP reply source if no address is set. + pub const DUMMY: Self = Self::new(192, 0, 0, 8); + + /// The loopback address: `127.0.0.1` + /// Used to send a message to the local host. + pub const LOOPBACK: Self = Self::new(127, 0, 0, 1); +} + +impl From<[u8; 4]> for Ipv4Addr { + /// Create a new IPv4 address from an array of 8-bit integers. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::from([192, 168, 0, 1]); + /// assert_eq!(addr, Ipv4Addr::new(192, 168, 0, 1)); + /// ``` + fn from(octets: [u8; 4]) -> Self { + Self::new(octets[0], octets[1], octets[2], octets[3]) + } +} + +impl From for u32 { + /// Get the 32-bit integer representation of the address. + /// + /// This is the same as calling [`Ipv4Addr::to_bits`]. + fn from(addr: Ipv4Addr) -> Self { + addr.to_bits() + } +} + +impl From for Ipv4Addr { + /// Create a new IPv4 address from a 32-bit integer. + /// + /// This is the same as calling [`Ipv4Addr::from_bits`]. + fn from(bits: u32) -> Self { + Self::from_bits(bits) + } +} + +impl PartialEq for Ipv4Addr { + /// Compare two IPv4 addresses. + /// + /// Returns `true` if the addresses are made up of the same bytes. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr1 = Ipv4Addr::new(192, 168, 0, 1); + /// let addr2 = Ipv4Addr::new(192, 168, 0, 1); + /// assert_eq!(addr1, addr2); + /// + /// let addr3 = Ipv4Addr::new(192, 168, 0, 2); + /// assert_ne!(addr1, addr3); + /// ``` + fn eq(&self, other: &Ipv4Addr) -> bool { + self.to_bits() == other.to_bits() + } +} + +impl Eq for Ipv4Addr {} + +impl Hash for Ipv4Addr { + /// Hash an IPv4 address. + /// + /// The trait cannot be derived because the `in_addr` struct does not implement `Hash`. + fn hash(&self, state: &mut H) { + self.to_bits().hash(state) + } +} + +impl PartialOrd for Ipv4Addr { + fn partial_cmp(&self, other: &Self) -> Option { + self.to_bits().partial_cmp(&other.to_bits()) + } +} + +impl Ord for Ipv4Addr { + fn cmp(&self, other: &Self) -> Ordering { + self.to_bits().cmp(&other.to_bits()) + } +} + +/// An IPv6 address. +/// +/// Wraps a `struct in6_addr`. +#[derive(Default, Copy, Clone)] +#[repr(transparent)] +pub struct Ipv6Addr(pub(crate) bindings::in6_addr); + +impl Ipv6Addr { + /// The maximum length of an IPv6 address string. + /// + /// This is the length of the string representation of the address. + /// It does not include the null terminator. + pub const MAX_STRING_LEN: usize = 45; + + /// Create a new IPv6 address from eight 16-bit integers. + /// + /// The 16-bit integers are transformed in network order. + /// + /// The IP address will be `a:b:c:d:e:f:g:h`. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + /// ``` + #[allow(clippy::too_many_arguments)] + pub const fn new(a: u16, b: u16, c: u16, d: u16, e: u16, f: u16, g: u16, h: u16) -> Self { + Self(bindings::in6_addr { + in6_u: bindings::in6_addr__bindgen_ty_1 { + u6_addr16: [ + a.to_be(), + b.to_be(), + c.to_be(), + d.to_be(), + e.to_be(), + f.to_be(), + g.to_be(), + h.to_be(), + ], + }, + }) + } + + /// Get the octets of the address. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + /// let expected = [0x20, 0x01, 0x0d, 0xb8, 0x85, 0xa3, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x2e, 0x03, 0x70, 0x73, 0x34]; + /// assert_eq!(addr.octets(), &expected); + /// ``` + pub const fn octets(&self) -> &[u8; 16] { + // SAFETY: The u6_addr8 field is a [u8; 16] array. + unsafe { &self.0.in6_u.u6_addr8 } + } + + /// Get the segments of the address. + /// + /// A segment is a 16-bit integer. + /// The segments are in network order. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + /// let expected = [0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334]; + /// assert_eq!(addr.segments(), &expected); + /// ``` + pub const fn segments(&self) -> &[u16; 8] { + // SAFETY: The u6_addr16 field is a [u16; 8] array. + unsafe { &self.0.in6_u.u6_addr16 } + } + + /// Create a 128-bit integer representation of the address. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + /// assert_eq!(addr.to_bits(), 0x20010db885a3000000008a2e03707334); + /// ``` + pub fn to_bits(&self) -> u128 { + u128::from_be_bytes(*self.octets() as _) + } + + /// Create a new IPv6 address from a 128-bit integer. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::from_bits(0x20010db885a3000000008a2e03707334); + /// assert_eq!(addr, Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334)); + /// ``` + pub const fn from_bits(bits: u128) -> Self { + Ipv6Addr(bindings::in6_addr { + in6_u: bindings::in6_addr__bindgen_ty_1 { + u6_addr8: bits.to_be_bytes() as _, + }, + }) + } + + /// The "any" address: `::` + /// + /// Used to accept any incoming message. + /// Should not be used as a destination address. + pub const ANY: Self = Self::new(0, 0, 0, 0, 0, 0, 0, 0); + + /// The loopback address: `::1` + /// + /// Used to send a message to the local host. + pub const LOOPBACK: Self = Self::new(0, 0, 0, 0, 0, 0, 0, 1); +} + +impl From<[u16; 8]> for Ipv6Addr { + fn from(value: [u16; 8]) -> Self { + Self(bindings::in6_addr { + in6_u: bindings::in6_addr__bindgen_ty_1 { u6_addr16: value }, + }) + } +} + +impl From<[u8; 16]> for Ipv6Addr { + fn from(value: [u8; 16]) -> Self { + Self(bindings::in6_addr { + in6_u: bindings::in6_addr__bindgen_ty_1 { u6_addr8: value }, + }) + } +} + +impl From for u128 { + fn from(addr: Ipv6Addr) -> Self { + addr.to_bits() + } +} + +impl From for Ipv6Addr { + fn from(bits: u128) -> Self { + Self::from_bits(bits) + } +} + +impl PartialEq for Ipv6Addr { + fn eq(&self, other: &Self) -> bool { + self.to_bits() == other.to_bits() + } +} + +impl Eq for Ipv6Addr {} + +impl Hash for Ipv6Addr { + fn hash(&self, state: &mut H) { + self.to_bits().hash(state) + } +} + +impl PartialOrd for Ipv6Addr { + fn partial_cmp(&self, other: &Self) -> Option { + self.to_bits().partial_cmp(&other.to_bits()) + } +} + +impl Ord for Ipv6Addr { + fn cmp(&self, other: &Self) -> Ordering { + self.to_bits().cmp(&other.to_bits()) + } +} + +/// A wrapper for a generic socket address. +/// +/// Wraps a C `struct sockaddr_storage`. +/// Unlike [`SocketAddr`], this struct is meant to be used internally only, +/// as a parameter for kernel function calls. +#[repr(transparent)] +#[derive(Copy, Clone, Default)] +pub(crate) struct SocketAddrStorage(pub(crate) bindings::__kernel_sockaddr_storage); + +impl SocketAddrStorage { + /// Returns the family of the address. + pub(crate) fn family(&self) -> Result { + // SAFETY: The union access is safe because the `ss_family` field is always valid. + let val: isize = unsafe { self.0.__bindgen_anon_1.__bindgen_anon_1.ss_family as _ }; + AddressFamily::try_from(val) + } + + pub(crate) fn into(self) -> T { + // SAFETY: The `self.0` field is a `struct sockaddr_storage` which is guaranteed to be large enough to hold any socket address. + unsafe { *(&self.0 as *const _ as *const T) } + } +} + +/// A generic Socket Address. Acts like a `struct sockaddr_storage`. +/// `sockaddr_storage` is used instead of `sockaddr` because it is guaranteed to be large enough to hold any socket address. +/// +/// The purpose of this enum is to be used as a generic parameter for functions that can take any type of address. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum SocketAddr { + /// An IPv4 address. + V4(SocketAddrV4), + /// An IPv6 address. + V6(SocketAddrV6), +} + +impl SocketAddr { + /// Returns the size in bytes of the concrete address contained. + /// + /// Used in the kernel functions that take a parameter with the size of the socket address. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddr, SocketAddrV4}; + /// assert_eq!(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80)).size(), + /// core::mem::size_of::()); + pub fn size(&self) -> usize { + match self { + SocketAddr::V4(_) => SocketAddrV4::size(), + SocketAddr::V6(_) => SocketAddrV6::size(), + } + } + + /// Returns the address family of the concrete address contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddr, SocketAddrV4}; + /// use kernel::net::AddressFamily; + /// assert_eq!(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80)).family(), + /// AddressFamily::Inet); + /// ``` + pub fn family(&self) -> AddressFamily { + match self { + SocketAddr::V4(_) => AddressFamily::Inet, + SocketAddr::V6(_) => AddressFamily::Inet6, + } + } + + /// Returns a pointer to the C `struct sockaddr_storage` contained. + /// Used in the kernel functions that take a pointer to a socket address. + pub(crate) fn as_ptr(&self) -> *const SocketAddrStorage { + match self { + SocketAddr::V4(addr) => addr as *const _ as _, + SocketAddr::V6(addr) => addr as *const _ as _, + } + } + + /// Creates a `SocketAddr` from a C `struct sockaddr_storage`. + /// The function consumes the `struct sockaddr_storage`. + /// Used in the kernel functions that return a socket address. + /// + /// # Panics + /// Panics if the address family of the `struct sockaddr_storage` is invalid. + /// This should never happen. + /// If it does, it is likely because of an invalid pointer. + pub(crate) fn try_from_raw(sockaddr: SocketAddrStorage) -> Result { + match sockaddr.family()? { + AddressFamily::Inet => Ok(SocketAddr::V4(sockaddr.into())), + AddressFamily::Inet6 => Ok(SocketAddr::V6(sockaddr.into())), + _ => Err(code::EINVAL), + } + } +} + +impl From for SocketAddr { + fn from(value: SocketAddrV4) -> Self { + SocketAddr::V4(value) + } +} + +impl From for SocketAddr { + fn from(value: SocketAddrV6) -> Self { + SocketAddr::V6(value) + } +} + +impl TryFrom for SocketAddrV4 { + type Error = Error; + + fn try_from(value: SocketAddr) -> core::result::Result { + match value { + SocketAddr::V4(addr) => Ok(addr), + _ => Err(Error::from_errno(bindings::EAFNOSUPPORT as _)), + } + } +} + +impl TryFrom for SocketAddrV6 { + type Error = Error; + + fn try_from(value: SocketAddr) -> core::result::Result { + match value { + SocketAddr::V6(addr) => Ok(addr), + _ => Err(Error::from_errno(bindings::EAFNOSUPPORT as _)), + } + } +} + +/// Generic trait for socket addresses. +/// +/// The purpose of this trait is: +/// - To force all socket addresses to have a size and an address family. +/// - Force all socket addresses to implement specific built-in traits. +pub trait GenericSocketAddr: + Sized + Copy + Clone + PartialEq + Eq + PartialOrd + Ord + Hash + Display +{ + /// Returns the size in bytes of the concrete address. + /// + /// # Examples + /// ```rust + /// use kernel::bindings; + /// use kernel::net::addr::{GenericSocketAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; + /// assert_eq!(SocketAddrV4::size(), core::mem::size_of::()); + /// ``` + fn size() -> usize + where + Self: Sized, + { + core::mem::size_of::() + } + + /// Returns the address family of the concrete address. + /// + /// # Examples + /// + /// ```rust + /// use kernel::net::addr::{GenericSocketAddr, SocketAddrV4}; + /// use kernel::net::AddressFamily; + /// assert_eq!(SocketAddrV4::family(), AddressFamily::Inet); + /// ``` + fn family() -> AddressFamily; +} + +/// IPv4 socket address. +/// +/// Wraps a C `struct sockaddr_in`. +/// +/// # Examples +/// ```rust +/// use kernel::bindings; +/// use kernel::net::addr::{GenericSocketAddr, Ipv4Addr, SocketAddr, SocketAddrV4}; +/// let addr = SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80); +/// assert_eq!(addr.ip(), &Ipv4Addr::new(192, 168, 0, 1)); +/// assert_eq!(SocketAddrV4::size(), core::mem::size_of::()); +/// ``` +#[repr(transparent)] +#[derive(Copy, Clone)] +pub struct SocketAddrV4(pub(crate) bindings::sockaddr_in); + +impl SocketAddrV4 { + /// The maximum length of a IPv4 socket address string representation. + /// + /// This is the length of the string representation of the address. + /// It does not include the null terminator. + pub const MAX_STRING_LEN: usize = 21; + + /// Creates a new IPv4 socket address from an IP address and a port. + /// + /// The port does not need to be in network byte order. + pub const fn new(addr: Ipv4Addr, port: u16) -> Self { + Self(bindings::sockaddr_in { + sin_family: AddressFamily::Inet as _, + sin_port: port.to_be(), + sin_addr: addr.0, + __pad: [0; 8], + }) + } + + /// Returns a reference to the IP address contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddrV4}; + /// + /// let ip = Ipv4Addr::new(192, 168, 0, 1); + /// let addr = SocketAddrV4::new(ip, 80); + /// assert_eq!(addr.ip(), &ip); + /// ``` + pub const fn ip(&self) -> &Ipv4Addr { + // SAFETY: The [Ipv4Addr] is a transparent representation of the C `struct in_addr`, + // which is the type of `sin_addr`. Therefore, the conversion is safe. + unsafe { &*(&self.0.sin_addr as *const _ as *const Ipv4Addr) } + } + + /// Change the IP address contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddrV4}; + /// + /// let mut addr = SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80); + /// addr.set_ip(Ipv4Addr::new(192, 168, 0, 2)); + /// assert_eq!(addr.ip(), &Ipv4Addr::new(192, 168, 0, 2)); + /// ``` + pub fn set_ip(&mut self, ip: Ipv4Addr) { + self.0.sin_addr = ip.0; + } + + /// Returns the port contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddrV4}; + /// + /// let addr = SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80); + /// assert_eq!(addr.port(), 81); + /// ``` + pub const fn port(&self) -> u16 { + self.0.sin_port.to_be() + } + + /// Change the port contained. + /// + /// The port does not need to be in network byte order. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddrV4}; + /// + /// let mut addr = SocketAddrV4::new(Ipv4Addr::new(192, 168, 0, 1), 80); + /// addr.set_port(81); + /// assert_eq!(addr.port(), 81); + /// ``` + pub fn set_port(&mut self, port: u16) { + self.0.sin_port = port.to_be(); + } +} + +impl GenericSocketAddr for SocketAddrV4 { + /// Returns the family of the address. + /// + /// # Invariants + /// The family is always [AddressFamily::Inet]. + fn family() -> AddressFamily { + AddressFamily::Inet + } +} + +impl PartialEq for SocketAddrV4 { + fn eq(&self, other: &SocketAddrV4) -> bool { + self.ip() == other.ip() && self.port() == other.port() + } +} + +impl Eq for SocketAddrV4 {} + +impl Hash for SocketAddrV4 { + fn hash(&self, state: &mut H) { + (self.ip(), self.port()).hash(state) + } +} + +impl PartialOrd for SocketAddrV4 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SocketAddrV4 { + fn cmp(&self, other: &Self) -> Ordering { + (self.ip(), self.port()).cmp(&(other.ip(), other.port())) + } +} + +/// IPv6 socket address. +/// +/// Wraps a C `struct sockaddr_in6`. +/// +/// # Examples +/// ```rust +/// use kernel::bindings; +/// use kernel::net::addr::{GenericSocketAddr, Ipv6Addr, SocketAddr, SocketAddrV6}; +/// +/// let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80, 0, 0); +/// assert_eq!(addr.ip(), &Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); +/// assert_eq!(SocketAddrV6::size(), core::mem::size_of::()); +#[repr(transparent)] +#[derive(Copy, Clone)] +pub struct SocketAddrV6(pub(crate) bindings::sockaddr_in6); + +impl SocketAddrV6 { + /// The maximum length of a IPv6 socket address string representation. + /// + /// This is the length of the string representation of the address. + /// It does not include the null terminator. + pub const MAX_STRING_LEN: usize = 74; + + /// Creates a new IPv6 socket address from an IP address, a port, a flowinfo and a scope_id. + /// The port does not need to be in network byte order. + pub const fn new(addr: Ipv6Addr, port: u16, flowinfo: u32, scope_id: u32) -> Self { + Self(bindings::sockaddr_in6 { + sin6_family: AddressFamily::Inet6 as _, + sin6_port: port.to_be(), + sin6_flowinfo: flowinfo, + sin6_addr: addr.0, + sin6_scope_id: scope_id, + }) + } + + /// Returns a reference to the IP address contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let ip = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); + /// let addr = SocketAddrV6::new(ip, 80, 0, 0); + /// assert_eq!(addr.ip(), &ip); + /// ``` + pub const fn ip(&self) -> &Ipv6Addr { + // SAFETY: The [Ipv6Addr] is a transparent representation of the C `struct in6_addr`, + // which is the type of `sin6_addr`. Therefore, the conversion is safe. + unsafe { &*(&self.0.sin6_addr as *const _ as *const Ipv6Addr) } + } + + /// Change the IP address contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let ip1 = Ipv6Addr::LOOPBACK; + /// let ip2 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2); + /// let mut addr = SocketAddrV6::new(ip1, 80, 0, 0); + /// addr.set_ip(ip2); + /// assert_eq!(addr.ip(), &ip2); + /// ``` + pub fn set_ip(&mut self, addr: Ipv6Addr) { + self.0.sin6_addr = addr.0; + } + + /// Returns the port contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80, 0, 0); + /// assert_eq!(addr.port(), 80); + /// ``` + pub const fn port(&self) -> u16 { + self.0.sin6_port.to_be() + } + + /// Change the port contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let mut addr = SocketAddrV6::new(Ipv6Addr::LOOPBACK, 80, 0, 0); + /// addr.set_port(443); + /// assert_eq!(addr.port(), 443); + /// ``` + pub fn set_port(&mut self, port: u16) { + self.0.sin6_port = port.to_be(); + } + + /// Returns the flowinfo contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80, 0, 0); + /// assert_eq!(addr.flowinfo(), 0); + /// ``` + pub const fn flowinfo(&self) -> u32 { + self.0.sin6_flowinfo as _ + } + + /// Change the flowinfo contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let mut addr = SocketAddrV6::new(Ipv6Addr::LOOPBACK, 80, 0, 0); + /// addr.set_flowinfo(1); + /// assert_eq!(addr.flowinfo(), 1); + /// ``` + pub fn set_flowinfo(&mut self, flowinfo: u32) { + self.0.sin6_flowinfo = flowinfo; + } + + /// Returns the scope_id contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80, 0, 1); + /// assert_eq!(addr.scope_id(), 1); + /// ``` + pub const fn scope_id(&self) -> u32 { + self.0.sin6_scope_id as _ + } + + /// Change the scope_id contained. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let mut addr = SocketAddrV6::new(Ipv6Addr::LOOPBACK, 80, 0, 0); + /// addr.set_scope_id(1); + /// assert_eq!(addr.scope_id(), 1); + /// ``` + pub fn set_scope_id(&mut self, scope_id: u32) { + self.0.sin6_scope_id = scope_id; + } +} + +impl GenericSocketAddr for SocketAddrV6 { + /// Returns the family of the address. + /// + /// # Invariants + /// The family is always [AddressFamily::Inet6]. + fn family() -> AddressFamily { + AddressFamily::Inet6 + } +} + +impl PartialEq for SocketAddrV6 { + fn eq(&self, other: &SocketAddrV6) -> bool { + self.ip() == other.ip() + && self.port() == other.port() + && self.flowinfo() == other.flowinfo() + && self.scope_id() == other.scope_id() + } +} + +impl Eq for SocketAddrV6 {} + +impl Hash for SocketAddrV6 { + fn hash(&self, state: &mut H) { + (self.ip(), self.port(), self.flowinfo(), self.scope_id()).hash(state) + } +} + +impl PartialOrd for SocketAddrV6 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SocketAddrV6 { + fn cmp(&self, other: &Self) -> Ordering { + (self.ip(), self.port(), self.flowinfo(), self.scope_id()).cmp(&( + other.ip(), + other.port(), + other.flowinfo(), + other.scope_id(), + )) + } +} + +/// Create a Socket address from a string. +/// +/// This method is a wrapper for the `inet_pton_with_scope` C function, which transforms a string +/// to the specified sockaddr* structure. +fn address_from_string(src: &str, port: &str, net: &Namespace) -> Result { + let src = CString::try_from_fmt(fmt!("{}", src))?; + let port = CString::try_from_fmt(fmt!("{}", port))?; + let mut addr = MaybeUninit::::zeroed(); + + // SAFETY: FFI call, all pointers are valid for the duration of the call. + // The address family matches the address structure. + match unsafe { + bindings::inet_pton_with_scope( + net as *const _ as *mut bindings::net as _, + T::family() as _, + src.as_ptr() as _, + port.as_ptr() as _, + addr.as_mut_ptr() as _, + ) + } { + // SAFETY: The address was initialized by the C function. + // Whatever was not initialized, e.g. flow info or scope id for ipv6, are zeroed. + 0 => Ok(unsafe { addr.assume_init() }), + errno => Err(Error::from_errno(errno as _)), + } +} + +/// Write the string representation of the `T` address to the formatter. +/// +/// This function is used to implement the `Display` trait for each address. +/// +/// The `cfmt` parameter is the C string format used to format the address. +/// For example, the format for an IPv4 address is `"%pI4"`. +/// +/// The `BUF_LEN` parameter is the size of the buffer used to format the address, including the null terminator. +/// +/// # Safety +/// In order to have a correct output, the `cfmt` parameter must be a valid C string format for the `T` address. +/// Also, the `BUF_LEN` parameter must be at least the length of the string representation of the address. +unsafe fn write_addr( + formatter: &mut Formatter<'_>, + cfmt: &CStr, + addr: &T, +) -> core::fmt::Result { + let mut buff = [0u8; BUF_LEN]; + // SAFETY: the buffer is big enough to contain the string representation of the address. + // The format is valid for the address. + let s = match unsafe { + bindings::snprintf( + buff.as_mut_ptr() as _, + BUF_LEN as _, + cfmt.as_ptr() as _, + addr as *const T, + ) + } { + n if n < 0 => Err(()), + + // the buffer is probably bigger than the actual string: truncate at the first null byte + _ => buff + .iter() + .position(|&c| c == 0) + // SAFETY: the buffer contains a UTF-8 valid string and contains a single null terminator. + .map(|i| unsafe { core::str::from_utf8_unchecked(&buff[..i]) }) + .ok_or(()), + }; + match s { + Ok(s) => write!(formatter, "{}", s), + Err(_) => Err(core::fmt::Error), + } +} + +impl Display for Ipv4Addr { + /// Display the address as a string. + /// The bytes are in network order. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv4Addr; + /// use kernel::pr_info; + /// + /// let addr = Ipv4Addr::new(192, 168, 0, 1); + /// pr_info!("{}", addr); // prints "192.168.0.1" + /// ``` + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + // SAFETY: MAX_STRING_LEN is the length of 255.255.255.255, the biggest Ipv4Addr string. + // +1 for the null terminator. + unsafe { + write_addr::<{ Ipv4Addr::MAX_STRING_LEN + 1 }, Ipv4Addr>(f, c_str!("%pI4"), self) + .map_err(|_| core::fmt::Error) + } + } +} + +impl Debug for Ipv4Addr { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "Ipv4Addr({})", self) + } +} + +impl FromStr for Ipv4Addr { + type Err = (); + + /// Create a new IPv4 address from a string. + /// The string must be in the format `a.b.c.d`, where `a`, `b`, `c` and `d` are 8-bit integers. + /// + /// # Examples + /// Valid addresses: + /// ```rust + /// use core::str::FromStr; + /// use kernel::net::addr::Ipv4Addr; + /// + /// let addr = Ipv4Addr::from_str("192.168.0.1"); + /// assert_eq!(addr, Ok(Ipv4Addr::new(192, 168, 0, 1))); + /// ``` + /// + /// Invalid addresses: + /// ```rust + /// use core::str::FromStr; + /// use kernel::net::addr::Ipv4Addr; + /// + /// let mut addr = Ipv4Addr::from_str("invalid"); + /// assert_eq!(addr, Err(())); + /// + /// addr = Ipv4Addr::from_str("280.168.0.1"); + /// assert_eq!(addr, Err(())); + /// + /// addr = Ipv4Addr::from_str("0.0.0.0.0"); + /// assert_eq!(addr, Err(())); + /// ``` + fn from_str(s: &str) -> Result { + let mut buffer = [0u8; 4]; + // SAFETY: FFI call, + // there is no need to construct a NULL-terminated string, as the length is passed. + match unsafe { + bindings::in4_pton( + s.as_ptr() as *const _, + s.len() as _, + buffer.as_mut_ptr() as _, + -1, + ptr::null_mut(), + ) + } { + 1 => Ok(Ipv4Addr::from(buffer)), + _ => Err(()), + } + } +} + +impl Display for Ipv6Addr { + /// Display the address as a string. + /// The bytes are in network order. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::Ipv6Addr; + /// use kernel::pr_info; + /// + /// let addr = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334); + /// pr_info!("{}", addr); // prints "2001:db8:85a3::8a2e:370:7334" + /// ``` + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + // SAFETY: MAX_STRING_LEN is the length of ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff, the biggest Ipv6Addr string. + unsafe { + write_addr::<{ Ipv6Addr::MAX_STRING_LEN + 1 }, Ipv6Addr>(f, c_str!("%pI6c"), self) + } + } +} + +impl Debug for Ipv6Addr { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "Ipv6Addr({})", self) + } +} + +impl FromStr for Ipv6Addr { + type Err = (); + + /// Create a new IPv6 address from a string. + /// + /// The address must follow the format described in [RFC 4291](https://tools.ietf.org/html/rfc4291#section-2.2). + /// + /// # Examples + /// Valid addresses: + /// ```rust + /// use core::str::FromStr; + /// use kernel::net::addr::Ipv6Addr; + /// + /// let addr = Ipv6Addr::from_str("2001:db8:85a3:0:0:8a2e:370:7334").unwrap(); + /// assert_eq!(addr, Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334)); + /// ``` + /// + /// Invalid addresses: + /// ```rust + /// use core::str::FromStr; + /// use kernel::net::addr::Ipv6Addr; + /// + /// let mut addr = Ipv6Addr::from_str("invalid"); + /// assert_eq!(addr, Err(())); + /// + /// addr = Ipv6Addr::from_str("2001:db8:85a3:0:0:8a2e:370:7334:1234"); + /// assert_eq!(addr, Err(())); + /// ``` + fn from_str(s: &str) -> Result { + let mut buffer = [0u8; 16]; + // SAFETY: FFI call, + // there is no need to construct a NULL-terminated string, as the length is passed. + match unsafe { + bindings::in6_pton( + s.as_ptr() as _, + s.len() as _, + buffer.as_mut_ptr() as _, + -1, + ptr::null_mut(), + ) + } { + 1 => Ok(Ipv6Addr::from(buffer)), + _ => Err(()), + } + } +} + +impl Display for SocketAddr { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + match self { + SocketAddr::V4(addr) => Display::fmt(addr, f), + SocketAddr::V6(addr) => Display::fmt(addr, f), + } + } +} + +impl Debug for SocketAddr { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "SocketAddr({})", self) + } +} + +impl FromStr for SocketAddr { + type Err = Error; + + fn from_str(s: &str) -> core::result::Result { + let funcs = [ + |s| SocketAddrV4::from_str(s).map(SocketAddr::V4), + |s| SocketAddrV6::from_str(s).map(SocketAddr::V6), + ]; + + funcs.iter().find_map(|f| f(s).ok()).ok_or(code::EINVAL) + } +} + +impl Display for SocketAddrV4 { + /// Display the address as a string. + /// + /// The output is of the form `address:port`, where `address` is the IP address in dotted + /// decimal notation, and `port` is the port number. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::SocketAddrV4; + /// use kernel::pr_info; + /// + /// let addr = SocketAddrV4::from_str("1.2.3.4:5678").unwrap(); + /// pr_info!("{}", addr); // prints "1.2.3.4:5678" + /// ``` + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + // SAFETY: MAX_STRING_LEN is the length of 255.255.255.255:12345, the biggest SocketAddrV4 string. + unsafe { + write_addr::<{ SocketAddrV4::MAX_STRING_LEN + 1 }, SocketAddrV4>( + f, + c_str!("%pISpc"), + self, + ) + } + } +} + +impl Debug for SocketAddrV4 { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "SocketAddrV4({})", self) + } +} + +impl FromStr for SocketAddrV4 { + type Err = Error; + + /// Parses a string as an IPv4 socket address. + /// + /// The string must be in the form `a.b.c.d:p`, where `a`, `b`, `c`, `d` are the four + /// components of the IPv4 address, and `p` is the port. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv4Addr, SocketAddrV4}; + /// + /// // valid + /// let addr = SocketAddrV4::from_str("192.168.1.0:80").unwrap(); + /// assert_eq!(addr.ip(), &Ipv4Addr::new(192, 168, 1, 0)); + /// assert_eq!(addr.port(), 80); + /// + /// // invalid + /// assert!(SocketAddrV4::from_str("192.168:800:80").is_err()); + /// ``` + fn from_str(s: &str) -> Result { + let (addr, port) = s.split_once(':').ok_or(code::EINVAL)?; + address_from_string(addr, port, init_net()) + } +} + +impl Display for SocketAddrV6 { + /// Display the address as a string. + /// + /// The output string is of the form `[addr]:port`, where `addr` is an IPv6 address and `port` + /// is a port number. + /// + /// Flow info and scope ID are not supported and are excluded from the output. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// let addr = SocketAddrV6::from_str("[::1]:80").unwrap(); + /// pr_info!("{}", addr); // prints "[::1]:80" + /// ``` + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + // SAFETY: MAX_STRING_LEN is big enough to hold the biggest SocketAddrV6 string. + unsafe { + write_addr::<{ SocketAddrV6::MAX_STRING_LEN + 1 }, SocketAddrV6>( + f, + c_str!("%pISpc"), + self, + ) + } + } +} + +impl Debug for SocketAddrV6 { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "SocketAddrV6({})", self) + } +} + +impl FromStr for SocketAddrV6 { + type Err = Error; + + /// Parses a string as an IPv6 socket address. + /// + /// The given string must be of the form `[addr]:port`, where `addr` is an IPv6 address and + /// `port` is a port number. + /// + /// Flow info and scope ID are not supported. + /// + /// # Examples + /// ```rust + /// use kernel::net::addr::{Ipv6Addr, SocketAddrV6}; + /// + /// // valid + /// let addr = SocketAddrV6::from_str("[2001:db8:85a3::8a2e:370:7334]:80").unwrap(); + /// assert_eq!(addr.ip(), &Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334)); + /// assert_eq!(addr.port(), 80); + /// ``` + fn from_str(s: &str) -> Result { + let (addr, port) = s.rsplit_once(':').ok_or(code::EINVAL)?; + let address = addr.trim_start_matches('[').trim_end_matches(']'); + address_from_string(address, port, init_net()) + } +} diff --git a/rust/kernel/net/ip.rs b/rust/kernel/net/ip.rs new file mode 100644 index 00000000000000..84f98d356137ec --- /dev/null +++ b/rust/kernel/net/ip.rs @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! IP protocol definitions. +//! +//! This module contains the kernel structures and functions related to IP protocols. +//! +//! C headers: +//! - [`include/linux/in.h`](../../../../include/linux/in.h) +//! - [`include/linux/ip.h`](../../../../include/linux/ip.h) +//! - [`include/uapi/linux/ip.h`](../../../../include/uapi/linux/ip.h) + +/// The Ip protocol. +/// +/// See [`tools/include/uapi/linux/in.h`](../../../../tools/include/uapi/linux/in.h) +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum IpProtocol { + /// Dummy protocol for TCP + Ip = bindings::IPPROTO_IP as isize, + /// Internet Control Message Protocol + Icmp = bindings::IPPROTO_ICMP as isize, + /// Internet Group Management Protocol + Igmp = bindings::IPPROTO_IGMP as isize, + /// IPIP tunnels (older KA9Q tunnels use 94) + IpIp = bindings::IPPROTO_IPIP as isize, + /// Transmission Control Protocol + Tcp = bindings::IPPROTO_TCP as isize, + /// Exterior Gateway Protocol + Egp = bindings::IPPROTO_EGP as isize, + /// PUP protocol + Pup = bindings::IPPROTO_PUP as isize, + /// User Datagram Protocol + Udp = bindings::IPPROTO_UDP as isize, + /// XNS Idp protocol + Idp = bindings::IPPROTO_IDP as isize, + /// SO Transport Protocol Class 4 + Tp = bindings::IPPROTO_TP as isize, + /// Datagram Congestion Control Protocol + Dccp = bindings::IPPROTO_DCCP as isize, + /// Ipv6-in-Ipv4 tunnelling + Ipv6 = bindings::IPPROTO_IPV6 as isize, + /// Rsvp Protocol + Rsvp = bindings::IPPROTO_RSVP as isize, + /// Cisco GRE tunnels (rfc 1701,1702) + Gre = bindings::IPPROTO_GRE as isize, + /// Encapsulation Security Payload protocol + Esp = bindings::IPPROTO_ESP as isize, + /// Authentication Header protocol + Ah = bindings::IPPROTO_AH as isize, + /// Multicast Transport Protocol + Mtp = bindings::IPPROTO_MTP as isize, + /// Ip option pseudo header for BEET + Beetph = bindings::IPPROTO_BEETPH as isize, + /// Encapsulation Header + Encap = bindings::IPPROTO_ENCAP as isize, + /// Protocol Independent Multicast + Pim = bindings::IPPROTO_PIM as isize, + /// Compression Header Protocol + Comp = bindings::IPPROTO_COMP as isize, + /// Layer 2 Tunnelling Protocol + L2Tp = bindings::IPPROTO_L2TP as isize, + /// Stream Control Transport Protocol + Sctp = bindings::IPPROTO_SCTP as isize, + /// Udp-Lite (Rfc 3828) + UdpLite = bindings::IPPROTO_UDPLITE as isize, + /// Mpls in Ip (Rfc 4023) + Mpls = bindings::IPPROTO_MPLS as isize, + /// Ethernet-within-Ipv6 Encapsulation + Ethernet = bindings::IPPROTO_ETHERNET as isize, + /// Raw Ip packets + Raw = bindings::IPPROTO_RAW as isize, + /// Multipath Tcp connection + Mptcp = bindings::IPPROTO_MPTCP as isize, +} diff --git a/rust/kernel/net/socket.rs b/rust/kernel/net/socket.rs new file mode 100644 index 00000000000000..1a7b3f7d8fc084 --- /dev/null +++ b/rust/kernel/net/socket.rs @@ -0,0 +1,641 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Socket API. +//! +//! This module contains the Socket layer kernel APIs that have been wrapped for usage by Rust code +//! in the kernel. +//! +//! C header: [`include/linux/socket.h`](../../../../include/linux/socket.h) +//! +//! This API is inspired by the Rust std::net Socket API, but is not a direct port. +//! The main difference is that the Rust std::net API is designed for user-space, while this API +//! is designed for kernel-space. +//! Rust net API: + +use super::*; +use crate::error::{to_result, Result}; +use crate::net::addr::*; +use crate::net::ip::IpProtocol; +use crate::net::socket::opts::{OptionsLevel, WritableOption}; +use core::cmp::max; +use core::marker::PhantomData; +use flags::*; +use kernel::net::socket::opts::SocketOption; + +pub mod flags; +pub mod opts; + +/// The socket type. +pub enum SockType { + /// Stream socket (e.g. TCP) + Stream = bindings::sock_type_SOCK_STREAM as isize, + /// Connectionless socket (e.g. UDP) + Datagram = bindings::sock_type_SOCK_DGRAM as isize, + /// Raw socket + Raw = bindings::sock_type_SOCK_RAW as isize, + /// Reliably-delivered message + Rdm = bindings::sock_type_SOCK_RDM as isize, + /// Sequenced packet stream + Seqpacket = bindings::sock_type_SOCK_SEQPACKET as isize, + /// Datagram Congestion Control Protocol socket + Dccp = bindings::sock_type_SOCK_DCCP as isize, + /// Packet socket + Packet = bindings::sock_type_SOCK_PACKET as isize, +} + +/// The socket shutdown command. +pub enum ShutdownCmd { + /// Disallow further receive operations. + Read = bindings::sock_shutdown_cmd_SHUT_RD as isize, + /// Disallow further send operations. + Write = bindings::sock_shutdown_cmd_SHUT_WR as isize, + /// Disallow further send and receive operations. + Both = bindings::sock_shutdown_cmd_SHUT_RDWR as isize, +} + +/// A generic socket. +/// +/// Wraps a `struct socket` from the kernel. +/// See [include/linux/socket.h](../../../../include/linux/socket.h) for more information. +/// +/// The wrapper offers high-level methods for common operations on the socket. +/// More fine-grained control is possible by using the C bindings directly. +/// +/// # Example +/// A simple TCP echo server: +/// ```rust +/// use kernel::flag_set; +/// use kernel::net::addr::{Ipv4Addr, SocketAddr, SocketAddrV4}; +/// use kernel::net::{AddressFamily, init_net}; +/// use kernel::net::ip::IpProtocol; +/// use kernel::net::socket::{Socket, SockType}; +/// +/// let socket = Socket::new_kern( +/// init_net(), +/// AddressFamily::Inet, +/// SockType::Stream, +/// IpProtocol::Tcp, +/// )?; +/// socket.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)))?; +/// socket.listen(10)?; +/// while let Ok(peer) = socket.accept(true) { +/// let mut buf = [0u8; 1024]; +/// peer.receive(&mut buf, flag_set!())?; +/// peer.send(&buf, flag_set!())?; +/// } +/// ``` +/// A simple UDP echo server: +/// ```rust +/// use kernel::net::addr::{Ipv4Addr, SocketAddr, SocketAddrV4}; +/// use kernel::net::{AddressFamily, init_net}; +/// use kernel::net::ip::IpProtocol; +/// use kernel::net::socket::{Socket, SockType}; +/// use kernel::flag_set; +/// +/// let socket = Socket::new_kern(init_net(), AddressFamily::Inet, SockType::Datagram, IpProtocol::Udp)?;/// +/// socket.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)))?; +/// let mut buf = [0u8; 1024]; +/// while let Ok((len, sender_opt)) = socket.receive_from(&mut buf, flag_set!()) { +/// let sender: SocketAddr = sender_opt.expect("Sender address is always available for UDP"); +/// socket.send_to(&buf[..len], &sender, flag_set!())?; +/// } +/// ``` +/// +/// # Invariants +/// +/// The socket pointer is valid for the lifetime of the wrapper. +#[repr(transparent)] +pub struct Socket(*mut bindings::socket); + +/// Getters and setters of socket internal fields. +/// +/// Not all fields are currently supported: hopefully, this will be improved in the future. +impl Socket { + /// Retrieve the flags associated with the socket. + /// + /// Unfortunately, these flags cannot be represented as a [`FlagSet`], since [`SocketFlag`]s + /// are not represented as masks but as the index of the bit they represent. + /// + /// An enum could be created, containing masks instead of indexes, but this could create + /// confusion with the C side. + /// + /// The methods [`Socket::has_flag`] and [`Socket::set_flags`] can be used to check and set individual flags. + pub fn flags(&self) -> u64 { + unsafe { (*self.0).flags } + } + + /// Set the flags associated with the socket. + pub fn set_flags(&self, flags: u64) { + unsafe { + (*self.0).flags = flags; + } + } + + /// Checks if the socket has a specific flag. + /// + /// # Example + /// ``` + /// use kernel::net::socket::{Socket, flags::SocketFlag, SockType}; + /// use kernel::net::AddressFamily; + /// use kernel::net::ip::IpProtocol; + /// + /// let socket = Socket::new(AddressFamily::Inet, SockType::Datagram, IpProtocol::Udp)?; + /// assert_eq!(socket.has_flag(SocketFlag::CustomSockOpt), false); + /// ``` + pub fn has_flag(&self, flag: SocketFlag) -> bool { + bindings::__BindgenBitfieldUnit::<[u8; 8], u8>::new(self.flags().to_be_bytes()) + .get_bit(flag as _) + } + + /// Sets a flag on the socket. + /// + /// # Example + /// ``` + /// use kernel::net::socket::{Socket, flags::SocketFlag, SockType}; + /// use kernel::net::AddressFamily; + /// use kernel::net::ip::IpProtocol; + /// + /// let socket = Socket::new(AddressFamily::Inet, SockType::Datagram, IpProtocol::Udp)?; + /// assert_eq!(socket.has_flag(SocketFlag::CustomSockOpt), false); + /// socket.set_flag(SocketFlag::CustomSockOpt, true); + /// assert_eq!(socket.has_flag(SocketFlag::CustomSockOpt), true); + /// ``` + pub fn set_flag(&self, flag: SocketFlag, value: bool) { + let flags_width = core::mem::size_of_val(&self.flags()) * 8; + let mut flags = + bindings::__BindgenBitfieldUnit::<[u8; 8], u8>::new(self.flags().to_be_bytes()); + flags.set_bit(flag as _, value); + self.set_flags(flags.get(0, flags_width as _)); + } + + /// Consumes the socket and returns the underlying pointer. + /// + /// The pointer is valid for the lifetime of the wrapper. + /// + /// # Safety + /// The caller must ensure that the pointer is not used after the wrapper is dropped. + pub unsafe fn into_inner(self) -> *mut bindings::socket { + self.0 + } + + /// Returns the underlying pointer. + /// + /// The pointer is valid for the lifetime of the wrapper. + /// + /// # Safety + /// The caller must ensure that the pointer is not used after the wrapper is dropped. + pub unsafe fn as_inner(&self) -> *mut bindings::socket { + self.0 + } +} + +/// Socket API implementation +impl Socket { + /// Private utility function to create a new socket by calling a function. + /// The function is generic over the creation function. + /// + /// # Arguments + /// * `create_fn`: A function that initiates the socket given as parameter. + /// The function must return 0 on success and a negative error code on failure. + fn base_new(create_fn: T) -> Result + where + T: (FnOnce(*mut *mut bindings::socket) -> core::ffi::c_int), + { + let mut socket_ptr: *mut bindings::socket = core::ptr::null_mut(); + to_result(create_fn(&mut socket_ptr))?; + Ok(Self(socket_ptr)) + } + + /// Create a new socket. + /// + /// Wraps the `sock_create` function. + pub fn new(family: AddressFamily, type_: SockType, proto: IpProtocol) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + Self::base_new(|socket_ptr| unsafe { + bindings::sock_create(family as _, type_ as _, proto as _, socket_ptr) + }) + } + + /// Create a new socket in a specific namespace. + /// + /// Wraps the `sock_create_kern` function. + pub fn new_kern( + ns: &Namespace, + family: AddressFamily, + type_: SockType, + proto: IpProtocol, + ) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + Self::base_new(|socket_ptr| unsafe { + bindings::sock_create_kern(ns.0.get(), family as _, type_ as _, proto as _, socket_ptr) + }) + } + + /// Creates a new "lite" socket. + /// + /// Wraps the `sock_create_lite` function. + /// + /// This is a lighter version of `sock_create` that does not perform any sanity check. + pub fn new_lite(family: AddressFamily, type_: SockType, proto: IpProtocol) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + Self::base_new(|socket_ptr| unsafe { + bindings::sock_create_lite(family as _, type_ as _, proto as _, socket_ptr) + }) + } + + /// Binds the socket to a specific address. + /// + /// Wraps the `kernel_bind` function. + pub fn bind(&self, address: SocketAddr) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + to_result(unsafe { + bindings::kernel_bind(self.0, address.as_ptr() as _, address.size() as i32) + }) + } + + /// Connects the socket to a specific address. + /// + /// Wraps the `kernel_connect` function. + /// + /// The socket must be a connection-oriented socket. + /// If the socket is not bound, it will be bound to a random local address. + /// + /// # Example + /// ```rust + /// use kernel::net::{AddressFamily, init_net}; + /// use kernel::net::addr::{Ipv4Addr, SocketAddr, SocketAddrV4}; + /// use kernel::net::ip::IpProtocol; + /// use kernel::net::socket::{Socket, SockType}; + /// + /// let socket = Socket::new_kern(init_net(), AddressFamily::Inet, SockType::Stream, IpProtocol::Tcp)?; + /// socket.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)))?; + /// socket.listen(10)?; + pub fn listen(&self, backlog: i32) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + to_result(unsafe { bindings::kernel_listen(self.0, backlog) }) + } + + /// Accepts a connection on a socket. + /// + /// Wraps the `kernel_accept` function. + pub fn accept(&self, block: bool) -> Result { + let mut new_sock = core::ptr::null_mut(); + let flags: i32 = if block { 0 } else { bindings::O_NONBLOCK as _ }; + + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + to_result(unsafe { bindings::kernel_accept(self.0, &mut new_sock, flags as _) })?; + + Ok(Self(new_sock)) + } + + /// Returns the address the socket is bound to. + /// + /// Wraps the `kernel_getsockname` function. + pub fn sockname(&self) -> Result { + let mut addr: SocketAddrStorage = SocketAddrStorage::default(); + + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + unsafe { + to_result(bindings::kernel_getsockname( + self.0, + &mut addr as *mut _ as _, + )) + } + .and_then(|_| SocketAddr::try_from_raw(addr)) + } + + /// Returns the address the socket is connected to. + /// + /// Wraps the `kernel_getpeername` function. + /// + /// The socket must be connected. + pub fn peername(&self) -> Result { + let mut addr: SocketAddrStorage = SocketAddrStorage::default(); + + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + unsafe { + to_result(bindings::kernel_getpeername( + self.0, + &mut addr as *mut _ as _, + )) + } + .and_then(|_| SocketAddr::try_from_raw(addr)) + } + + /// Connects the socket to a specific address. + /// + /// Wraps the `kernel_connect` function. + pub fn connect(&self, address: &SocketAddr, flags: i32) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + unsafe { + to_result(bindings::kernel_connect( + self.0, + address.as_ptr() as _, + address.size() as _, + flags, + )) + } + } + + /// Shuts down the socket. + /// + /// Wraps the `kernel_sock_shutdown` function. + pub fn shutdown(&self, how: ShutdownCmd) -> Result { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + unsafe { to_result(bindings::kernel_sock_shutdown(self.0, how as _)) } + } + + /// Receive a message from the socket. + /// + /// This function is the lowest-level receive function. It is used by the other receive functions. + /// + /// The `flags` parameter is a set of flags that control the behavior of the function. + /// The flags are described in the [`ReceiveFlag`] enum. + /// + /// The returned Message is a wrapper for `msghdr` and it contains the header information about the message, + /// including the sender address (if present) and the flags. + /// + /// The data message is written to the provided buffer and the number of bytes written is returned together with the header. + /// + /// Wraps the `kernel_recvmsg` function. + pub fn receive_msg( + &self, + bytes: &mut [u8], + flags: FlagSet, + ) -> Result<(usize, MessageHeader)> { + let addr = SocketAddrStorage::default(); + + let mut msg = bindings::msghdr { + msg_name: &addr as *const _ as _, + ..Default::default() + }; + + let mut vec = bindings::kvec { + iov_base: bytes.as_mut_ptr() as _, + iov_len: bytes.len() as _, + }; + + // SAFETY: FFI call; the socket address is valid for the lifetime of the wrapper. + let size = unsafe { + bindings::kernel_recvmsg( + self.0, + &mut msg as _, + &mut vec, + 1, + bytes.len() as _, + flags.value() as _, + ) + }; + to_result(size)?; + + let addr: Option = SocketAddr::try_from_raw(addr).ok(); + + Ok((size as _, MessageHeader::new(msg, addr))) + } + + /// Receives data from a remote socket and returns the bytes read and the sender address. + /// + /// Used by connectionless sockets to retrieve the sender of the message. + /// If the socket is connection-oriented, the sender address will be `None`. + /// + /// The function abstracts the usage of the `struct msghdr` type. + /// See [Socket::receive_msg] for more information. + pub fn receive_from( + &self, + bytes: &mut [u8], + flags: FlagSet, + ) -> Result<(usize, Option)> { + self.receive_msg(bytes, flags) + .map(|(size, hdr)| (size, hdr.into())) + } + + /// Receives data from a remote socket and returns only the bytes read. + /// + /// Used by connection-oriented sockets, where the sender address is the connected peer. + pub fn receive(&self, bytes: &mut [u8], flags: FlagSet) -> Result { + let (size, _) = self.receive_from(bytes, flags)?; + Ok(size) + } + + /// Sends a message to a remote socket. + /// + /// Wraps the `kernel_sendmsg` function. + /// + /// Crate-public to allow its usage only in the kernel crate. + /// In the future, this function could be made public, accepting a [`Message`] as input, + /// but with the current API, it does not give any advantage. + pub(crate) fn send_msg( + &self, + bytes: &[u8], + mut message: bindings::msghdr, + flags: FlagSet, + ) -> Result { + let mut vec = bindings::kvec { + iov_base: bytes.as_ptr() as _, + iov_len: bytes.len() as _, + }; + message.msg_flags = flags.value() as _; + + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + let size = unsafe { + bindings::kernel_sendmsg( + self.0, + &message as *const _ as _, + &mut vec, + 1, + bytes.len() as _, + ) + }; + to_result(size)?; + Ok(size as _) + } + + /// Sends a message to a remote socket and returns the bytes sent. + /// + /// The `flags` parameter is a set of flags that control the behavior of the function. + /// The flags are described in the [`SendFlag`] enum. + pub fn send(&self, bytes: &[u8], flags: FlagSet) -> Result { + self.send_msg(bytes, bindings::msghdr::default(), flags) + } + + /// Sends a message to a specific remote socket address and returns the bytes sent. + /// + /// The `flags` parameter is a set of flags that control the behavior of the function. + /// The flags are described in the [`SendFlag`] enum. + pub fn send_to( + &self, + bytes: &[u8], + address: &SocketAddr, + flags: FlagSet, + ) -> Result { + let message = bindings::msghdr { + msg_name: address.as_ptr() as _, + msg_namelen: address.size() as _, + ..Default::default() + }; + self.send_msg(bytes, message, flags) + } + + /// Sets an option on the socket. + /// + /// Wraps the `sock_setsockopt` function. + /// + /// The generic type `T` is the type of the option value. + /// See the [options module](opts) for the type and extra information about each option. + /// + /// Unfortunately, options can only be set but not retrieved. + /// This is because the kernel functions to retrieve options are not exported by the kernel. + /// The only exported functions accept user-space pointers, and are therefore not usable in the kernel. + /// + /// # Example + /// ``` + /// use kernel::net::AddressFamily; + /// use kernel::net::ip::IpProtocol;use kernel::net::socket::{Socket, SockType}; + /// use kernel::net::socket::opts; + /// + /// let socket = Socket::new(AddressFamily::Inet, SockType::Datagram, IpProtocol::Udp)?; + /// socket.set_option::(true)?; + /// ``` + pub fn set_option(&self, value: impl Into) -> Result + where + O: SocketOption + WritableOption, + { + let value_ptr = SockPtr::new(&value); + + // The minimum size is the size of an integer. + let min_size = core::mem::size_of::(); + let size = max(core::mem::size_of::(), min_size); + + if O::level() == OptionsLevel::Socket && !self.has_flag(SocketFlag::CustomSockOpt) { + // SAFETY: FFI call; + // the address is valid for the lifetime of the wrapper; + // the size is at least the size of an integer; + // the level and name of the option are valid and coherent. + to_result(unsafe { + bindings::sock_setsockopt( + self.0, + O::level() as isize as _, + O::value() as _, + value_ptr.to_raw() as _, + size as _, + ) + }) + } else { + // SAFETY: FFI call; + // the address is valid for the lifetime of the wrapper; + // the size is at least the size of an integer; + // the level and name of the option are valid and coherent. + to_result(unsafe { + (*(*self.0).ops) + .setsockopt + .map(|f| { + f( + self.0, + O::level() as _, + O::value() as _, + value_ptr.to_raw() as _, + size as _, + ) + }) + .unwrap_or(-(bindings::EOPNOTSUPP as i32)) + }) + } + } +} + +impl Drop for Socket { + /// Closes and releases the socket. + /// + /// Wraps the `sock_release` function. + fn drop(&mut self) { + // SAFETY: FFI call; the address is valid for the lifetime of the wrapper. + unsafe { + bindings::sock_release(self.0); + } + } +} + +// SAFETY: sockets are thread-safe; synchronization is handled by the kernel. +unsafe impl Send for Socket {} +unsafe impl Sync for Socket {} + +/// Socket header message. +/// +/// Wraps the `msghdr` structure. +/// This struct provides a safe interface to the `msghdr` structure. +/// +/// The instances of this struct are only created by the `receive` methods of the [`Socket`] struct. +/// +/// # Invariants +/// The `msg_name` in the wrapped `msghdr` object is always null; the address is stored in the `MessageHeader` object +/// and can be retrieved with the [`MessageHeader::address`] method. +#[derive(Clone, Copy)] +pub struct MessageHeader(pub(crate) bindings::msghdr, pub(crate) Option); + +impl MessageHeader { + /// Returns the address of the message. + pub fn address(&self) -> Option<&SocketAddr> { + self.1.as_ref() + } + + /// Returns the flags of the message. + pub fn flags(&self) -> FlagSet { + FlagSet::from(self.0.msg_flags as isize) + } + + /// Consumes the message header and returns the underlying `msghdr` structure. + /// + /// The returned msghdr will have a null pointer for the address. + pub fn into_raw(self) -> bindings::msghdr { + self.0 + } + + /// Creates a new message header. + /// + /// The `msg_name` of the field gets replaced with a NULL pointer. + pub(crate) fn new(mut hdr: bindings::msghdr, addr: Option) -> Self { + hdr.msg_name = core::ptr::null_mut(); + Self(hdr, addr) + } +} + +impl From for Option { + /// Consumes the message header and returns the contained address. + fn from(hdr: MessageHeader) -> Self { + hdr.1 + } +} + +impl From for bindings::msghdr { + /// Consumes the message header and returns the underlying `msghdr` structure. + /// + /// The returned msghdr will have a null pointer for the address. + /// + /// This function is actually supposed to be crate-public, since bindings are not supposed to be + /// used outside the kernel library. + /// However, until the support for `msghdr` is not complete, specific needs might be satisfied + /// only by using directly the underlying `msghdr` structure. + fn from(hdr: MessageHeader) -> Self { + hdr.0 + } +} + +#[derive(Clone, Copy)] +#[repr(transparent)] +struct SockPtr<'a>(bindings::sockptr_t, PhantomData<&'a ()>); + +impl<'a> SockPtr<'a> { + fn new(value: &'a T) -> Self + where + T: Sized, + { + let mut sockptr = bindings::sockptr_t::default(); + sockptr.__bindgen_anon_1.kernel = value as *const T as _; + sockptr._bitfield_1 = bindings::__BindgenBitfieldUnit::new([1; 1usize]); // kernel ptr + SockPtr(sockptr, PhantomData) + } + + fn to_raw(self) -> bindings::sockptr_t { + self.0 + } +} diff --git a/rust/kernel/net/socket/flags.rs b/rust/kernel/net/socket/flags.rs new file mode 100644 index 00000000000000..fe98e09a8d46e1 --- /dev/null +++ b/rust/kernel/net/socket/flags.rs @@ -0,0 +1,467 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Socket-related flags and utilities. +use crate::bindings; +use core::fmt::Debug; +use core::ops::{BitOr, BitOrAssign}; + +/// Generic socket flag trait. +/// +/// This trait represents any kind of flag with "bitmask" values (i.e. 0x1, 0x2, 0x4, 0x8, etc.) +pub trait Flag: + Into + TryFrom + Debug + Copy + Clone + Send + Sync + 'static +{ +} + +/// Socket send operation flags. +/// +/// See for more. +#[derive(Debug, Copy, Clone)] +pub enum SendFlag { + /// Got a successful reply. + /// + /// Only valid for datagram and raw sockets. + /// Only valid for IPv4 and IPv6. + Confirm = bindings::MSG_CONFIRM as isize, + + /// Don't use a gateway to send out the packet. + DontRoute = bindings::MSG_DONTROUTE as isize, + + /// Enables nonblocking operation. + /// + /// If the operation would block, return immediately with an error. + DontWait = bindings::MSG_DONTWAIT as isize, + + /// Terminates a record. + EOR = bindings::MSG_EOR as isize, + + /// More data will be sent. + /// + /// Only valid for TCP and UDP sockets. + More = bindings::MSG_MORE as isize, + + /// Don't send SIGPIPE error if the socket is shut down. + NoSignal = bindings::MSG_NOSIGNAL as isize, + + /// Send out-of-band data on supported sockets. + OOB = bindings::MSG_OOB as isize, +} + +impl From for isize { + fn from(value: SendFlag) -> Self { + value as isize + } +} + +impl TryFrom for SendFlag { + type Error = (); + + fn try_from(value: isize) -> Result { + let val = value as u32; + match val { + bindings::MSG_CONFIRM => Ok(SendFlag::Confirm), + bindings::MSG_DONTROUTE => Ok(SendFlag::DontRoute), + bindings::MSG_DONTWAIT => Ok(SendFlag::DontWait), + bindings::MSG_EOR => Ok(SendFlag::EOR), + bindings::MSG_MORE => Ok(SendFlag::More), + bindings::MSG_NOSIGNAL => Ok(SendFlag::NoSignal), + bindings::MSG_OOB => Ok(SendFlag::OOB), + _ => Err(()), + } + } +} + +impl Flag for SendFlag {} + +/// Socket receive operation flags. +/// +/// See for more. +#[derive(Debug, Copy, Clone)] +pub enum ReceiveFlag { + /// Enables nonblocking operation. + /// + /// If the operation would block, return immediately with an error. + DontWait = bindings::MSG_DONTWAIT as isize, + + /// Specifies that queued errors should be received from the socket error queue. + ErrQueue = bindings::MSG_ERRQUEUE as isize, + + /// Enables out-of-band reception. + OOB = bindings::MSG_OOB as isize, + + /// Peeks at an incoming message. + /// + /// The data is treated as unread and the next recv() or similar function shall still return this data. + Peek = bindings::MSG_PEEK as isize, + + /// Returns the real length of the packet, even when it was longer than the passed buffer. + /// + /// Only valid for raw, datagram, netlink and UNIX datagram sockets. + Trunc = bindings::MSG_TRUNC as isize, + + /// Waits for the full request to be satisfied. + WaitAll = bindings::MSG_WAITALL as isize, +} + +impl From for isize { + fn from(value: ReceiveFlag) -> Self { + value as isize + } +} + +impl TryFrom for ReceiveFlag { + type Error = (); + + fn try_from(value: isize) -> Result { + let val = value as u32; + match val { + bindings::MSG_DONTWAIT => Ok(ReceiveFlag::DontWait), + bindings::MSG_ERRQUEUE => Ok(ReceiveFlag::ErrQueue), + bindings::MSG_OOB => Ok(ReceiveFlag::OOB), + bindings::MSG_PEEK => Ok(ReceiveFlag::Peek), + bindings::MSG_TRUNC => Ok(ReceiveFlag::Trunc), + bindings::MSG_WAITALL => Ok(ReceiveFlag::WaitAll), + _ => Err(()), + } + } +} + +impl Flag for ReceiveFlag {} + +/// Socket `flags` field flags. +/// +/// These flags are used internally by the kernel. +/// However, they are exposed here for completeness. +/// +/// This enum does not implement the `Flag` trait, since it is not actually a flag. +/// Flags are often defined as a mask that can be used to retrieve the flag value; the socket flags, +/// instead, are defined as the index of the bit that they occupy in the `flags` field. +/// This means that they cannot be used as a mask, just like all the other flags that implement `Flag` do. +/// +/// For example, SOCK_PASSCRED has value 3, meaning that it is represented by the 3rd bit of the `flags` field; +/// a normal flag would represent it as a mask, i.e. 1 << 3 = 0b1000. +/// +/// See [include/linux/net.h](../../../../include/linux/net.h) for more. +pub enum SocketFlag { + /// Undocumented. + NoSpace = bindings::SOCK_NOSPACE as isize, + /// Undocumented. + PassCred = bindings::SOCK_PASSCRED as isize, + /// Undocumented. + PassSecurity = bindings::SOCK_PASSSEC as isize, + /// Undocumented. + SupportZeroCopy = bindings::SOCK_SUPPORT_ZC as isize, + /// Undocumented. + CustomSockOpt = bindings::SOCK_CUSTOM_SOCKOPT as isize, + /// Undocumented. + PassPidFd = bindings::SOCK_PASSPIDFD as isize, +} + +impl From for isize { + fn from(value: SocketFlag) -> Self { + value as isize + } +} + +impl TryFrom for SocketFlag { + type Error = (); + + fn try_from(value: isize) -> Result { + let val = value as u32; + match val { + bindings::SOCK_NOSPACE => Ok(SocketFlag::NoSpace), + bindings::SOCK_PASSCRED => Ok(SocketFlag::PassCred), + bindings::SOCK_PASSSEC => Ok(SocketFlag::PassSecurity), + bindings::SOCK_SUPPORT_ZC => Ok(SocketFlag::SupportZeroCopy), + bindings::SOCK_CUSTOM_SOCKOPT => Ok(SocketFlag::CustomSockOpt), + bindings::SOCK_PASSPIDFD => Ok(SocketFlag::PassPidFd), + _ => Err(()), + } + } +} + +/// Flags associated with a received message. +/// +/// Represents the flag contained in the `msg_flags` field of a `msghdr` struct. +#[derive(Debug, Copy, Clone)] +pub enum MessageFlag { + /// End of record. + Eor = bindings::MSG_EOR as isize, + /// Trailing portion of the message is discarded. + Trunc = bindings::MSG_TRUNC as isize, + /// Control data was discarded due to lack of space. + Ctrunc = bindings::MSG_CTRUNC as isize, + /// Out-of-band data was received. + Oob = bindings::MSG_OOB as isize, + /// An error was received instead of data. + ErrQueue = bindings::MSG_ERRQUEUE as isize, +} + +impl From for isize { + fn from(value: MessageFlag) -> Self { + value as isize + } +} + +impl TryFrom for MessageFlag { + type Error = (); + + fn try_from(value: isize) -> Result { + let val = value as u32; + match val { + bindings::MSG_EOR => Ok(MessageFlag::Eor), + bindings::MSG_TRUNC => Ok(MessageFlag::Trunc), + bindings::MSG_CTRUNC => Ok(MessageFlag::Ctrunc), + bindings::MSG_OOB => Ok(MessageFlag::Oob), + bindings::MSG_ERRQUEUE => Ok(MessageFlag::ErrQueue), + _ => Err(()), + } + } +} + +impl Flag for MessageFlag {} + +/// Structure representing a set of flags. +/// +/// This structure is used to represent a set of flags, such as the flags passed to `send` or `recv`. +/// It is generic over the type of flag that it contains. +/// +/// # Invariants +/// The value of the flags must be a valid combination of the flags that it contains. +/// +/// This means that the value must be the bitwise OR of the values of the flags, and that it +/// must be possible to retrieve the value of the flags from the value. +/// +/// # Example +/// ``` +/// use kernel::net::socket::flags::{SendFlag, FlagSet}; +/// +/// let mut flags = FlagSet::::empty(); +/// flags.insert(SendFlag::DontWait); +/// flags.insert(SendFlag::More); +/// assert!(flags.contains(SendFlag::DontWait)); +/// assert!(flags.contains(SendFlag::More)); +/// flags.clear(); +/// assert_eq!(flags.value(), 0); +/// +/// flags = FlagSet::::from(SendFlag::More); +/// flags |= SendFlag::DontWait; +/// assert!(flags.contains(SendFlag::DontWait)); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct FlagSet { + value: isize, + _phantom: core::marker::PhantomData, +} + +impl FlagSet { + /// Create a new empty set of flags. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let flags = FlagSet::::empty(); + /// assert_eq!(flags.value(), 0); + /// ``` + pub fn empty() -> Self { + FlagSet { + value: 0, + _phantom: core::marker::PhantomData, + } + } + + /// Clear all the flags set. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let mut flags = FlagSet::::from(SendFlag::More); + /// assert!(flags.contains(SendFlag::More)); + /// flags.clear(); + /// assert_eq!(flags.value(), 0); + /// ``` + pub fn clear(&mut self) { + self.value = 0; + } + + /// Add a flag to the set. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let mut flags = FlagSet::::empty(); + /// assert!(!flags.contains(SendFlag::DontWait)); + /// flags.insert(SendFlag::DontWait); + /// assert!(flags.contains(SendFlag::DontWait)); + /// ``` + pub fn insert(&mut self, flag: T) { + self.value |= flag.into(); + } + + /// Remove a flag from the set. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let mut flags = FlagSet::::from(SendFlag::DontWait); + /// assert!(flags.contains(SendFlag::DontWait)); + /// flags.remove(SendFlag::DontWait); + /// assert!(!flags.contains(SendFlag::DontWait)); + /// ``` + pub fn remove(&mut self, flag: T) { + self.value &= !flag.into(); + } + + /// Check if a flag is set. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let mut flags = FlagSet::::from(SendFlag::DontWait); + /// assert!(flags.contains(SendFlag::DontWait)); + /// ``` + pub fn contains(&self, flag: T) -> bool { + self.value & flag.into() != 0 + } + + /// Get the integer value of the flags set. + /// + /// # Example + /// ``` + /// use kernel::net::socket::flags::{SendFlag, FlagSet}; + /// + /// let flags = FlagSet::::from(SendFlag::DontWait); + /// assert_eq!(flags.value(), SendFlag::DontWait as isize); + /// ``` + pub fn value(&self) -> isize { + self.value + } +} + +impl BitOr for FlagSet { + type Output = FlagSet; + + fn bitor(self, rhs: T) -> Self::Output { + FlagSet { + value: self.value | rhs.into(), + _phantom: core::marker::PhantomData, + } + } +} + +impl BitOrAssign for FlagSet { + fn bitor_assign(&mut self, rhs: T) { + self.value |= rhs.into(); + } +} + +// impl from isize for any flags +impl From for FlagSet { + fn from(value: isize) -> Self { + FlagSet { + value, + _phantom: core::marker::PhantomData, + } + } +} + +impl From for FlagSet { + fn from(value: T) -> Self { + Self::from(value.into()) + } +} + +impl FromIterator for FlagSet { + fn from_iter>(iter: I) -> Self { + let mut flags = FlagSet::empty(); + for flag in iter { + flags.insert(flag); + } + flags + } +} + +impl From> for isize { + fn from(value: FlagSet) -> Self { + value.value + } +} + +impl IntoIterator for FlagSet { + type Item = T; + type IntoIter = FlagSetIterator; + + fn into_iter(self) -> Self::IntoIter { + FlagSetIterator { + flags: self, + current: 0, + } + } +} + +/// Iterator over the flags in a set. +/// +/// This iterator iterates over the flags in a set, in order of increasing value. +/// +/// # Example +/// ``` +/// use kernel::net::socket::flags::{SendFlag, FlagSet}; +/// +/// let mut flags = FlagSet::from_iter([SendFlag::DontWait, SendFlag::More]); +/// for flag in flags.into_iter() { +/// println!("Flag: {:?}", flag); +/// } +/// ``` +pub struct FlagSetIterator { + flags: FlagSet, + current: usize, +} + +impl Iterator for FlagSetIterator { + type Item = T; + + fn next(&mut self) -> Option { + let mut value = 1 << self.current; + while value <= self.flags.value { + self.current += 1; + if self.flags.value & value != 0 { + if let Ok(flag) = T::try_from(value) { + return Some(flag); + } + } + value = 1 << self.current; + } + None + } +} + +/// Create a set of flags from a list of flags. +/// +/// This macro provides a compact way to create empty sets and sets from a list of flags. +/// +/// # Example +/// ``` +/// use kernel::net::socket::flags::SendFlag; +/// use kernel::flag_set; +/// +/// let mut flags = flag_set!(SendFlag::DontWait, SendFlag::More); +/// assert!(flags.contains(SendFlag::DontWait)); +/// assert!(flags.contains(SendFlag::More)); +/// +/// let mut empty_flags = flag_set!(); +/// assert_eq!(empty_flags.value(), 0); +/// ``` +#[macro_export] +macro_rules! flag_set { + () => { + $crate::net::socket::flags::FlagSet::empty() + }; + ($($flag:expr),+) => { + $crate::net::socket::flags::FlagSet::from_iter([$($flag),+]) + }; +} diff --git a/rust/kernel/net/socket/opts.rs b/rust/kernel/net/socket/opts.rs new file mode 100644 index 00000000000000..6ca8ac35b305b6 --- /dev/null +++ b/rust/kernel/net/socket/opts.rs @@ -0,0 +1,1222 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Socket options. +//! +//! This module contains the types related to socket options. +//! It is meant to be used together with the [`Socket`](kernel::net::socket::Socket) type. +//! +//! Socket options have more sense in the user space than in the kernel space: the kernel can +//! directly access the socket data structures, so it does not need to use socket options. +//! However, that level of freedom is currently not available in the Rust kernel API; therefore, +//! having socket options is a good compromise. +//! +//! When Rust wrappers for the structures related to the socket (and required by the options, +//! e.g. `tcp_sock`, `inet_sock`, etc.) are available, the socket options will be removed, +//! and substituted by direct methods inside the socket types. + +use kernel::bindings; + +/// Options level to retrieve and set socket options. +/// See `man 7 socket` for more information. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum OptionsLevel { + /// IP level socket options. + /// See `man 7 ip` for more information. + Ip = bindings::IPPROTO_IP as isize, + + /// Socket level socket options. + /// See `man 7 socket` for more information. + Socket = bindings::SOL_SOCKET as isize, + + /// IPv6 level socket options. + /// See `man 7 ipv6` for more information. + Ipv6 = bindings::IPPROTO_IPV6 as isize, + + /// Raw level socket options. + /// See `man 7 raw` for more information. + Raw = bindings::IPPROTO_RAW as isize, + + /// TCP level socket options. + /// See `man 7 tcp` for more information. + Tcp = bindings::IPPROTO_TCP as isize, +} + +/// Generic socket option type. +/// +/// This trait is implemented by each individual socket option. +/// +/// Having socket options as structs instead of enums allows: +/// - Type safety, making sure that the correct type is used for each option. +/// - Read/write enforcement, making sure that only readable options +/// are read and only writable options are written. +pub trait SocketOption { + /// Rust type of the option value. + /// + /// This type is used to store the value of the option. + /// It is also used to enforce type safety. + /// + /// For example, the [`ip::Mtu`] option has a value of type `u32`. + type Type; + + /// Retrieve the C value of the option. + /// + /// This value is used to pass the option to the kernel. + fn value() -> isize; + + /// Retrieve the level of the option. + /// + /// This value is used to pass the option to the kernel. + fn level() -> OptionsLevel; +} + +/// Generic readable socket option type. +/// +/// This trait is implemented by each individual readable socket option. +/// Can be combined with [`WritableOption`] to create a readable and writable socket option. +pub trait WritableOption: SocketOption {} + +/// Generic writable socket option type. +/// +/// This trait is implemented by each individual writable socket option. +/// Can be combined with [`ReadableOption`] to create a readable and writable socket option. +pub trait ReadableOption: SocketOption {} + +/// Generates the code for the implementation of a socket option. +/// +/// # Parameters +/// * `$opt`: Name of the socket option. +/// * `$value`: C value of the socket option. +/// * `$level`: Level of the socket option, like [`OptionsLevel::Ip`]. +/// * `$rtyp`: Rust type of the socket option. +/// * `$($tr:ty),*`: Traits that the socket option implements, like [`WritableOption`]. +macro_rules! impl_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $level:expr, + unimplemented, + $($tr:ty),*) => {}; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $level:expr, + $rtyp:ty, + $($tr:ty),*) => { + $(#[$meta])* + #[repr(transparent)] + #[derive(Default)] + pub struct $opt; + impl SocketOption for $opt { + type Type = $rtyp; + fn value() -> isize { + $value as isize + } + fn level() -> OptionsLevel { + $level + } + } + $( + impl $tr for $opt {} + )* + }; +} + +pub mod ip { + //! IP socket options. + use super::{OptionsLevel, ReadableOption, SocketOption, WritableOption}; + use crate::net::addr::Ipv4Addr; + + macro_rules! impl_ip_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + unimplemented, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Ip, + unimplemented, + $($tr),* + ); + }; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $rtyp:ty, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Ip, + $rtyp, + $($tr),* + ); + }; + } + + impl_ip_opt!( + /// Join a multicast group. + /// + /// C value type: `struct ip_mreqn`. + AddMembership = bindings::IP_ADD_MEMBERSHIP, + unimplemented, + WritableOption + ); + impl_ip_opt!( + /// Join a multicast group with source filtering. + /// + /// C value type: `struct ip_mreq_source` + AddSourceMembership = bindings::IP_ADD_SOURCE_MEMBERSHIP, + unimplemented, + WritableOption + ); + impl_ip_opt!( + /// Don't reserve a port when binding with port number 0. + /// + /// C value type: `int` + BindAddressNoPort = bindings::IP_BIND_ADDRESS_NO_PORT, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Block packets from a specific source. + /// + /// C value type: `struct ip_mreq_source` + BlockSource = bindings::IP_BLOCK_SOURCE, + unimplemented, + WritableOption + ); + impl_ip_opt!( + /// Leave a multicast group. + /// + /// C value type: `struct ip_mreqn` + DropMembership = bindings::IP_DROP_MEMBERSHIP, + unimplemented, + WritableOption + ); + impl_ip_opt!( + /// Stop receiving packets from a specific source. + /// + /// C value type: `struct ip_mreq_source` + DropSourceMembership = bindings::IP_DROP_SOURCE_MEMBERSHIP, + unimplemented, + WritableOption + ); + impl_ip_opt!( + /// Allow binding to a non-local address. + /// + /// C value type: `int` + FreeBind = bindings::IP_FREEBIND, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Receive the IP header with the packet. + /// + /// C value type: `int` + Header = bindings::IP_HDRINCL, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Full-state multicast filtering API. + /// + /// C value type: `struct ip_msfilter` + MsFilter = bindings::IP_MSFILTER, + unimplemented, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Retrieve the MTU of the socket. + /// + /// C value type: `int` + Mtu = bindings::IP_MTU, + u32, + ReadableOption + ); + impl_ip_opt!( + /// Discover the MTU of the path to a destination. + /// + /// C value type: `int` + MtuDiscover = bindings::IP_MTU_DISCOVER, + unimplemented, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Modify delivery policy of messages. + /// + /// C value type: `int` + MulticastAll = bindings::IP_MULTICAST_ALL, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set the interface for outgoing multicast packets. + /// + /// C value type: `struct in_addr` + MulticastInterface = bindings::IP_MULTICAST_IF, + Ipv4Addr, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set whether multicast packets are looped back to the sender. + /// + /// C value type: `int` + MulticastLoop = bindings::IP_MULTICAST_LOOP, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set the TTL of outgoing multicast packets. + /// + /// C value type: `int` + MulticastTtl = bindings::IP_MULTICAST_TTL, + u8, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Whether to disable reassembling of fragmented packets. + /// + /// C value type: `int` + NoDefrag = bindings::IP_NODEFRAG, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set the options to be included in outgoing packets. + /// + /// C value type: `char *` + IpOptions = bindings::IP_OPTIONS, + unimplemented, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Enable receiving security context with the packet. + /// + /// C value type: `int` + PassSec = bindings::IP_PASSSEC, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Enable extended reliable error message passing. + /// + /// C value type: `int` + RecvErr = bindings::IP_RECVERR, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Pass all IP Router Alert messages to this socket. + /// + /// C value type: `int` + RouterAlert = bindings::IP_ROUTER_ALERT, + bool, + WritableOption + ); + impl_ip_opt!( + /// Set the TOS field of outgoing packets. + /// + /// C value type: `int` + Tos = bindings::IP_TOS, + u8, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set transparent proxying. + /// + /// C value type: `int` + Transparent = bindings::IP_TRANSPARENT, + bool, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Set the TTL of outgoing packets. + /// + /// C value type: `int` + Ttl = bindings::IP_TTL, + u8, + ReadableOption, + WritableOption + ); + impl_ip_opt!( + /// Unblock packets from a specific source. + /// + /// C value type: `struct ip_mreq_source` + UnblockSource = bindings::IP_UNBLOCK_SOURCE, + unimplemented, + WritableOption + ); +} + +pub mod sock { + //! Socket options. + use super::*; + use crate::net::ip::IpProtocol; + use crate::net::socket::SockType; + use crate::net::AddressFamily; + macro_rules! impl_sock_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + unimplemented, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Socket, + unimplemented, + $($tr),* + ); + }; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $rtyp:ty, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Socket, + $rtyp, + $($tr),* + ); + }; + } + + impl_sock_opt!( + /// Get whether the socket is accepting connections. + /// + /// C value type: `int` + AcceptConn = bindings::SO_ACCEPTCONN, + bool, + ReadableOption + ); + + impl_sock_opt!( + /// Attach a filter to the socket. + /// + /// C value type: `struct sock_fprog` + AttachFilter = bindings::SO_ATTACH_FILTER, + unimplemented, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Attach a eBPF program to the socket. + /// + /// C value type: `struct sock_fprog` + AttachBpf = bindings::SO_ATTACH_BPF, + unimplemented, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Bind the socket to a specific network device. + /// + /// C value type: `char *` + BindToDevice = bindings::SO_BINDTODEVICE, + &'static str, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Set the broadcast flag on the socket. + /// + /// Only valid for datagram sockets. + /// + /// C value type: `int` + Broadcast = bindings::SO_BROADCAST, + bool, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Enable BSD compatibility. + /// + /// C value type: `int` + BsdCompatible = bindings::SO_BSDCOMPAT, + bool, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Enable socket debugging. + /// + /// C value type: `int` + Debug = bindings::SO_DEBUG, + bool, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Remove BPF or eBPF program from the socket. + /// + /// The argument is ignored. + /// + /// C value type: `int` + DetachFilter = bindings::SO_DETACH_FILTER, + bool, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Get the domain of the socket. + /// + /// C value type: `int` + Domain = bindings::SO_DOMAIN, + AddressFamily, + ReadableOption + ); + impl_sock_opt!( + /// Get and clear pending errors. + /// + /// C value type: `int` + Error = bindings::SO_ERROR, + u32, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Only send packets to directly connected peers. + /// + /// C value type: `int` + DontRoute = bindings::SO_DONTROUTE, + bool, + ReadableOption, + WritableOption + ); + impl_sock_opt!( + /// Set or get the CPU affinity of a socket. + /// + /// C value type: `int` + IncomingCpu = bindings::SO_INCOMING_CPU, + u32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Enable keep-alive packets. + /// + /// C value type: `int` + KeepAlive = bindings::SO_KEEPALIVE, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the linger timeout. + /// + /// C value type: `struct linger` + Linger = bindings::SO_LINGER, + Linger, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Prevent changing the filters attached to the socket. + /// + /// C value type: `int` + LockFilter = bindings::SO_LOCK_FILTER, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the mark of the socket. + /// + /// C value type: `int` + Mark = bindings::SO_MARK, + u32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set whether out-of-band data is received in the normal data stream. + /// + /// C value type: `int` + OobInline = bindings::SO_OOBINLINE, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Enable the receiving of SCM credentials. + /// + /// C value type: `int` + PassCred = bindings::SO_PASSCRED, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set the peek offset for MSG_PEEK reads. + /// + /// Only valid for UNIX sockets. + /// + /// C value type: `int` + PeekOff = bindings::SO_PEEK_OFF, + i32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the protocol-defined priority for all packets. + /// + /// C value type: `int` + Priority = bindings::SO_PRIORITY, + u8, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Retrieve the socket protocol + /// + /// C value type: `int` + Protocol = bindings::SO_PROTOCOL, + IpProtocol, + ReadableOption + ); + + impl_sock_opt!( + /// Set or get the receive buffer size. + /// + /// C value type: `int` + RcvBuf = bindings::SO_RCVBUF, + u32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the receive low watermark. + /// + /// C value type: `int` + RcvLowat = bindings::SO_RCVLOWAT, + u32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the receive timeout. + /// + /// C value type: `struct timeval` + RcvTimeo = bindings::SO_RCVTIMEO_NEW, + unimplemented, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the reuse address flag. + /// + /// C value type: `int` + ReuseAddr = bindings::SO_REUSEADDR, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the reuse port flag. + /// + /// C value type: `int` + ReusePort = bindings::SO_REUSEPORT, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the send buffer size. + /// + /// C value type: `int` + SndBuf = bindings::SO_SNDBUF, + u32, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the send timeout. + /// + /// C value type: `struct timeval` + SndTimeo = bindings::SO_SNDTIMEO_NEW, + unimplemented, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set whether the timestamp control messages are received. + /// + /// C value type: `int` + Timestamp = bindings::SO_TIMESTAMP_NEW, + bool, + ReadableOption, + WritableOption + ); + + impl_sock_opt!( + /// Set or get the socket type. + /// + /// C value type: `int` + Type = bindings::SO_TYPE, + SockType, + ReadableOption + ); +} + +pub mod ipv6 { + //! IPv6 socket options. + use super::*; + use crate::net::AddressFamily; + macro_rules! impl_ipv6_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + unimplemented, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Ipv6, + unimplemented, + $($tr),* + ); + }; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $rtyp:ty, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Ipv6, + $rtyp, + $($tr),* + ); + }; + } + + impl_ipv6_opt!( + /// Modify the address family used by the socket. + /// + /// C value type: `int` + AddrForm = bindings::IPV6_ADDRFORM, + AddressFamily, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Join a multicast group. + /// + /// C value type: `struct ipv6_mreq` + AddMembership = bindings::IPV6_ADD_MEMBERSHIP, + unimplemented, + WritableOption + ); + + impl_ipv6_opt!( + /// Leave a multicast group. + /// + /// C value type: `struct ipv6_mreq` + DropMembership = bindings::IPV6_DROP_MEMBERSHIP, + unimplemented, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get the MTU of the socket. + /// + /// C value type: `int` + Mtu = bindings::IPV6_MTU, + u32, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or retrieve the MTU discovery settings. + /// + /// C value type: `int` (macros) + MtuDiscover = bindings::IPV6_MTU_DISCOVER, + unimplemented, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get the multicast hop limit. + /// + /// Range is -1 to 255. + /// + /// C value type: `int` + MulticastHops = bindings::IPV6_MULTICAST_HOPS, + i16, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get the multicast interface. + /// + /// Only valid for datagram and raw sockets. + /// + /// C value type: `int` + MulticastInterface = bindings::IPV6_MULTICAST_IF, + u32, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or read whether multicast packets are looped back + /// + /// C value type: `int` + MulticastLoop = bindings::IPV6_MULTICAST_LOOP, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get whether IPV6_PKTINFO is enabled. + /// + /// Only valid for datagram and raw sockets. + /// + /// C value type: `int` + ReceivePktInfo = bindings::IPV6_PKTINFO, + bool, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get whether IPV6_RTHDR messages are delivered. + /// + /// Only valid for raw sockets. + /// + /// C value type: `int` + RouteHdr = bindings::IPV6_RTHDR, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get whether IPV6_DSTOPTS messages are delivered. + /// + /// Only valid for datagram and raw sockets. + /// + /// C value type: `int` + DestOptions = bindings::IPV6_DSTOPTS, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get whether IPV6_HOPOPTS messages are delivered. + /// + /// Only valid for datagram and raw sockets. + /// + /// C value type: `int` + HopOptions = bindings::IPV6_HOPOPTS, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get whether IPV6_FLOWINFO messages are delivered. + /// + /// Only valid for datagram and raw sockets. + /// + /// C value type: `int` + FlowInfo = bindings::IPV6_FLOWINFO, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Enable extended reliable error message reporting. + /// + /// C value type: `int` + RecvErr = bindings::IPV6_RECVERR, + bool, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Pass all Router Alert enabled messages to the socket. + /// + /// Only valid for raw sockets. + /// + /// C value type: `int` + RouterAlert = bindings::IPV6_ROUTER_ALERT, + bool, + WritableOption + ); + + impl_ipv6_opt!( + /// Set or get the unicast hop limit. + /// + /// Range is -1 to 255. + /// + /// C value type: `int` + UnicastHops = bindings::IPV6_UNICAST_HOPS, + i16, + ReadableOption, + WritableOption + ); + + impl_ipv6_opt!( + /// Set whether the socket can only send and receive IPv6 packets. + /// + /// C value type: `int` + V6Only = bindings::IPV6_V6ONLY, + bool, + ReadableOption, + WritableOption + ); +} + +pub mod raw { + //! Raw socket options. + //! + //! These options are only valid for sockets with type [`SockType::Raw`](kernel::net::socket::SockType::Raw). + macro_rules! impl_raw_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + unimplemented, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Raw, + unimplemented, + $($tr),* + ); + }; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $rtyp:ty, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Raw, + $rtyp, + $($tr),* + ); + }; + } + + impl_raw_opt!( + /// Enable a filter for IPPROTO_ICMP raw sockets. + /// The filter has a bit set for each ICMP type to be filtered out. + /// + /// C value type: `struct icmp_filter` + Filter = bindings::ICMP_FILTER as isize, + unimplemented, + ReadableOption, + WritableOption + ); +} + +pub mod tcp { + //! TCP socket options. + //! + //! These options are only valid for sockets with type [`SockType::Stream`](kernel::net::socket::SockType::Stream) + //! and protocol [`IpProtocol::Tcp`](kernel::net::ip::IpProtocol::Tcp). + use super::*; + macro_rules! impl_tcp_opt { + ($(#[$meta:meta])* + $opt:ident = $value:expr, + unimplemented, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Tcp, + unimplemented, + $($tr),* + ); + }; + + ($(#[$meta:meta])* + $opt:ident = $value:expr, + $rtyp:ty, + $($tr:ty),*) => { + impl_opt!( + $(#[$meta])* + $opt = $value, + OptionsLevel::Tcp, + $rtyp, + $($tr),* + ); + }; + } + + impl_tcp_opt!( + /// Set or get the congestion control algorithm to be used. + /// + /// C value type: `char *` + Congestion = bindings::TCP_CONGESTION, + unimplemented, // &[u8]? what about lifetime? + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// If true, don't send partial frames. + /// + /// C value type: `int` + Cork = bindings::TCP_CORK, + bool, + WritableOption, + ReadableOption + ); + + impl_tcp_opt!( + /// Allow a listener to be awakened only when data arrives. + /// The value is the time to wait for data in milliseconds. + /// + /// C value type: `int` + DeferAccept = bindings::TCP_DEFER_ACCEPT, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Collect information about this socket. + /// + /// C value type: `struct tcp_info` + Info = bindings::TCP_INFO, + unimplemented, + ReadableOption + ); + + impl_tcp_opt!( + /// Set or get maximum number of keepalive probes to send. + /// + /// C value type: `int` + KeepCount = bindings::TCP_KEEPCNT, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get the time in seconds to idle before sending keepalive probes. + /// + /// C value type: `int` + KeepIdle = bindings::TCP_KEEPIDLE, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get the time in seconds between keepalive probes. + /// + /// C value type: `int` + KeepInterval = bindings::TCP_KEEPINTVL, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get the lifetime or orphaned FIN_WAIT2 sockets. + /// + /// C value type: `int` + Linger2 = bindings::TCP_LINGER2, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get the maximum segment size for outgoing TCP packets. + /// + /// C value type: `int` + MaxSeg = bindings::TCP_MAXSEG, + i32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// If true, Nagle algorithm is disabled, i.e. segments are send as soon as possible. + /// + /// C value type: `int` + NoDelay = bindings::TCP_NODELAY, + bool, + WritableOption, + ReadableOption + ); + + impl_tcp_opt!( + /// Set or get whether QuickAck mode is on. + /// If true, ACKs are sent immediately, rather than delayed. + /// + /// C value type: `int` + QuickAck = bindings::TCP_QUICKACK, + bool, + WritableOption, + ReadableOption + ); + + impl_tcp_opt!( + /// Set or get the number of SYN retransmits before the connection is dropped. + /// + /// C value type: `int` + SynCount = bindings::TCP_SYNCNT, + u8, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get how long sent packets can remain unacknowledged before timing out. + /// The value is in milliseconds; 0 means to use the system default. + /// + /// C value type: `unsigned int` + UserTimeout = bindings::TCP_USER_TIMEOUT, + u32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Set or get the maximum window size for TCP sockets. + /// + /// C value type: `int` + WindowClamp = bindings::TCP_WINDOW_CLAMP, + u32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Enable Fast Open on the listener socket (RFC 7413). + /// The value is the maximum length of pending SYNs. + /// + /// C value type: `int` + FastOpen = bindings::TCP_FASTOPEN, + u32, + ReadableOption, + WritableOption + ); + + impl_tcp_opt!( + /// Enable Fast Open on the client socket (RFC 7413). + /// + /// C value type: `int` + FastOpenConnect = bindings::TCP_FASTOPEN_CONNECT, + bool, + ReadableOption, + WritableOption + ); +} + +/// Linger structure to set and get the [sock::Linger] option. +/// This is a wrapper around the C struct `linger`. +#[repr(transparent)] +pub struct Linger(bindings::linger); + +impl Linger { + /// Create a "on" Linger object with the given linger time. + /// This is equivalent to `linger { l_onoff: 1, l_linger: linger_time }`. + /// The linger time is in seconds. + /// + /// # Example + /// ``` + /// use kernel::net::socket::opts::Linger; + /// let linger = Linger::on(10); + /// assert!(linger.is_on()); + /// assert_eq!(linger.linger_time(), 10); + pub fn on(linger: i32) -> Self { + Linger(bindings::linger { + l_onoff: 1 as _, + l_linger: linger as _, + }) + } + + /// Create an "off" Linger object. + /// This is equivalent to `linger { l_onoff: 0, l_linger: 0 }`. + /// + /// # Example + /// ``` + /// use kernel::net::socket::opts::Linger; + /// let linger = Linger::off(); + /// assert!(!linger.is_on()); + pub fn off() -> Self { + Linger(bindings::linger { + l_onoff: 0 as _, + l_linger: 0 as _, + }) + } + + /// Get whether the linger option is on. + /// + /// # Example + /// ``` + /// use kernel::net::socket::opts::Linger; + /// let linger = Linger::on(10); + /// assert!(linger.is_on()); + /// ``` + /// + /// ``` + /// use kernel::net::socket::opts::Linger; + /// let linger = Linger::off(); + /// assert!(!linger.is_on()); + /// ``` + pub fn is_on(&self) -> bool { + self.0.l_onoff != 0 + } + + /// Get the linger time in seconds. + /// If the linger option is off, this will return 0. + /// + /// # Example + /// ``` + /// use kernel::net::socket::opts::Linger; + /// let linger = Linger::on(10); + /// assert_eq!(linger.linger_time(), 10); + /// ``` + pub fn linger_time(&self) -> i32 { + self.0.l_linger as _ + } +} diff --git a/rust/kernel/net/tcp.rs b/rust/kernel/net/tcp.rs new file mode 100644 index 00000000000000..86a42ac3e36710 --- /dev/null +++ b/rust/kernel/net/tcp.rs @@ -0,0 +1,252 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! TCP socket wrapper. +//! +//! This module contains wrappers for a TCP Socket ([`TcpListener`]) and an active +//! TCP connection ([`TcpStream`]). +//! The wrappers are just convenience structs around the generic [`Socket`] type. +//! +//! The API is inspired by the Rust standard library's [`TcpListener`](https://doc.rust-lang.org/std/net/struct.TcpListener.html) and [`TcpStream`](https://doc.rust-lang.org/std/net/struct.TcpStream.html). + +use crate::error::Result; +use crate::net::addr::SocketAddr; +use crate::net::ip::IpProtocol; +use crate::net::socket::flags::{FlagSet, ReceiveFlag, SendFlag}; +use crate::net::socket::opts::{SocketOption, WritableOption}; +use crate::net::socket::{ShutdownCmd, SockType, Socket}; +use crate::net::AddressFamily; +use kernel::net::socket::MessageHeader; + +/// A TCP listener. +/// +/// Wraps the [`Socket`] type to create a TCP-specific interface. +/// +/// The wrapper abstracts away the generic Socket methods that a connection-oriented +/// protocol like TCP does not need. +/// +/// # Examples +/// ```rust +/// use kernel::net::tcp::TcpListener; +/// use kernel::net::addr::*; +/// +/// let listener = TcpListener::new(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); +/// while let Ok(stream) = listener.accept() { +/// // ... +/// } +pub struct TcpListener(pub(crate) Socket); + +impl TcpListener { + /// Create a new TCP listener bound to the given address. + /// + /// The listener will be ready to accept connections. + pub fn new(address: SocketAddr) -> Result { + let socket = Socket::new(AddressFamily::Inet, SockType::Stream, IpProtocol::Tcp)?; + socket.bind(address)?; + socket.listen(128)?; + Ok(Self(socket)) + } + + /// Returns the local address that this listener is bound to. + /// + /// See [`Socket::sockname()`] for more. + pub fn sockname(&self) -> Result { + self.0.sockname() + } + + /// Returns an iterator over incoming connections. + /// + /// Each iteration will return a [`Result`] containing a [`TcpStream`] on success. + /// See [`TcpIncoming`] for more. + /// + /// # Examples + /// ```rust + /// use kernel::net::tcp::TcpListener; + /// use kernel::net::addr::*; + /// + /// let listener = TcpListener::new(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); + /// for stream in listener.incoming() { + /// // ... + /// } + /// ``` + pub fn incoming(&self) -> TcpIncoming<'_> { + TcpIncoming { listener: self } + } + + /// Accepts an incoming connection. + /// + /// Returns a [`TcpStream`] on success. + pub fn accept(&self) -> Result { + Ok(TcpStream(self.0.accept(true)?)) + } + + /// Sets the value of the given option. + /// + /// See [`Socket::set_option()`](Socket::set_option) for more. + pub fn set_option(&self, value: impl Into) -> Result + where + O: SocketOption + WritableOption, + { + self.0.set_option::(value) + } +} + +/// An iterator over incoming connections from a [`TcpListener`]. +/// +/// Each iteration will return a [`Result`] containing a [`TcpStream`] on success. +/// The iterator will never return [`None`]. +/// +/// This struct is created by the [`TcpListener::incoming()`] method. +pub struct TcpIncoming<'a> { + listener: &'a TcpListener, +} + +impl Iterator for TcpIncoming<'_> { + /// The item type of the iterator. + type Item = Result; + + /// Get the next connection from the listener. + fn next(&mut self) -> Option { + Some(self.listener.accept()) + } +} + +/// A TCP stream. +/// +/// Represents an active TCP connection between two sockets. +/// The stream can be opened by the listener, with [`TcpListener::accept()`], or by +/// connecting to a remote address with [`TcpStream::connect()`]. +/// The stream can be used to send and receive data. +/// +/// See [`TcpListener`] for an example of how to create a [`TcpStream`]. +pub struct TcpStream(pub(crate) Socket); + +impl TcpStream { + /// Opens a TCP stream by connecting to the given address. + /// + /// Returns a [`TcpStream`] on success. + /// + /// # Examples + /// ```rust + /// use kernel::net::tcp::TcpStream; + /// use kernel::net::addr::*; + /// + /// let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)); + /// let stream = TcpStream::connect(&peer_addr).unwrap(); + /// ``` + pub fn connect(address: &SocketAddr) -> Result { + let socket = Socket::new(AddressFamily::Inet, SockType::Stream, IpProtocol::Tcp)?; + socket.connect(address, 0)?; + Ok(Self(socket)) + } + + /// Returns the address of the remote peer of this connection. + /// + /// See [`Socket::peername()`] for more. + pub fn peername(&self) -> Result { + self.0.peername() + } + + /// Returns the address of the local socket of this connection. + /// + /// See [`Socket::sockname()`] for more. + pub fn sockname(&self) -> Result { + self.0.sockname() + } + + /// Receive data from the stream. + /// The given flags are used to modify the behavior of the receive operation. + /// See [`ReceiveFlag`] for more. + /// + /// Returns the number of bytes received. + /// + /// # Examples + /// ```rust + /// use kernel::flag_set; + /// use kernel::net::tcp::TcpListener; + /// use kernel::net::addr::*; + /// + /// let listener = TcpListener::new(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); + /// while let Ok(stream) = listener.accept() { + /// let mut buf = [0u8; 1024]; + /// while let Ok(len) = stream.receive(&mut buf, flag_set!()) { + /// // ... + /// } + /// } + /// ``` + pub fn receive(&self, buf: &mut [u8], flags: FlagSet) -> Result { + self.0.receive(buf, flags) + } + + /// Receive data from the stream and return the message header. + /// + /// The given flags are used to modify the behavior of the receive operation. + /// + /// Returns the number of bytes received and the message header, which contains + /// information about the sender and the message. + /// + /// See [`Socket::receive_msg()`] for more. + pub fn receive_msg( + &self, + buf: &mut [u8], + flags: FlagSet, + ) -> Result<(usize, MessageHeader)> { + self.0.receive_msg(buf, flags) + } + + /// Send data to the stream. + /// + /// The given flags are used to modify the behavior of the send operation. + /// See [`SendFlag`] for more. + /// + /// Returns the number of bytes sent. + /// + /// # Examples + /// ```rust + /// use kernel::flag_set; + /// use kernel::net::tcp::TcpListener; + /// use kernel::net::addr::*; + /// + /// let listener = TcpListener::new(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); + /// while let Ok(stream) = listener.accept() { + /// let mut buf = [0u8; 1024]; + /// while let Ok(len) = stream.receive(&mut buf, flag_set!()) { + /// stream.send(&buf[..len], flag_set!())?; + /// } + /// } + /// ``` + pub fn send(&self, buf: &[u8], flags: FlagSet) -> Result { + self.0.send(buf, flags) + } + + /// Manually shutdown some portion of the stream. + /// See [`ShutdownCmd`] for more. + /// + /// This method is not required to be called, as the stream will be shutdown + /// automatically when it is dropped. + /// + /// # Examples + /// ```rust + /// use kernel::net::tcp::TcpListener; + /// use kernel::net::addr::*; + /// use kernel::net::socket::ShutdownCmd; + /// + /// let listener = TcpListener::new(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); + /// while let Ok(stream) = listener.accept() { + /// // ... + /// stream.shutdown(ShutdownCmd::Both)?; + /// } + /// ``` + pub fn shutdown(&self, how: ShutdownCmd) -> Result { + self.0.shutdown(how) + } +} + +impl Drop for TcpStream { + /// Shutdown the stream. + /// + /// This method ignores the outcome of the shutdown operation: whether the stream + /// is successfully shutdown or not, the stream will be dropped anyways. + fn drop(&mut self) { + self.0.shutdown(ShutdownCmd::Both).ok(); + } +} diff --git a/rust/kernel/net/udp.rs b/rust/kernel/net/udp.rs new file mode 100644 index 00000000000000..9193292a30f64b --- /dev/null +++ b/rust/kernel/net/udp.rs @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! UDP socket wrapper. +//! +//! This module contains wrappers for a UDP Socket ([`UdpSocket`]). +//! The wrapper is just convenience structs around the generic [`Socket`] type. +//! +//! The API is inspired by the Rust standard library's [`UdpSocket`](https://doc.rust-lang.org/std/net/struct.UdpSocket.html). + +use crate::error::Result; +use crate::net::addr::SocketAddr; +use crate::net::ip::IpProtocol; +use crate::net::socket::flags::{FlagSet, ReceiveFlag, SendFlag}; +use crate::net::socket::{opts::SocketOption, MessageHeader, SockType, Socket}; +use crate::net::AddressFamily; +use kernel::net::socket::opts::WritableOption; + +/// A UDP socket. +/// +/// Provides an interface to send and receive UDP packets, removing +/// all the socket functionality that is not needed for UDP. +/// +/// # Examples +/// ```rust +/// use kernel::flag_set; +/// use kernel::net::udp::UdpSocket; +/// use kernel::net::addr::*; +/// +/// let socket = UdpSocket::new().unwrap(); +/// socket.bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000))).unwrap(); +/// let mut buf = [0u8; 1024]; +/// while let Ok((len, addr)) = socket.receive_from(&mut buf, flag_set!()) { +/// socket.send_to(&buf[..len], &addr, flag_set!()).unwrap(); +/// } +/// ``` +pub struct UdpSocket(pub(crate) Socket); + +impl UdpSocket { + /// Creates a UDP socket. + /// + /// Returns a [`UdpSocket`] on success. + pub fn new() -> Result { + Ok(Self(Socket::new( + AddressFamily::Inet, + SockType::Datagram, + IpProtocol::Udp, + )?)) + } + + /// Binds the socket to the given address. + pub fn bind(&self, address: SocketAddr) -> Result { + self.0.bind(address) + } + + /// Returns the socket's local address. + /// + /// This function assumes the socket is bound, + /// i.e. it must be called after [`bind()`](UdpSocket::bind). + /// + /// # Examples + /// ```rust + /// use kernel::net::udp::UdpSocket; + /// use kernel::net::addr::*; + /// + /// let socket = UdpSocket::new().unwrap(); + /// let local_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)); + /// socket.bind(local_addr).unwrap(); + /// assert_eq!(socket.sockname().unwrap(), local_addr); + pub fn sockname(&self) -> Result { + self.0.sockname() + } + + /// Returns the socket's peer address. + /// + /// This function assumes the socket is connected, + /// i.e. it must be called after [`connect()`](UdpSocket::connect). + /// + /// # Examples + /// ```rust + /// use kernel::net::udp::UdpSocket; + /// use kernel::net::addr::*; + /// + /// let socket = UdpSocket::new().unwrap(); + /// let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)); + /// socket.connect(&peer_addr).unwrap(); + /// assert_eq!(socket.peername().unwrap(), peer_addr); + pub fn peername(&self) -> Result { + self.0.peername() + } + + /// Receive a message from the socket. + /// + /// The given flags are used to modify the behavior of the receive operation. + /// See [`ReceiveFlag`] for more. + /// + /// The returned [`MessageHeader`] contains metadata about the received message. + /// + /// See [`Socket::receive_msg()`] for more. + pub fn receive_msg( + &self, + buf: &mut [u8], + flags: FlagSet, + ) -> Result<(usize, MessageHeader)> { + self.0.receive_msg(buf, flags) + } + + /// Receives data from another socket. + /// + /// The given flags are used to modify the behavior of the receive operation. + /// See [`ReceiveFlag`] for more. + /// + /// Returns the number of bytes received and the address of the sender. + pub fn receive_from( + &self, + buf: &mut [u8], + flags: FlagSet, + ) -> Result<(usize, SocketAddr)> { + self.0 + .receive_from(buf, flags) + .map(|(size, addr)| (size, addr.unwrap())) + } + + /// Sends data to another socket. + /// + /// The given flags are used to modify the behavior of the send operation. + /// See [`SendFlag`] for more. + /// + /// Returns the number of bytes sent. + pub fn send_to( + &self, + buf: &[u8], + address: &SocketAddr, + flags: FlagSet, + ) -> Result { + self.0.send_to(buf, address, flags) + } + + /// Connects the socket to the given address. + /// + /// # Examples + /// ```rust + /// use kernel::net::udp::UdpSocket; + /// use kernel::net::addr::*; + /// + /// let socket = UdpSocket::new().unwrap(); + /// let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOOPBACK, 8000)); + /// socket.connect(&peer_addr).unwrap(); + /// ``` + pub fn connect(&self, address: &SocketAddr) -> Result { + self.0.connect(address, 0) + } + + /// Receives data from the connected socket. + /// + /// This function assumes the socket is connected, + /// i.e. it must be called after [`connect()`](UdpSocket::connect). + /// + /// Returns the number of bytes received. + pub fn receive(&self, buf: &mut [u8], flags: FlagSet) -> Result { + self.0.receive(buf, flags) + } + + /// Sends data to the connected socket. + /// + /// This function assumes the socket is connected, + /// i.e. it must be called after [`connect()`](UdpSocket::connect). + /// + /// Returns the number of bytes sent. + pub fn send(&self, buf: &[u8], flags: FlagSet) -> Result { + self.0.send(buf, flags) + } + + /// Sets the value of the given option. + /// + /// See [`Socket::set_option()`](Socket::set_option) for more. + pub fn set_option(&self, value: impl Into) -> Result + where + O: SocketOption + WritableOption, + { + self.0.set_option::(value) + } +}