use std::io;
use std::io::Cursor;
use std::io::Read;
use std::io::Write;
use std::mem;
use std::os::windows::io::AsRawHandle;
use std::os::windows::io::RawHandle;
use std::time::Duration;
use log::warn;
use serde::de::DeserializeOwned;
use serde::Deserialize;
use serde::Serialize;
use serde::Serializer;
use zerocopy::FromBytes;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;
use crate::descriptor::AsRawDescriptor;
use crate::descriptor::FromRawDescriptor;
use crate::descriptor::SafeDescriptor;
use crate::descriptor_reflection::deserialize_with_descriptors;
use crate::descriptor_reflection::SerializeDescriptors;
use crate::tube::Error;
use crate::tube::RecvTube;
use crate::tube::Result;
use crate::tube::SendTube;
use crate::BlockingMode;
use crate::CloseNotifier;
use crate::Event;
use crate::EventToken;
use crate::FramingMode;
use crate::PipeConnection;
use crate::RawDescriptor;
use crate::ReadNotifier;
use crate::StreamChannel;
#[derive(Serialize, Deserialize, Debug)]
pub struct Tube {
socket: StreamChannel,
#[serde(serialize_with = "set_tube_pid_on_serialize")]
target_pid: Option<u32>,
}
fn set_tube_pid_on_serialize<S>(
existing_pid_value: &Option<u32>,
serializer: S,
) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
match existing_pid_value {
Some(pid) => serializer.serialize_u32(*pid),
None => serializer.serialize_u32(ALIAS_PID.lock().unwrap_or(std::process::id())),
}
}
#[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
#[repr(C)]
struct MsgHeader {
msg_json_size: usize,
descriptor_json_size: usize,
}
static DH_TUBE: sync::Mutex<Option<DuplicateHandleTube>> = sync::Mutex::new(None);
static ALIAS_PID: sync::Mutex<Option<u32>> = sync::Mutex::new(None);
pub fn set_duplicate_handle_tube(dh_tube: DuplicateHandleTube) {
DH_TUBE.lock().replace(dh_tube);
}
pub fn set_alias_pid(alias_pid: u32) {
ALIAS_PID.lock().replace(alias_pid);
}
impl Tube {
pub fn pair() -> Result<(Tube, Tube)> {
let (socket1, socket2) = StreamChannel::pair(BlockingMode::Blocking, FramingMode::Message)
.map_err(|e| Error::Pair(io::Error::from_raw_os_error(e.errno())))?;
Ok((Tube::new(socket1), Tube::new(socket2)))
}
pub fn pair_with_buffer_size(buffer_size: usize) -> Result<(Tube, Tube)> {
let (socket1, socket2) = StreamChannel::pair_with_buffer_size(
BlockingMode::Blocking,
FramingMode::Message,
buffer_size,
)
.map_err(|e| Error::Pair(io::Error::from_raw_os_error(e.errno())))?;
let tube1 = Tube::new(socket1);
let tube2 = Tube::new(socket2);
Ok((tube1, tube2))
}
pub fn new(socket: StreamChannel) -> Tube {
Tube {
socket,
target_pid: None,
}
}
pub(crate) fn try_clone(&self) -> Result<Self> {
Ok(Tube {
socket: self.socket.try_clone().map_err(Error::Clone)?,
target_pid: self.target_pid,
})
}
fn send_proto<M: protobuf::Message>(&self, msg: &M) -> Result<()> {
let bytes = msg.write_to_bytes().map_err(Error::Proto)?;
let size_header = bytes.len();
let mut data_packet =
Cursor::new(Vec::with_capacity(mem::size_of::<usize>() + size_header));
data_packet
.write(&size_header.to_le_bytes())
.map_err(Error::from_send_io_buf_error)?;
data_packet.write(&bytes).map_err(Error::SendIoBuf)?;
self.socket
.write_immutable(&data_packet.into_inner())
.map_err(Error::from_send_error)?;
Ok(())
}
fn recv_proto<M: protobuf::Message>(&self) -> Result<M> {
let mut header_bytes = [0u8; mem::size_of::<usize>()];
perform_read(&mut |buf| (&self.socket).read(buf), &mut header_bytes)
.map_err(Error::from_recv_io_error)?;
let size_header = usize::from_le_bytes(header_bytes);
let mut proto_bytes = vec![0u8; size_header];
perform_read(&mut |buf| (&self.socket).read(buf), &mut proto_bytes)
.map_err(Error::from_recv_io_error)?;
protobuf::Message::parse_from_bytes(&proto_bytes).map_err(Error::Proto)
}
pub fn send<T: Serialize>(&self, msg: &T) -> Result<()> {
serialize_and_send(|buf| self.socket.write_immutable(buf), msg, self.target_pid)
}
pub fn recv<T: DeserializeOwned>(&self) -> Result<T> {
deserialize_and_recv(|buf| (&self.socket).read(buf))
}
#[cfg(windows)]
pub fn flush_blocking(&mut self) -> Result<()> {
self.socket.flush_blocking().map_err(Error::Flush)
}
pub fn set_target_pid(&mut self, target_pid: u32) {
self.target_pid = Some(target_pid);
}
pub fn target_pid(&self) -> Option<u32> {
self.target_pid
}
pub fn set_send_timeout(&self, _timeout: Option<Duration>) -> Result<()> {
unimplemented!("To be removed/refactored upstream.");
}
pub fn set_recv_timeout(&self, _timeout: Option<Duration>) -> Result<()> {
unimplemented!("To be removed/refactored upstream.");
}
pub fn get_read_notifier_event(&self) -> &Event {
self.socket.get_read_notifier_event()
}
pub fn get_close_notifier_event(&self) -> &Event {
self.socket.get_close_notifier_event()
}
}
pub fn serialize_and_send<T: Serialize, F: Fn(&[u8]) -> io::Result<usize>>(
write_fn: F,
msg: &T,
target_pid: Option<u32>,
) -> Result<()> {
let msg_serialize = SerializeDescriptors::new(&msg);
let msg_json = serde_json::to_vec(&msg_serialize).map_err(Error::Json)?;
let msg_descriptors = msg_serialize.into_descriptors();
let mut duped_descriptors = Vec::with_capacity(msg_descriptors.len());
for desc in msg_descriptors {
duped_descriptors.push(duplicate_handle(desc, target_pid)? as usize)
}
let descriptor_json = if duped_descriptors.is_empty() {
None
} else {
Some(serde_json::to_vec(&duped_descriptors).map_err(Error::Json)?)
};
let header = MsgHeader {
msg_json_size: msg_json.len(),
descriptor_json_size: descriptor_json.as_ref().map_or(0, |json| json.len()),
};
let mut data_packet = Cursor::new(Vec::with_capacity(
header.as_bytes().len() + header.msg_json_size + header.descriptor_json_size,
));
data_packet
.write(header.as_bytes())
.map_err(Error::SendIoBuf)?;
data_packet
.write(msg_json.as_slice())
.map_err(Error::SendIoBuf)?;
if let Some(descriptor_json) = descriptor_json {
data_packet
.write(descriptor_json.as_slice())
.map_err(Error::SendIoBuf)?;
}
let data_bytes = data_packet.into_inner();
write_fn(&data_bytes).map_err(Error::from_send_error)?;
Ok(())
}
fn duplicate_handle(desc: RawHandle, target_pid: Option<u32>) -> Result<RawHandle> {
match target_pid {
Some(pid) => match &*DH_TUBE.lock() {
Some(tube) => tube.request_duplicate_handle(pid, desc),
None => {
win_util::duplicate_handle_with_target_pid(desc, pid).map_err(Error::DupDescriptor)
}
},
None => win_util::duplicate_handle(desc).map_err(Error::DupDescriptor),
}
}
fn perform_read<F: FnMut(&mut [u8]) -> io::Result<usize>>(
read_fn: &mut F,
buf: &mut [u8],
) -> io::Result<usize> {
let bytes_read = read_fn(buf)?;
if bytes_read != buf.len() {
Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"failed to fill whole buffer, expected {} got {}",
buf.len(),
bytes_read
),
))
} else {
Ok(bytes_read)
}
}
pub fn deserialize_and_recv<T: DeserializeOwned, F: FnMut(&mut [u8]) -> io::Result<usize>>(
mut read_fn: F,
) -> Result<T> {
let mut header = MsgHeader::default();
perform_read(&mut read_fn, header.as_mut_bytes()).map_err(Error::from_recv_io_error)?;
let mut msg_json = vec![0u8; header.msg_json_size];
perform_read(&mut read_fn, msg_json.as_mut_slice()).map_err(Error::from_recv_io_error)?;
if msg_json.is_empty() {
return Err(Error::RecvUnexpectedEmptyBody);
}
let descriptor_usizes: Vec<usize> = if header.descriptor_json_size > 0 {
let mut msg_descriptors_json = vec![0u8; header.descriptor_json_size];
perform_read(&mut read_fn, msg_descriptors_json.as_mut_slice())
.map_err(Error::from_recv_io_error)?;
serde_json::from_slice(msg_descriptors_json.as_slice()).map_err(Error::Json)?
} else {
Vec::new()
};
let msg_descriptors = descriptor_usizes.into_iter().map(|item| {
unsafe { SafeDescriptor::from_raw_descriptor(item as RawDescriptor) }
});
deserialize_with_descriptors(|| serde_json::from_slice(&msg_json), msg_descriptors)
.map_err(Error::Json)
}
#[derive(EventToken, Eq, PartialEq, Copy, Clone)]
enum Token {
SocketReady,
}
impl AsRawDescriptor for Tube {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.socket.as_raw_descriptor()
}
}
impl AsRawHandle for Tube {
fn as_raw_handle(&self) -> RawHandle {
self.as_raw_descriptor()
}
}
impl ReadNotifier for Tube {
fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
self.socket.get_read_notifier()
}
}
impl CloseNotifier for Tube {
fn get_close_notifier(&self) -> &dyn AsRawDescriptor {
self.socket.get_close_notifier()
}
}
impl AsRawDescriptor for SendTube {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.0.as_raw_descriptor()
}
}
impl AsRawDescriptor for RecvTube {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.0.as_raw_descriptor()
}
}
impl CloseNotifier for SendTube {
fn get_close_notifier(&self) -> &dyn AsRawDescriptor {
self.0.get_close_notifier()
}
}
impl CloseNotifier for RecvTube {
fn get_close_notifier(&self) -> &dyn AsRawDescriptor {
self.0.get_close_notifier()
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct DuplicateHandleRequest {
pub target_alias_pid: u32,
pub handle: usize,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct DuplicateHandleResponse {
pub handle: Option<usize>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct DuplicateHandleTube(Tube);
impl DuplicateHandleTube {
pub fn new(tube: Tube) -> Self {
Self(tube)
}
pub fn request_duplicate_handle(
&self,
target_alias_pid: u32,
handle: RawHandle,
) -> Result<RawHandle> {
let req = DuplicateHandleRequest {
target_alias_pid,
handle: handle as usize,
};
self.0.send(&req)?;
let res: DuplicateHandleResponse = self.0.recv()?;
res.handle
.map(|h| h as RawHandle)
.ok_or(Error::BrokerDupDescriptor)
}
}
#[derive(Serialize, Deserialize)]
pub struct ProtoTube(Tube);
impl ProtoTube {
pub fn pair() -> Result<(ProtoTube, ProtoTube)> {
Tube::pair().map(|(t1, t2)| (ProtoTube(t1), ProtoTube(t2)))
}
pub fn pair_with_buffer_size(size: usize) -> Result<(ProtoTube, ProtoTube)> {
Tube::pair_with_buffer_size(size).map(|(t1, t2)| (ProtoTube(t1), ProtoTube(t2)))
}
pub fn send_proto<M: protobuf::Message>(&self, msg: &M) -> Result<()> {
self.0.send_proto(msg)
}
pub fn recv_proto<M: protobuf::Message>(&self) -> Result<M> {
self.0.recv_proto()
}
}
impl ReadNotifier for ProtoTube {
fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
self.0.get_read_notifier()
}
}
impl AsRawDescriptor for ProtoTube {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.0.as_raw_descriptor()
}
}
pub struct PipeTube {
pipe: PipeConnection,
target_pid: Option<u32>,
}
impl PipeTube {
pub fn from(pipe: PipeConnection, target_pid: Option<u32>) -> Self {
Self { pipe, target_pid }
}
pub fn send<T: Serialize>(&self, msg: &T) -> Result<()> {
serialize_and_send(|buf| self.pipe.write(buf), msg, self.target_pid)
}
pub fn recv<T: DeserializeOwned>(&self) -> Result<T> {
deserialize_and_recv(|buf| {
unsafe { self.pipe.read(buf) }
})
}
}
pub struct FlushOnDropTube(pub Tube);
impl FlushOnDropTube {
pub fn from(tube: Tube) -> Self {
Self(tube)
}
}
impl Drop for FlushOnDropTube {
fn drop(&mut self) {
if let Err(e) = self.0.flush_blocking() {
warn!("failed to flush Tube: {}", e)
}
}
}
impl Error {
fn map_io_error(e: io::Error, err_ctor: fn(io::Error) -> Error) -> Error {
if e.kind() == io::ErrorKind::UnexpectedEof || e.kind() == io::ErrorKind::BrokenPipe {
Error::Disconnected
} else {
err_ctor(e)
}
}
fn from_recv_io_error(e: io::Error) -> Error {
Self::map_io_error(e, Error::Recv)
}
fn from_send_error(e: io::Error) -> Error {
Self::map_io_error(e, Error::Send)
}
fn from_send_io_buf_error(e: io::Error) -> Error {
Self::map_io_error(e, Error::SendIoBuf)
}
}