use std::io;
use std::sync::Arc;
use log::error;
use log::warn;
use serde::ser::SerializeStruct;
use serde::Deserialize;
use serde::Serialize;
use serde::Serializer;
use sync::Mutex;
use super::named_pipes;
use super::named_pipes::PipeConnection;
use super::MultiProcessMutex;
use super::RawDescriptor;
use super::Result;
use crate::descriptor::AsRawDescriptor;
use crate::CloseNotifier;
use crate::Event;
use crate::ReadNotifier;
#[derive(Copy, Clone)]
pub enum FramingMode {
Message,
Byte,
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum BlockingMode {
Blocking,
Nonblocking,
}
impl From<FramingMode> for named_pipes::FramingMode {
fn from(framing_mode: FramingMode) -> Self {
match framing_mode {
FramingMode::Message => named_pipes::FramingMode::Message,
FramingMode::Byte => named_pipes::FramingMode::Byte,
}
}
}
impl From<BlockingMode> for named_pipes::BlockingMode {
fn from(blocking_mode: BlockingMode) -> Self {
match blocking_mode {
BlockingMode::Blocking => named_pipes::BlockingMode::Wait,
BlockingMode::Nonblocking => named_pipes::BlockingMode::NoWait,
}
}
}
pub const DEFAULT_BUFFER_SIZE: usize = 50 * 1024;
#[derive(Deserialize, Debug)]
pub struct StreamChannel {
pipe_conn: named_pipes::PipeConnection,
write_notify: Event,
read_notify: Event,
pipe_closed: Event,
remote_write_lock: MultiProcessMutex,
local_write_lock: MultiProcessMutex,
#[serde(skip)]
#[serde(default = "create_read_lock")]
read_lock: Arc<Mutex<()>>,
#[serde(skip)]
#[serde(default = "create_true_mutex")]
is_channel_closed_on_drop: Mutex<bool>,
send_buffer_size: usize,
}
fn create_read_lock() -> Arc<Mutex<()>> {
Arc::new(Mutex::new(()))
}
fn create_true_mutex() -> Mutex<bool> {
Mutex::new(true)
}
impl Serialize for StreamChannel {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut s = serializer.serialize_struct("StreamChannel", 7)?;
s.serialize_field("pipe_conn", &self.pipe_conn)?;
s.serialize_field("write_notify", &self.write_notify)?;
s.serialize_field("read_notify", &self.read_notify)?;
s.serialize_field("pipe_closed", &self.pipe_closed)?;
s.serialize_field("remote_write_lock", &self.remote_write_lock)?;
s.serialize_field("local_write_lock", &self.local_write_lock)?;
s.serialize_field("send_buffer_size", &self.send_buffer_size)?;
let ret = s.end();
if ret.is_ok() {
*self.is_channel_closed_on_drop.lock() = false;
}
ret
}
}
impl Drop for StreamChannel {
fn drop(&mut self) {
if *self.is_channel_closed_on_drop.lock() {
if let Err(e) = self.pipe_closed.signal() {
warn!("failed to notify on channel drop: {}", e);
}
}
}
}
impl StreamChannel {
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
if nonblocking {
self.pipe_conn
.set_blocking(&named_pipes::BlockingMode::NoWait)
} else {
self.pipe_conn
.set_blocking(&named_pipes::BlockingMode::Wait)
}
}
pub fn try_clone(&self) -> io::Result<Self> {
Ok(StreamChannel {
pipe_conn: self.pipe_conn.try_clone()?,
write_notify: self.write_notify.try_clone()?,
read_notify: self.read_notify.try_clone()?,
pipe_closed: self.pipe_closed.try_clone()?,
remote_write_lock: self.remote_write_lock.try_clone()?,
local_write_lock: self.local_write_lock.try_clone()?,
read_lock: self.read_lock.clone(),
is_channel_closed_on_drop: create_true_mutex(),
send_buffer_size: self.send_buffer_size,
})
}
fn get_readable_byte_count(&self) -> io::Result<u32> {
match self.pipe_conn.get_available_byte_count() {
Err(e) if e.kind() == io::ErrorKind::BrokenPipe => Ok(0),
Err(e) => {
error!("StreamChannel failed to get readable byte count: {}", e);
Err(e)
}
Ok(byte_count) => Ok(byte_count),
}
}
pub(super) fn inner_read(&self, buf: &mut [u8]) -> io::Result<usize> {
let _read_lock = self.read_lock.lock();
let res = unsafe { self.pipe_conn.read(buf) };
loop {
let byte_count = self.get_readable_byte_count()?;
if byte_count > 0 {
self.read_notify.signal().map_err(|e| {
io::Error::other(format!("failed to write to read notifier: {e:?}"))
})?;
return res;
}
if let Some(_write_lock) = self.remote_write_lock.try_lock( 1) {
let byte_count = self.get_readable_byte_count()?;
if byte_count > 0 {
self.read_notify.signal().map_err(|e| {
io::Error::other(format!("failed to write to read notifier: {e:?}"))
})?;
} else {
self.read_notify.reset().map_err(|e| {
io::Error::other(format!("failed to reset read notifier: {e:?}"))
})?;
}
return res;
}
}
}
pub fn write_immutable(&self, buf: &[u8]) -> io::Result<usize> {
if self.pipe_conn.get_framing_mode() == named_pipes::FramingMode::Message
&& buf.len() > self.send_buffer_size
{
return Err(io::Error::other(format!(
"StreamChannel forbids message mode writes larger than the \
default buffer size of {}.",
self.send_buffer_size,
)));
}
let _lock = self.local_write_lock.lock();
let res = self.pipe_conn.write(buf);
if res.is_ok() {
self.write_notify.signal().map_err(|e| {
io::Error::other(format!("failed to write to read notifier: {e:?}"))
})?;
}
res
}
pub fn from_pipes(
pipe_a: PipeConnection,
pipe_b: PipeConnection,
send_buffer_size: usize,
) -> Result<(StreamChannel, StreamChannel)> {
let (notify_a_write, notify_b_write) = (Event::new()?, Event::new()?);
let pipe_closed = Event::new()?;
let write_lock_a = MultiProcessMutex::new()?;
let write_lock_b = MultiProcessMutex::new()?;
let sock_a = StreamChannel {
pipe_conn: pipe_a,
write_notify: notify_a_write.try_clone()?,
read_notify: notify_b_write.try_clone()?,
read_lock: Arc::new(Mutex::new(())),
local_write_lock: write_lock_a.try_clone()?,
remote_write_lock: write_lock_b.try_clone()?,
pipe_closed: pipe_closed.try_clone()?,
is_channel_closed_on_drop: create_true_mutex(),
send_buffer_size,
};
let sock_b = StreamChannel {
pipe_conn: pipe_b,
write_notify: notify_b_write,
read_notify: notify_a_write,
read_lock: Arc::new(Mutex::new(())),
local_write_lock: write_lock_b,
remote_write_lock: write_lock_a,
pipe_closed,
is_channel_closed_on_drop: create_true_mutex(),
send_buffer_size,
};
Ok((sock_a, sock_b))
}
pub fn pair_with_buffer_size(
blocking_mode: BlockingMode,
framing_mode: FramingMode,
buffer_size: usize,
) -> Result<(StreamChannel, StreamChannel)> {
let (pipe_a, pipe_b) = named_pipes::pair_with_buffer_size(
&named_pipes::FramingMode::from(framing_mode),
&named_pipes::BlockingMode::from(blocking_mode),
0,
buffer_size,
false,
)?;
Self::from_pipes(pipe_a, pipe_b, buffer_size)
}
pub fn pair(
blocking_mode: BlockingMode,
framing_mode: FramingMode,
) -> Result<(StreamChannel, StreamChannel)> {
let (pipe_a, pipe_b) = named_pipes::pair_with_buffer_size(
&named_pipes::FramingMode::from(framing_mode),
&named_pipes::BlockingMode::from(blocking_mode),
0,
DEFAULT_BUFFER_SIZE,
false,
)?;
Self::from_pipes(pipe_a, pipe_b, DEFAULT_BUFFER_SIZE)
}
pub fn flush_blocking(&self) -> io::Result<()> {
self.pipe_conn.flush_data_blocking()
}
pub(crate) fn get_read_notifier_event(&self) -> &Event {
&self.read_notify
}
pub(crate) fn get_close_notifier_event(&self) -> &Event {
&self.pipe_closed
}
}
impl io::Write for StreamChannel {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_immutable(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl AsRawDescriptor for &StreamChannel {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.pipe_conn.as_raw_descriptor()
}
}
impl ReadNotifier for StreamChannel {
fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
&self.read_notify
}
}
impl CloseNotifier for StreamChannel {
fn get_close_notifier(&self) -> &dyn AsRawDescriptor {
&self.pipe_closed
}
}
impl io::Read for StreamChannel {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner_read(buf)
}
}
impl io::Read for &StreamChannel {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner_read(buf)
}
}
impl AsRawDescriptor for StreamChannel {
fn as_raw_descriptor(&self) -> RawDescriptor {
(&self).as_raw_descriptor()
}
}
#[cfg(test)]
mod test {
use std::io::Read;
use std::io::Write;
use std::time::Duration;
use super::super::EventContext;
use super::super::EventTrigger;
use super::*;
use crate::EventToken;
use crate::ReadNotifier;
#[derive(EventToken, Debug, Eq, PartialEq, Copy, Clone)]
enum Token {
ReceivedData,
}
const EVENT_WAIT_TIME: Duration = Duration::from_secs(10);
#[test]
fn test_read_notifies_multiple_writes() {
let (mut sender, mut receiver) =
StreamChannel::pair(BlockingMode::Blocking, FramingMode::Byte).unwrap();
sender.write_all(&[1, 2]).unwrap();
let event_ctx: EventContext<Token> = EventContext::build_with(&[EventTrigger::from(
receiver.get_read_notifier(),
Token::ReceivedData,
)])
.unwrap();
assert_eq!(event_ctx.wait_timeout(EVENT_WAIT_TIME).unwrap().len(), 1);
let mut recv_buffer = [0u8; 1];
let size = receiver.read(&mut recv_buffer).unwrap();
assert_eq!(size, 1);
assert_eq!(recv_buffer[0], 1);
assert_eq!(event_ctx.wait_timeout(EVENT_WAIT_TIME).unwrap().len(), 1);
let size = receiver.read(&mut recv_buffer).unwrap();
assert_eq!(size, 1);
assert_eq!(recv_buffer[0], 2);
}
#[test]
fn test_blocked_writer_wont_deadlock() {
let (mut writer, mut reader) =
StreamChannel::pair_with_buffer_size(BlockingMode::Blocking, FramingMode::Byte, 100)
.unwrap();
const NUM_OPS: usize = 100;
let writer = std::thread::spawn(move || {
let buf = [0u8; 100];
for _ in 0..NUM_OPS {
assert_eq!(writer.write(&buf).unwrap(), buf.len());
}
writer
});
let mut buf = [0u8; 100];
for _ in 0..NUM_OPS {
assert_eq!(reader.read(&mut buf).unwrap(), buf.len());
}
writer.join().unwrap();
}
#[test]
fn test_non_blocking_pair() {
let (mut sender, mut receiver) =
StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
sender.write_all(&[75, 77, 54, 82, 76, 65]).unwrap();
let event_ctx: EventContext<Token> = EventContext::build_with(&[EventTrigger::from(
receiver.get_read_notifier(),
Token::ReceivedData,
)])
.unwrap();
let events = event_ctx.wait().unwrap();
let tokens: Vec<Token> = events
.iter()
.filter(|e| e.is_readable)
.map(|e| e.token)
.collect();
assert_eq!(tokens, vec! {Token::ReceivedData});
let mut recv_buffer: [u8; 4] = [0; 4];
let mut size = receiver.read(&mut recv_buffer).unwrap();
assert_eq!(size, 4);
assert_eq!(recv_buffer, [75, 77, 54, 82]);
size = receiver.read(&mut recv_buffer).unwrap();
assert_eq!(size, 2);
assert_eq!(recv_buffer[0..2], [76, 65]);
assert_eq!(
event_ctx
.wait_timeout(std::time::Duration::new(0, 0))
.unwrap()
.len(),
0
);
}
#[test]
fn test_non_blocking_pair_error_no_data() {
let (mut sender, mut receiver) =
StreamChannel::pair(BlockingMode::Nonblocking, FramingMode::Byte).unwrap();
receiver
.set_nonblocking(true)
.expect("Failed to set receiver to nonblocking mode.");
sender.write_all(&[75, 77]).unwrap();
let event_ctx: EventContext<Token> = EventContext::build_with(&[EventTrigger::from(
receiver.get_read_notifier(),
Token::ReceivedData,
)])
.unwrap();
let events = event_ctx.wait().unwrap();
let tokens: Vec<Token> = events
.iter()
.filter(|e| e.is_readable)
.map(|e| e.token)
.collect();
assert_eq!(tokens, vec! {Token::ReceivedData});
let mut recv_buffer: [u8; 4] = [0; 4];
let size = receiver.read(&mut recv_buffer).unwrap();
assert_eq!(size, 2);
assert_eq!(recv_buffer, [75, 77, 00, 00]);
assert!(receiver.read(&mut recv_buffer).is_err());
}
}