use serde::Serialize;
use tokio::sync::mpsc;
use crate::msgpack_rpc::MsgPackCaller;
use super::{
protocol::{ClientRequestMethod, RefServerMessageParams, ServerClosedParams, ToClientRequest},
server_multiplexer::ServerMultiplexer,
};
pub struct CloseReason(pub String);
pub enum SocketSignal {
Send(Vec<u8>),
CloseWith(CloseReason),
}
impl From<Vec<u8>> for SocketSignal {
fn from(v: Vec<u8>) -> Self {
SocketSignal::Send(v)
}
}
impl SocketSignal {
pub fn from_message<T>(msg: &T) -> Self
where
T: Serialize + ?Sized,
{
SocketSignal::Send(rmp_serde::to_vec_named(msg).unwrap())
}
}
#[allow(dead_code)]
pub enum ServerMessageDestination {
Channel(mpsc::Sender<SocketSignal>),
Rpc(MsgPackCaller),
}
pub struct ServerMessageSink {
id: u16,
tx: Option<ServerMessageDestination>,
multiplexer: ServerMultiplexer,
flate: Option<FlateStream<CompressFlateAlgorithm>>,
}
impl ServerMessageSink {
pub fn new_plain(
multiplexer: ServerMultiplexer,
id: u16,
tx: ServerMessageDestination,
) -> Self {
Self {
tx: Some(tx),
id,
multiplexer,
flate: None,
}
}
pub fn new_compressed(
multiplexer: ServerMultiplexer,
id: u16,
tx: ServerMessageDestination,
) -> Self {
Self {
tx: Some(tx),
id,
multiplexer,
flate: Some(FlateStream::new(CompressFlateAlgorithm(
flate2::Compress::new(flate2::Compression::new(2), false),
))),
}
}
pub async fn server_closed(&mut self) -> Result<(), mpsc::error::SendError<SocketSignal>> {
self.server_message_or_closed(None).await
}
pub async fn server_message(
&mut self,
body: &[u8],
) -> Result<(), mpsc::error::SendError<SocketSignal>> {
self.server_message_or_closed(Some(body)).await
}
async fn server_message_or_closed(
&mut self,
body_or_end: Option<&[u8]>,
) -> Result<(), mpsc::error::SendError<SocketSignal>> {
let i = self.id;
let mut tx = self.tx.take().unwrap();
if let Some(b) = body_or_end {
let body = self.get_server_msg_content(b, false);
let r =
send_data_or_close_if_none(i, &mut tx, Some(RefServerMessageParams { i, body }))
.await;
self.tx = Some(tx);
return r;
}
let tail = self.get_server_msg_content(&[], true);
if !tail.is_empty() {
let _ = send_data_or_close_if_none(
i,
&mut tx,
Some(RefServerMessageParams { i, body: tail }),
)
.await;
}
let r = send_data_or_close_if_none(i, &mut tx, None).await;
self.tx = Some(tx);
r
}
pub(crate) fn get_server_msg_content<'a: 'b, 'b>(
&'a mut self,
body: &'b [u8],
finish: bool,
) -> &'b [u8] {
if let Some(flate) = &mut self.flate {
if let Ok(compressed) = flate.process(body, finish) {
return compressed;
}
}
body
}
}
async fn send_data_or_close_if_none(
i: u16,
tx: &mut ServerMessageDestination,
msg: Option<RefServerMessageParams<'_>>,
) -> Result<(), mpsc::error::SendError<SocketSignal>> {
match tx {
ServerMessageDestination::Channel(tx) => {
tx.send(SocketSignal::from_message(&ToClientRequest {
id: None,
params: match msg {
Some(msg) => ClientRequestMethod::servermsg(msg),
None => ClientRequestMethod::serverclose(ServerClosedParams { i }),
},
}))
.await
}
ServerMessageDestination::Rpc(caller) => {
match msg {
Some(msg) => caller.notify("servermsg", msg),
None => caller.notify("serverclose", ServerClosedParams { i }),
};
Ok(())
}
}
}
impl Drop for ServerMessageSink {
fn drop(&mut self) {
self.multiplexer.remove(self.id);
}
}
pub struct ClientMessageDecoder {
dec: Option<FlateStream<DecompressFlateAlgorithm>>,
}
impl ClientMessageDecoder {
pub fn new_plain() -> Self {
ClientMessageDecoder { dec: None }
}
pub fn new_compressed() -> Self {
ClientMessageDecoder {
dec: Some(FlateStream::new(DecompressFlateAlgorithm(
flate2::Decompress::new(false),
))),
}
}
pub fn decode<'a: 'b, 'b>(&'a mut self, message: &'b [u8]) -> std::io::Result<&'b [u8]> {
match &mut self.dec {
Some(d) => d.process(message, false),
None => Ok(message),
}
}
}
trait FlateAlgorithm {
fn total_in(&self) -> u64;
fn total_out(&self) -> u64;
fn process(
&mut self,
contents: &[u8],
output: &mut [u8],
finish: bool,
) -> Result<flate2::Status, std::io::Error>;
}
struct DecompressFlateAlgorithm(flate2::Decompress);
impl FlateAlgorithm for DecompressFlateAlgorithm {
fn total_in(&self) -> u64 {
self.0.total_in()
}
fn total_out(&self) -> u64 {
self.0.total_out()
}
fn process(
&mut self,
contents: &[u8],
output: &mut [u8],
finish: bool,
) -> Result<flate2::Status, std::io::Error> {
let mode = match finish {
true => flate2::FlushDecompress::Finish,
false => flate2::FlushDecompress::None,
};
self.0
.decompress(contents, output, mode)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
}
}
struct CompressFlateAlgorithm(flate2::Compress);
impl FlateAlgorithm for CompressFlateAlgorithm {
fn total_in(&self) -> u64 {
self.0.total_in()
}
fn total_out(&self) -> u64 {
self.0.total_out()
}
fn process(
&mut self,
contents: &[u8],
output: &mut [u8],
finish: bool,
) -> Result<flate2::Status, std::io::Error> {
let mode = match finish {
true => flate2::FlushCompress::Finish,
false => flate2::FlushCompress::Sync,
};
self.0
.compress(contents, output, mode)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
}
}
struct FlateStream<A>
where
A: FlateAlgorithm,
{
flate: A,
output: Vec<u8>,
}
impl<A> FlateStream<A>
where
A: FlateAlgorithm,
{
pub fn new(alg: A) -> Self {
Self {
flate: alg,
output: vec![0; 4096],
}
}
pub fn process(&mut self, contents: &[u8], finish: bool) -> std::io::Result<&[u8]> {
let mut out_offset = 0;
let mut in_offset = 0;
loop {
let in_before = self.flate.total_in();
let out_before = self.flate.total_out();
match self.flate.process(
&contents[in_offset..],
&mut self.output[out_offset..],
finish,
) {
Ok(flate2::Status::Ok | flate2::Status::BufError) => {
let processed_len = in_offset + (self.flate.total_in() - in_before) as usize;
let output_len = out_offset + (self.flate.total_out() - out_before) as usize;
if processed_len < contents.len() || output_len == self.output.len() {
out_offset = output_len;
in_offset = processed_len;
if output_len == self.output.len() {
self.output.resize(self.output.len() * 2, 0);
}
continue;
}
return Ok(&self.output[..output_len]);
}
Ok(flate2::Status::StreamEnd) => {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected stream end",
))
}
Err(e) => return Err(e),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::{engine::general_purpose, Engine as _};
#[test]
fn test_round_trips_compression() {
let (tx, _) = mpsc::channel(1);
let mut sink = ServerMessageSink::new_compressed(
ServerMultiplexer::new(),
0,
ServerMessageDestination::Channel(tx),
);
let mut decompress = ClientMessageDecoder::new_compressed();
for msg_len in [3, 30, 300, 3000, 30000] {
let vals = (0..msg_len).map(|v| v as u8).collect::<Vec<u8>>();
let compressed = sink.get_server_msg_content(&vals, false);
assert_ne!(compressed, vals);
let decompressed = decompress.decode(compressed).unwrap();
assert_eq!(decompressed.len(), vals.len());
assert_eq!(decompressed, vals);
}
}
const TEST_191501_BUFS: [&str; 3] = [
"TMzLSsQwFIDhfSDv0NXsYs2kubQQXIgX0IUwHVyfpCdjaSYZmkjRpxdEBnf/5vufHsZmK0PbxuwhfuRS2zmVecKVBd1rEYTUqL3gCoxBY7g2RoWOg+nE7Z4H1N3dij6nhL7OOY15wWTBeN87IVkACayTijMXcGJagevkxJ3i/e4/swFiwV1Z5ss7ukP2C9bHFc5YbF0/sXkex7eW33BK7q9maI6X0woTUvIXQ7OhK7+YkgN6dn2xF/wamhTgVM8xHl8Tr2kvvv2SymYtJZT8AAAA//8=",
"YmJAgIhqpZLKglQlK6XE0pIMJR0IZaVUlJqbX5JaXAwSSkksSQQK+WUkung5BWam6TumVaWEFhQHJBuUGrg4WUY4eQV4GOTnhwVkWJiX5lRmOdoq1QIAAAD//w==",
""
];
#[test]
fn test_flatestream_decodes_191501() {
let mut dec = ClientMessageDecoder::new_compressed();
let mut len = 0;
for b in TEST_191501_BUFS {
let b = general_purpose::STANDARD
.decode(b)
.expect("expected no decode error");
let s = dec.decode(&b).expect("expected no decompress error");
len += s.len();
}
assert_eq!(len, 265 + 101 + 10370);
}
}