Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/cli/src/msgpack_rpc.rs
3309 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
use bytes::Buf;
7
use serde::de::DeserializeOwned;
8
use tokio::{
9
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
10
pin,
11
sync::mpsc,
12
};
13
use tokio_util::codec::Decoder;
14
15
use crate::{
16
rpc::{self, MaybeSync, Serialization},
17
util::{
18
errors::{AnyError, InvalidRpcDataError},
19
sync::{Barrier, Receivable},
20
},
21
};
22
use std::io::{self, Cursor, ErrorKind};
23
24
#[derive(Copy, Clone)]
25
pub struct MsgPackSerializer {}
26
27
impl Serialization for MsgPackSerializer {
28
fn serialize(&self, value: impl serde::Serialize) -> Vec<u8> {
29
rmp_serde::to_vec_named(&value).expect("expected to serialize")
30
}
31
32
fn deserialize<P: serde::de::DeserializeOwned>(&self, b: &[u8]) -> Result<P, AnyError> {
33
rmp_serde::from_slice(b).map_err(|e| InvalidRpcDataError(e.to_string()).into())
34
}
35
}
36
37
pub type MsgPackCaller = rpc::RpcCaller<MsgPackSerializer>;
38
39
/// Creates a new RPC Builder that serializes to msgpack.
40
pub fn new_msgpack_rpc() -> rpc::RpcBuilder<MsgPackSerializer> {
41
rpc::RpcBuilder::new(MsgPackSerializer {})
42
}
43
44
/// Starting processing msgpack rpc over the given i/o. It's recommended that
45
/// the reader be passed in as a BufReader for efficiency.
46
pub async fn start_msgpack_rpc<
47
C: Send + Sync + 'static,
48
X: Clone,
49
S: Send + Sync + Serialization,
50
Read: AsyncRead + Unpin,
51
Write: AsyncWrite + Unpin,
52
>(
53
dispatcher: rpc::RpcDispatcher<S, C>,
54
mut read: Read,
55
mut write: Write,
56
mut msg_rx: impl Receivable<Vec<u8>>,
57
mut shutdown_rx: Barrier<X>,
58
) -> io::Result<(Option<X>, Read, Write)> {
59
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
60
let mut decoder = MsgPackCodec::new();
61
let mut decoder_buf = bytes::BytesMut::new();
62
63
let shutdown_fut = shutdown_rx.wait();
64
pin!(shutdown_fut);
65
66
loop {
67
tokio::select! {
68
r = read.read_buf(&mut decoder_buf) => {
69
r?;
70
71
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
72
match dispatcher.dispatch_with_partial(&frame.vec, frame.obj) {
73
MaybeSync::Sync(Some(v)) => {
74
let _ = write_tx.send(v).await;
75
},
76
MaybeSync::Sync(None) => continue,
77
MaybeSync::Future(fut) => {
78
let write_tx = write_tx.clone();
79
tokio::spawn(async move {
80
if let Some(v) = fut.await {
81
let _ = write_tx.send(v).await;
82
}
83
});
84
}
85
MaybeSync::Stream((stream, fut)) => {
86
if let Some(stream) = stream {
87
dispatcher.register_stream(write_tx.clone(), stream).await;
88
}
89
let write_tx = write_tx.clone();
90
tokio::spawn(async move {
91
if let Some(v) = fut.await {
92
let _ = write_tx.send(v).await;
93
}
94
});
95
}
96
}
97
};
98
},
99
Some(m) = write_rx.recv() => {
100
write.write_all(&m).await?;
101
},
102
Some(m) = msg_rx.recv_msg() => {
103
write.write_all(&m).await?;
104
},
105
r = &mut shutdown_fut => return Ok((r.ok(), read, write)),
106
}
107
108
write.flush().await?;
109
}
110
}
111
112
/// Reader that reads msgpack object messages in a cancellation-safe way using Tokio's codecs.
113
///
114
/// rmp_serde does not support async reads, and does not plan to. But we know every
115
/// type in protocol is some kind of object, so by asking to deserialize the
116
/// requested object from a reader (repeatedly, if incomplete) we can
117
/// accomplish streaming.
118
pub struct MsgPackCodec<T> {
119
_marker: std::marker::PhantomData<T>,
120
}
121
122
impl<T> MsgPackCodec<T> {
123
pub fn new() -> Self {
124
Self {
125
_marker: std::marker::PhantomData,
126
}
127
}
128
}
129
130
pub struct MsgPackDecoded<T> {
131
pub obj: T,
132
pub vec: Vec<u8>,
133
}
134
135
impl<T: DeserializeOwned> tokio_util::codec::Decoder for MsgPackCodec<T> {
136
type Item = MsgPackDecoded<T>;
137
type Error = io::Error;
138
139
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
140
let bytes_ref = src.as_ref();
141
let mut cursor = Cursor::new(bytes_ref);
142
143
match rmp_serde::decode::from_read::<_, T>(&mut cursor) {
144
Err(
145
rmp_serde::decode::Error::InvalidDataRead(e)
146
| rmp_serde::decode::Error::InvalidMarkerRead(e),
147
) if e.kind() == ErrorKind::UnexpectedEof => {
148
src.reserve(1024);
149
Ok(None)
150
}
151
Err(e) => Err(std::io::Error::new(
152
std::io::ErrorKind::InvalidData,
153
e.to_string(),
154
)),
155
Ok(obj) => {
156
let len = cursor.position() as usize;
157
let vec = src[..len].to_vec();
158
src.advance(len);
159
Ok(Some(MsgPackDecoded { obj, vec }))
160
}
161
}
162
}
163
}
164
165
#[cfg(test)]
166
mod tests {
167
use serde::{Deserialize, Serialize};
168
169
use super::*;
170
171
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)]
172
pub struct Msg {
173
pub x: i32,
174
}
175
176
#[test]
177
fn test_protocol() {
178
let mut c = MsgPackCodec::<Msg>::new();
179
let mut buf = bytes::BytesMut::new();
180
181
assert!(c.decode(&mut buf).unwrap().is_none());
182
183
buf.extend_from_slice(rmp_serde::to_vec_named(&Msg { x: 1 }).unwrap().as_slice());
184
buf.extend_from_slice(rmp_serde::to_vec_named(&Msg { x: 2 }).unwrap().as_slice());
185
186
assert_eq!(
187
c.decode(&mut buf).unwrap().expect("expected msg1").obj,
188
Msg { x: 1 }
189
);
190
assert_eq!(
191
c.decode(&mut buf).unwrap().expect("expected msg1").obj,
192
Msg { x: 2 }
193
);
194
}
195
}
196
197