use std::cmp::Ordering;
use std::convert::TryFrom;
use std::ffi::OsString;
use std::fs::remove_file;
use std::io;
use std::mem;
use std::mem::size_of;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use std::net::SocketAddr;
use std::net::SocketAddrV4;
use std::net::SocketAddrV6;
use std::net::TcpListener;
use std::net::TcpStream;
use std::net::ToSocketAddrs;
use std::ops::Deref;
use std::os::fd::OwnedFd;
use std::os::unix::ffi::OsStringExt;
use std::path::Path;
use std::path::PathBuf;
use std::ptr::null_mut;
use std::time::Duration;
use std::time::Instant;
use libc::c_int;
use libc::recvfrom;
use libc::sa_family_t;
use libc::sockaddr;
use libc::sockaddr_in;
use libc::sockaddr_in6;
use libc::socklen_t;
use libc::AF_INET;
use libc::AF_INET6;
use libc::MSG_PEEK;
use libc::MSG_TRUNC;
use log::warn;
use serde::Deserialize;
use serde::Serialize;
use crate::descriptor::AsRawDescriptor;
use crate::descriptor::FromRawDescriptor;
use crate::descriptor::IntoRawDescriptor;
use crate::handle_eintr_errno;
use crate::sys::sockaddr_un;
use crate::sys::sockaddrv4_to_lib_c;
use crate::sys::sockaddrv6_to_lib_c;
use crate::Error;
use crate::RawDescriptor;
use crate::SafeDescriptor;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum InetVersion {
V4,
V6,
}
impl InetVersion {
pub fn from_sockaddr(s: &SocketAddr) -> Self {
match s {
SocketAddr::V4(_) => InetVersion::V4,
SocketAddr::V6(_) => InetVersion::V6,
}
}
}
impl From<InetVersion> for sa_family_t {
fn from(v: InetVersion) -> sa_family_t {
match v {
InetVersion::V4 => AF_INET as sa_family_t,
InetVersion::V6 => AF_INET6 as sa_family_t,
}
}
}
pub(in crate::sys) fn socket(
domain: c_int,
sock_type: c_int,
protocol: c_int,
) -> io::Result<SafeDescriptor> {
match unsafe { libc::socket(domain, sock_type, protocol) } {
-1 => Err(io::Error::last_os_error()),
fd => Ok(unsafe { SafeDescriptor::from_raw_descriptor(fd) }),
}
}
pub(in crate::sys) fn socketpair(
domain: c_int,
sock_type: c_int,
protocol: c_int,
) -> io::Result<(SafeDescriptor, SafeDescriptor)> {
let mut fds = [0, 0];
match unsafe { libc::socketpair(domain, sock_type, protocol, fds.as_mut_ptr()) } {
-1 => Err(io::Error::last_os_error()),
_ => Ok(
unsafe {
(
SafeDescriptor::from_raw_descriptor(fds[0]),
SafeDescriptor::from_raw_descriptor(fds[1]),
)
},
),
}
}
#[derive(Debug)]
pub struct TcpSocket {
pub(in crate::sys) inet_version: InetVersion,
pub(in crate::sys) descriptor: SafeDescriptor,
}
impl TcpSocket {
pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
let sockaddr = addr
.to_socket_addrs()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
.next()
.unwrap();
let ret = match sockaddr {
SocketAddr::V4(a) => {
let sin = sockaddrv4_to_lib_c(&a);
unsafe {
libc::bind(
self.as_raw_descriptor(),
&sin as *const sockaddr_in as *const sockaddr,
size_of::<sockaddr_in>() as socklen_t,
)
}
}
SocketAddr::V6(a) => {
let sin6 = sockaddrv6_to_lib_c(&a);
unsafe {
libc::bind(
self.as_raw_descriptor(),
&sin6 as *const sockaddr_in6 as *const sockaddr,
size_of::<sockaddr_in6>() as socklen_t,
)
}
}
};
if ret < 0 {
let bind_err = io::Error::last_os_error();
Err(bind_err)
} else {
Ok(())
}
}
pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
let sockaddr = addr
.to_socket_addrs()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
.next()
.unwrap();
let ret = match sockaddr {
SocketAddr::V4(a) => {
let sin = sockaddrv4_to_lib_c(&a);
unsafe {
libc::connect(
self.as_raw_descriptor(),
&sin as *const sockaddr_in as *const sockaddr,
size_of::<sockaddr_in>() as socklen_t,
)
}
}
SocketAddr::V6(a) => {
let sin6 = sockaddrv6_to_lib_c(&a);
unsafe {
libc::connect(
self.as_raw_descriptor(),
&sin6 as *const sockaddr_in6 as *const sockaddr,
size_of::<sockaddr_in>() as socklen_t,
)
}
}
};
if ret < 0 {
let connect_err = io::Error::last_os_error();
Err(connect_err)
} else {
Ok(TcpStream::from(self.descriptor))
}
}
pub fn listen(self) -> io::Result<TcpListener> {
let ret = unsafe { libc::listen(self.as_raw_descriptor(), 1) };
if ret < 0 {
let listen_err = io::Error::last_os_error();
Err(listen_err)
} else {
Ok(TcpListener::from(self.descriptor))
}
}
pub fn local_port(&self) -> io::Result<u16> {
match self.inet_version {
InetVersion::V4 => {
let mut sin = sockaddrv4_to_lib_c(&SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0));
let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
let ret = unsafe {
libc::getsockname(
self.as_raw_descriptor(),
&mut sin as *mut sockaddr_in as *mut sockaddr,
&mut addrlen as *mut socklen_t,
)
};
if ret < 0 {
let getsockname_err = io::Error::last_os_error();
Err(getsockname_err)
} else {
assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
Ok(u16::from_be(sin.sin_port))
}
}
InetVersion::V6 => {
let mut sin6 = sockaddrv6_to_lib_c(&SocketAddrV6::new(
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
0,
0,
0,
));
let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
let ret = unsafe {
libc::getsockname(
self.as_raw_descriptor(),
&mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
&mut addrlen as *mut socklen_t,
)
};
if ret < 0 {
let getsockname_err = io::Error::last_os_error();
Err(getsockname_err)
} else {
assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
Ok(u16::from_be(sin6.sin6_port))
}
}
}
}
}
impl AsRawDescriptor for TcpSocket {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.descriptor.as_raw_descriptor()
}
}
pub(in crate::sys) fn sun_path_offset() -> usize {
std::mem::offset_of!(libc::sockaddr_un, sun_path)
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UnixSeqpacket(SafeDescriptor);
impl UnixSeqpacket {
pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
let (addr, len) = sockaddr_un(path.as_ref())?;
unsafe {
let ret = libc::connect(
descriptor.as_raw_descriptor(),
&addr as *const _ as *const _,
len,
);
if ret < 0 {
return Err(io::Error::last_os_error());
}
}
Ok(UnixSeqpacket(descriptor))
}
pub fn try_clone(&self) -> io::Result<Self> {
Ok(Self(self.0.try_clone()?))
}
pub fn get_readable_bytes(&self) -> io::Result<usize> {
let mut byte_count = 0i32;
let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONREAD, &mut byte_count) };
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(byte_count as usize)
}
}
pub fn next_packet_size(&self) -> io::Result<usize> {
#[cfg(not(debug_assertions))]
let buf = null_mut();
#[cfg(debug_assertions)]
let buf = &mut 0 as *mut _ as *mut _;
let ret = unsafe {
recvfrom(
self.as_raw_descriptor(),
buf,
0,
MSG_TRUNC | MSG_PEEK,
null_mut(),
null_mut(),
)
};
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(ret as usize)
}
}
pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
unsafe {
let ret = libc::write(
self.as_raw_descriptor(),
buf.as_ptr() as *const _,
buf.len(),
);
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(ret as usize)
}
}
}
pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
unsafe {
let ret = libc::read(
self.as_raw_descriptor(),
buf.as_mut_ptr() as *mut _,
buf.len(),
);
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(ret as usize)
}
}
}
pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
let packet_size = self.next_packet_size()?;
buf.resize(packet_size, 0);
let read_bytes = self.recv(buf)?;
buf.resize(read_bytes, 0);
Ok(())
}
pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
self.recv_to_vec(&mut buf)?;
Ok(buf)
}
#[allow(clippy::useless_conversion)]
fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
let timeval = match timeout {
Some(t) => {
if t.as_secs() == 0 && t.subsec_micros() == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"zero timeout duration is invalid",
));
}
let nsec = t.subsec_micros() as i32;
libc::timeval {
tv_sec: t.as_secs() as libc::time_t,
tv_usec: libc::suseconds_t::from(nsec),
}
}
None => libc::timeval {
tv_sec: 0,
tv_usec: 0,
},
};
let ret = unsafe {
libc::setsockopt(
self.as_raw_descriptor(),
libc::SOL_SOCKET,
kind,
&timeval as *const libc::timeval as *const libc::c_void,
mem::size_of::<libc::timeval>() as libc::socklen_t,
)
};
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
self.set_timeout(timeout, libc::SO_RCVTIMEO)
}
pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
self.set_timeout(timeout, libc::SO_SNDTIMEO)
}
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
let mut nonblocking = nonblocking as libc::c_int;
let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
}
impl From<UnixSeqpacket> for SafeDescriptor {
fn from(s: UnixSeqpacket) -> Self {
s.0
}
}
impl From<SafeDescriptor> for UnixSeqpacket {
fn from(s: SafeDescriptor) -> Self {
Self(s)
}
}
impl FromRawDescriptor for UnixSeqpacket {
unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self {
Self(SafeDescriptor::from_raw_descriptor(descriptor))
}
}
impl AsRawDescriptor for UnixSeqpacket {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.0.as_raw_descriptor()
}
}
impl IntoRawDescriptor for UnixSeqpacket {
fn into_raw_descriptor(self) -> RawDescriptor {
self.0.into_raw_descriptor()
}
}
impl io::Read for UnixSeqpacket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.recv(buf)
}
}
impl io::Write for UnixSeqpacket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.send(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
pub struct UnixSeqpacketListener {
descriptor: SafeDescriptor,
no_path: bool,
}
impl UnixSeqpacketListener {
pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
if path.as_ref().starts_with("/proc/self/fd/") {
let fd = path
.as_ref()
.file_name()
.expect("Failed to get fd filename")
.to_str()
.expect("fd filename should be unicode")
.parse::<i32>()
.expect("fd should be an integer");
let mut result: c_int = 0;
let mut result_len = size_of::<c_int>() as libc::socklen_t;
let ret = unsafe {
libc::getsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_ACCEPTCONN,
&mut result as *mut _ as *mut libc::c_void,
&mut result_len,
)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
if result != 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"specified descriptor is not a listening socket",
));
}
let descriptor = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
return Ok(UnixSeqpacketListener {
descriptor,
no_path: true,
});
}
let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
let (addr, len) = sockaddr_un(path.as_ref())?;
unsafe {
let ret = handle_eintr_errno!(libc::bind(
descriptor.as_raw_descriptor(),
&addr as *const _ as *const _,
len
));
if ret < 0 {
return Err(io::Error::last_os_error());
}
let ret = handle_eintr_errno!(libc::listen(descriptor.as_raw_descriptor(), 128));
if ret < 0 {
return Err(io::Error::last_os_error());
}
}
Ok(UnixSeqpacketListener {
descriptor,
no_path: false,
})
}
pub fn accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket> {
let start = Instant::now();
loop {
let mut fds = libc::pollfd {
fd: self.as_raw_descriptor(),
events: libc::POLLIN,
revents: 0,
};
let elapsed = Instant::now().saturating_duration_since(start);
let remaining = timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO);
let cur_timeout_ms = i32::try_from(remaining.as_millis()).unwrap_or(i32::MAX);
match unsafe { libc::poll(&mut fds, 1, cur_timeout_ms) }.cmp(&0) {
Ordering::Greater => return self.accept(),
Ordering::Equal => return Err(io::Error::from_raw_os_error(libc::ETIMEDOUT)),
Ordering::Less => {
if Error::last() != Error::new(libc::EINTR) {
return Err(io::Error::last_os_error());
}
}
}
}
}
pub fn path(&self) -> io::Result<PathBuf> {
let mut addr = sockaddr_un(Path::new(""))?.0;
if self.no_path {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"socket has no path",
));
}
let sun_path_offset = (&addr.sun_path as *const _ as usize
- &addr.sun_family as *const _ as usize)
as libc::socklen_t;
let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
let ret = unsafe {
handle_eintr_errno!(libc::getsockname(
self.as_raw_descriptor(),
&mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
&mut len
))
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
|| addr.sun_path[0] == 0
|| len < 1 + sun_path_offset
{
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"getsockname on socket returned invalid value",
));
}
let path_os_str = OsString::from_vec(
addr.sun_path[..(len - sun_path_offset - 1) as usize]
.iter()
.map(|&c| c as _)
.collect(),
);
Ok(path_os_str.into())
}
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
let mut nonblocking = nonblocking as libc::c_int;
let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
}
impl AsRawDescriptor for UnixSeqpacketListener {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.descriptor.as_raw_descriptor()
}
}
impl From<UnixSeqpacketListener> for OwnedFd {
fn from(val: UnixSeqpacketListener) -> Self {
val.descriptor.into()
}
}
pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
impl AsRawDescriptor for UnlinkUnixSeqpacketListener {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.0.as_raw_descriptor()
}
}
impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
fn as_ref(&self) -> &UnixSeqpacketListener {
&self.0
}
}
impl Deref for UnlinkUnixSeqpacketListener {
type Target = UnixSeqpacketListener;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Drop for UnlinkUnixSeqpacketListener {
fn drop(&mut self) {
if let Ok(path) = self.0.path() {
if let Err(e) = remove_file(path) {
warn!("failed to remove control socket file: {:?}", e);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sockaddr_un_zero_length_input() {
let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
}
#[test]
fn sockaddr_un_long_input_err() {
let res = sockaddr_un(Path::new(&"a".repeat(108)));
assert!(res.is_err());
}
#[test]
fn sockaddr_un_long_input_pass() {
let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
}
#[test]
fn sockaddr_un_len_check() {
let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
}
#[test]
#[allow(clippy::unnecessary_cast)]
#[allow(clippy::char_lit_as_u8)]
fn sockaddr_un_pass() {
let path_size = 50;
let (addr, len) =
sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
let mut ref_sun_path = [0 as libc::c_char; 108];
for path in ref_sun_path.iter_mut().take(path_size) {
*path = 'a' as libc::c_char;
}
for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
assert_eq!(addr_char, ref_char);
}
}
}