Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/cli/src/rpc.rs
3309 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
use std::{
7
collections::HashMap,
8
future,
9
sync::{
10
atomic::{AtomicU32, Ordering},
11
Arc, Mutex,
12
},
13
};
14
15
use crate::log;
16
use futures::{future::BoxFuture, Future, FutureExt};
17
use serde::{de::DeserializeOwned, Deserialize, Serialize};
18
use tokio::{
19
io::{AsyncReadExt, AsyncWriteExt, DuplexStream, WriteHalf},
20
sync::{mpsc, oneshot},
21
};
22
23
use crate::util::errors::AnyError;
24
25
pub type SyncMethod = Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> Option<Vec<u8>>>;
26
pub type AsyncMethod =
27
Arc<dyn Send + Sync + Fn(Option<u32>, &[u8]) -> BoxFuture<'static, Option<Vec<u8>>>>;
28
pub type Duplex = Arc<
29
dyn Send
30
+ Sync
31
+ Fn(Option<u32>, &[u8]) -> (Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>),
32
>;
33
34
pub enum Method {
35
Sync(SyncMethod),
36
Async(AsyncMethod),
37
Duplex(Duplex),
38
}
39
40
/// Serialization is given to the RpcBuilder and defines how data gets serialized
41
/// when callinth methods.
42
pub trait Serialization: Send + Sync + 'static {
43
fn serialize(&self, value: impl Serialize) -> Vec<u8>;
44
fn deserialize<P: DeserializeOwned>(&self, b: &[u8]) -> Result<P, AnyError>;
45
}
46
47
/// RPC is a basic, transport-agnostic builder for RPC methods. You can
48
/// register methods to it, then call `.build()` to get a "dispatcher" type.
49
pub struct RpcBuilder<S> {
50
serializer: Arc<S>,
51
methods: HashMap<&'static str, Method>,
52
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
53
}
54
55
impl<S: Serialization> RpcBuilder<S> {
56
/// Creates a new empty RPC builder.
57
pub fn new(serializer: S) -> Self {
58
Self {
59
serializer: Arc::new(serializer),
60
methods: HashMap::new(),
61
calls: Arc::new(std::sync::Mutex::new(HashMap::new())),
62
}
63
}
64
65
/// Creates a caller that will be connected to any eventual dispatchers,
66
/// and that sends data to the "tx" channel.
67
pub fn get_caller(&mut self, sender: mpsc::UnboundedSender<Vec<u8>>) -> RpcCaller<S> {
68
RpcCaller {
69
serializer: self.serializer.clone(),
70
calls: self.calls.clone(),
71
sender,
72
}
73
}
74
75
/// Gets a method builder.
76
pub fn methods<C: Send + Sync + 'static>(self, context: C) -> RpcMethodBuilder<S, C> {
77
RpcMethodBuilder {
78
context: Arc::new(context),
79
serializer: self.serializer,
80
methods: self.methods,
81
calls: self.calls,
82
}
83
}
84
}
85
86
pub struct RpcMethodBuilder<S, C> {
87
context: Arc<C>,
88
serializer: Arc<S>,
89
methods: HashMap<&'static str, Method>,
90
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
91
}
92
93
#[derive(Serialize)]
94
struct DuplexStreamStarted {
95
pub for_request_id: u32,
96
pub stream_ids: Vec<u32>,
97
}
98
99
impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {
100
/// Registers a synchronous rpc call that returns its result directly.
101
pub fn register_sync<P, R, F>(&mut self, method_name: &'static str, callback: F)
102
where
103
P: DeserializeOwned,
104
R: Serialize,
105
F: Fn(P, &C) -> Result<R, AnyError> + Send + Sync + 'static,
106
{
107
if self.methods.contains_key(method_name) {
108
panic!("Method already registered: {method_name}");
109
}
110
111
let serial = self.serializer.clone();
112
let context = self.context.clone();
113
self.methods.insert(
114
method_name,
115
Method::Sync(Arc::new(move |id, body| {
116
let param = match serial.deserialize::<RequestParams<P>>(body) {
117
Ok(p) => p,
118
Err(err) => {
119
return id.map(|id| {
120
serial.serialize(ErrorResponse {
121
id,
122
error: ResponseError {
123
code: 0,
124
message: format!("{err:?}"),
125
},
126
})
127
})
128
}
129
};
130
131
match callback(param.params, &context) {
132
Ok(result) => id.map(|id| serial.serialize(&SuccessResponse { id, result })),
133
Err(err) => id.map(|id| {
134
serial.serialize(ErrorResponse {
135
id,
136
error: ResponseError {
137
code: -1,
138
message: format!("{err:?}"),
139
},
140
})
141
}),
142
}
143
})),
144
);
145
}
146
147
/// Registers an async rpc call that returns a Future.
148
pub fn register_async<P, R, Fut, F>(&mut self, method_name: &'static str, callback: F)
149
where
150
P: DeserializeOwned + Send + 'static,
151
R: Serialize + Send + Sync + 'static,
152
Fut: Future<Output = Result<R, AnyError>> + Send,
153
F: (Fn(P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
154
{
155
let serial = self.serializer.clone();
156
let context = self.context.clone();
157
self.methods.insert(
158
method_name,
159
Method::Async(Arc::new(move |id, body| {
160
let param = match serial.deserialize::<RequestParams<P>>(body) {
161
Ok(p) => p,
162
Err(err) => {
163
return future::ready(id.map(|id| {
164
serial.serialize(ErrorResponse {
165
id,
166
error: ResponseError {
167
code: 0,
168
message: format!("{err:?}"),
169
},
170
})
171
}))
172
.boxed();
173
}
174
};
175
176
let callback = callback.clone();
177
let serial = serial.clone();
178
let context = context.clone();
179
let fut = async move {
180
match callback(param.params, context).await {
181
Ok(result) => {
182
id.map(|id| serial.serialize(&SuccessResponse { id, result }))
183
}
184
Err(err) => id.map(|id| {
185
serial.serialize(ErrorResponse {
186
id,
187
error: ResponseError {
188
code: -1,
189
message: format!("{err:?}"),
190
},
191
})
192
}),
193
}
194
};
195
196
fut.boxed()
197
})),
198
);
199
}
200
201
/// Registers an async rpc call that returns a Future containing a duplex
202
/// stream that should be handled by the client.
203
pub fn register_duplex<P, R, Fut, F>(
204
&mut self,
205
method_name: &'static str,
206
streams: usize,
207
callback: F,
208
) where
209
P: DeserializeOwned + Send + 'static,
210
R: Serialize + Send + Sync + 'static,
211
Fut: Future<Output = Result<R, AnyError>> + Send,
212
F: (Fn(Vec<DuplexStream>, P, Arc<C>) -> Fut) + Clone + Send + Sync + 'static,
213
{
214
let serial = self.serializer.clone();
215
let context = self.context.clone();
216
self.methods.insert(
217
method_name,
218
Method::Duplex(Arc::new(move |id, body| {
219
let param = match serial.deserialize::<RequestParams<P>>(body) {
220
Ok(p) => p,
221
Err(err) => {
222
return (
223
None,
224
future::ready(id.map(|id| {
225
serial.serialize(ErrorResponse {
226
id,
227
error: ResponseError {
228
code: 0,
229
message: format!("{err:?}"),
230
},
231
})
232
}))
233
.boxed(),
234
);
235
}
236
};
237
238
let callback = callback.clone();
239
let serial = serial.clone();
240
let context = context.clone();
241
242
let mut dto = StreamDto {
243
req_id: id.unwrap_or(0),
244
streams: Vec::with_capacity(streams),
245
};
246
let mut servers = Vec::with_capacity(streams);
247
248
for _ in 0..streams {
249
let (client, server) = tokio::io::duplex(8192);
250
servers.push(server);
251
dto.streams.push((next_message_id(), client));
252
}
253
254
let fut = async move {
255
match callback(servers, param.params, context).await {
256
Ok(r) => id.map(|id| serial.serialize(&SuccessResponse { id, result: r })),
257
Err(err) => id.map(|id| {
258
serial.serialize(ErrorResponse {
259
id,
260
error: ResponseError {
261
code: -1,
262
message: format!("{err:?}"),
263
},
264
})
265
}),
266
}
267
};
268
269
(Some(dto), fut.boxed())
270
})),
271
);
272
}
273
274
/// Builds into a usable, sync rpc dispatcher.
275
pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> {
276
let streams = Streams::default();
277
278
let s1 = streams.clone();
279
self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
280
let s1 = s1.clone();
281
async move {
282
s1.remove(m.stream).await;
283
Ok(())
284
}
285
});
286
287
let s2 = streams.clone();
288
self.register_sync(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
289
s2.write(m.stream, m.segment);
290
Ok(())
291
});
292
293
RpcDispatcher {
294
log,
295
context: self.context,
296
calls: self.calls,
297
serializer: self.serializer,
298
methods: Arc::new(self.methods),
299
streams,
300
}
301
}
302
}
303
304
type DispatchMethod = Box<dyn Send + Sync + FnOnce(Outcome)>;
305
306
/// Dispatcher returned from a Builder that provides a transport-agnostic way to
307
/// deserialize and dispatch RPC calls. This structure may get more advanced as
308
/// time goes on...
309
#[derive(Clone)]
310
pub struct RpcCaller<S: Serialization> {
311
serializer: Arc<S>,
312
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
313
sender: mpsc::UnboundedSender<Vec<u8>>,
314
}
315
316
impl<S: Serialization> RpcCaller<S> {
317
pub fn serialize_notify<M, A>(serializer: &S, method: M, params: A) -> Vec<u8>
318
where
319
S: Serialization,
320
M: AsRef<str> + serde::Serialize,
321
A: Serialize,
322
{
323
serializer.serialize(&FullRequest {
324
id: None,
325
method,
326
params,
327
})
328
}
329
330
/// Enqueues an outbound call. Returns whether the message was enqueued.
331
pub fn notify<M, A>(&self, method: M, params: A) -> bool
332
where
333
M: AsRef<str> + serde::Serialize,
334
A: Serialize,
335
{
336
self.sender
337
.send(Self::serialize_notify(&self.serializer, method, params))
338
.is_ok()
339
}
340
341
/// Enqueues an outbound call, returning its result.
342
pub fn call<M, A, R>(&self, method: M, params: A) -> oneshot::Receiver<Result<R, ResponseError>>
343
where
344
M: AsRef<str> + serde::Serialize,
345
A: Serialize,
346
R: DeserializeOwned + Send + 'static,
347
{
348
let (tx, rx) = oneshot::channel();
349
let id = next_message_id();
350
let body = self.serializer.serialize(&FullRequest {
351
id: Some(id),
352
method,
353
params,
354
});
355
356
if self.sender.send(body).is_err() {
357
drop(tx);
358
return rx;
359
}
360
361
let serializer = self.serializer.clone();
362
self.calls.lock().unwrap().insert(
363
id,
364
Box::new(move |body| {
365
match body {
366
Outcome::Error(e) => tx.send(Err(e)).ok(),
367
Outcome::Success(r) => match serializer.deserialize::<SuccessResponse<R>>(&r) {
368
Ok(r) => tx.send(Ok(r.result)).ok(),
369
Err(err) => tx
370
.send(Err(ResponseError {
371
code: 0,
372
message: err.to_string(),
373
}))
374
.ok(),
375
},
376
};
377
}),
378
);
379
380
rx
381
}
382
}
383
384
/// Dispatcher returned from a Builder that provides a transport-agnostic way to
385
/// deserialize and handle RPC calls. This structure may get more advanced as
386
/// time goes on...
387
#[derive(Clone)]
388
pub struct RpcDispatcher<S, C> {
389
log: log::Logger,
390
context: Arc<C>,
391
serializer: Arc<S>,
392
methods: Arc<HashMap<&'static str, Method>>,
393
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
394
streams: Streams,
395
}
396
397
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
398
fn next_message_id() -> u32 {
399
MESSAGE_ID_COUNTER.fetch_add(1, Ordering::SeqCst)
400
}
401
402
impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
403
/// Runs the incoming request, returning the result of the call synchronously
404
/// or in a future. (The caller can then decide whether to run the future
405
/// sequentially in its receive loop, or not.)
406
///
407
/// The future or return result will be optional bytes that should be sent
408
/// back to the socket.
409
pub fn dispatch(&self, body: &[u8]) -> MaybeSync {
410
match self.serializer.deserialize::<PartialIncoming>(body) {
411
Ok(partial) => self.dispatch_with_partial(body, partial),
412
Err(_err) => {
413
warning!(self.log, "Failed to deserialize request, hex: {:X?}", body);
414
MaybeSync::Sync(None)
415
}
416
}
417
}
418
419
/// Like dispatch, but allows passing an existing PartialIncoming.
420
pub fn dispatch_with_partial(&self, body: &[u8], partial: PartialIncoming) -> MaybeSync {
421
let id = partial.id;
422
423
if let Some(method_name) = partial.method {
424
let method = self.methods.get(method_name.as_str());
425
match method {
426
Some(Method::Sync(callback)) => MaybeSync::Sync(callback(id, body)),
427
Some(Method::Async(callback)) => MaybeSync::Future(callback(id, body)),
428
Some(Method::Duplex(callback)) => MaybeSync::Stream(callback(id, body)),
429
None => MaybeSync::Sync(id.map(|id| {
430
self.serializer.serialize(ErrorResponse {
431
id,
432
error: ResponseError {
433
code: -1,
434
message: format!("Method not found: {method_name}"),
435
},
436
})
437
})),
438
}
439
} else if let Some(err) = partial.error {
440
if let Some(cb) = self.calls.lock().unwrap().remove(&id.unwrap()) {
441
cb(Outcome::Error(err));
442
}
443
MaybeSync::Sync(None)
444
} else {
445
if let Some(cb) = self.calls.lock().unwrap().remove(&id.unwrap()) {
446
cb(Outcome::Success(body.to_vec()));
447
}
448
MaybeSync::Sync(None)
449
}
450
}
451
452
/// Registers a stream call returned from dispatch().
453
pub async fn register_stream(
454
&self,
455
write_tx: mpsc::Sender<impl 'static + From<Vec<u8>> + Send>,
456
dto: StreamDto,
457
) {
458
let r = write_tx
459
.send(
460
self.serializer
461
.serialize(&FullRequest {
462
id: None,
463
method: METHOD_STREAMS_STARTED,
464
params: DuplexStreamStarted {
465
stream_ids: dto.streams.iter().map(|(id, _)| *id).collect(),
466
for_request_id: dto.req_id,
467
},
468
})
469
.into(),
470
)
471
.await;
472
473
if r.is_err() {
474
return;
475
}
476
477
for (stream_id, duplex) in dto.streams {
478
let (mut read, write) = tokio::io::split(duplex);
479
self.streams.insert(stream_id, write);
480
481
let write_tx = write_tx.clone();
482
let serial = self.serializer.clone();
483
tokio::spawn(async move {
484
let mut buf = vec![0; 4096];
485
loop {
486
match read.read(&mut buf).await {
487
Ok(0) | Err(_) => break,
488
Ok(n) => {
489
let r = write_tx
490
.send(
491
serial
492
.serialize(&FullRequest {
493
id: None,
494
method: METHOD_STREAM_DATA,
495
params: StreamDataParams {
496
segment: &buf[..n],
497
stream: stream_id,
498
},
499
})
500
.into(),
501
)
502
.await;
503
504
if r.is_err() {
505
return;
506
}
507
}
508
}
509
}
510
511
let _ = write_tx
512
.send(
513
serial
514
.serialize(&FullRequest {
515
id: None,
516
method: METHOD_STREAM_ENDED,
517
params: StreamEndedParams { stream: stream_id },
518
})
519
.into(),
520
)
521
.await;
522
});
523
}
524
}
525
526
pub fn context(&self) -> Arc<C> {
527
self.context.clone()
528
}
529
}
530
531
struct StreamRec {
532
write: Option<WriteHalf<DuplexStream>>,
533
q: Vec<Vec<u8>>,
534
ended: bool,
535
}
536
537
#[derive(Clone, Default)]
538
struct Streams {
539
map: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
540
}
541
542
impl Streams {
543
pub async fn remove(&self, id: u32) {
544
let mut remove = None;
545
546
{
547
let mut map = self.map.lock().unwrap();
548
if let Some(s) = map.get_mut(&id) {
549
if let Some(w) = s.write.take() {
550
map.remove(&id);
551
remove = Some(w);
552
} else {
553
s.ended = true; // will shut down in write loop
554
}
555
}
556
}
557
558
// do this outside of the sync lock:
559
if let Some(mut w) = remove {
560
let _ = w.shutdown().await;
561
}
562
}
563
564
pub fn write(&self, id: u32, buf: Vec<u8>) {
565
let mut map = self.map.lock().unwrap();
566
if let Some(s) = map.get_mut(&id) {
567
s.q.push(buf);
568
569
if let Some(w) = s.write.take() {
570
tokio::spawn(write_loop(id, w, self.map.clone()));
571
}
572
}
573
}
574
575
pub fn insert(&self, id: u32, stream: WriteHalf<DuplexStream>) {
576
self.map.lock().unwrap().insert(
577
id,
578
StreamRec {
579
write: Some(stream),
580
q: Vec::new(),
581
ended: false,
582
},
583
);
584
}
585
}
586
587
/// Write loop started by `Streams.write`. It takes the WriteHalf, and
588
/// runs until there's no more items in the 'write queue'. At that point, if the
589
/// record still exists in the `streams` (i.e. we haven't shut down), it'll
590
/// return the WriteHalf so that the next `write` call starts
591
/// the loop again. Otherwise, it'll shut down the WriteHalf.
592
///
593
/// This is the equivalent of the same write_loop in the server_multiplexer.
594
/// I couldn't figure out a nice way to abstract it without introducing
595
/// performance overhead...
596
async fn write_loop(
597
id: u32,
598
mut w: WriteHalf<DuplexStream>,
599
streams: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
600
) {
601
let mut items_vec = vec![];
602
loop {
603
{
604
let mut lock = streams.lock().unwrap();
605
let stream_rec = match lock.get_mut(&id) {
606
Some(b) => b,
607
None => break,
608
};
609
610
if stream_rec.q.is_empty() {
611
if stream_rec.ended {
612
lock.remove(&id);
613
break;
614
} else {
615
stream_rec.write = Some(w);
616
return;
617
}
618
}
619
620
std::mem::swap(&mut stream_rec.q, &mut items_vec);
621
}
622
623
for item in items_vec.drain(..) {
624
if w.write_all(&item).await.is_err() {
625
break;
626
}
627
}
628
}
629
630
let _ = w.shutdown().await; // got here from `break` above, meaning our record got cleared. Close the bridge if so
631
}
632
633
const METHOD_STREAMS_STARTED: &str = "streams_started";
634
const METHOD_STREAM_DATA: &str = "stream_data";
635
const METHOD_STREAM_ENDED: &str = "stream_ended";
636
637
#[allow(dead_code)] // false positive
638
trait AssertIsSync: Sync {}
639
impl<S: Serialization, C: Send + Sync> AssertIsSync for RpcDispatcher<S, C> {}
640
641
/// Approximate shape that is used to determine what kind of data is incoming.
642
#[derive(Deserialize, Debug)]
643
pub struct PartialIncoming {
644
pub id: Option<u32>,
645
pub method: Option<String>,
646
pub error: Option<ResponseError>,
647
}
648
649
#[derive(Deserialize)]
650
struct StreamDataIncomingParams {
651
#[serde(with = "serde_bytes")]
652
pub segment: Vec<u8>,
653
pub stream: u32,
654
}
655
656
#[derive(Serialize, Deserialize)]
657
struct StreamDataParams<'a> {
658
#[serde(with = "serde_bytes")]
659
pub segment: &'a [u8],
660
pub stream: u32,
661
}
662
663
#[derive(Serialize, Deserialize)]
664
struct StreamEndedParams {
665
pub stream: u32,
666
}
667
668
#[derive(Serialize)]
669
pub struct FullRequest<M: AsRef<str>, P> {
670
pub id: Option<u32>,
671
pub method: M,
672
pub params: P,
673
}
674
675
#[derive(Deserialize)]
676
struct RequestParams<P> {
677
pub params: P,
678
}
679
680
#[derive(Serialize, Deserialize)]
681
struct SuccessResponse<T> {
682
pub id: u32,
683
pub result: T,
684
}
685
686
#[derive(Serialize, Deserialize)]
687
struct ErrorResponse {
688
pub id: u32,
689
pub error: ResponseError,
690
}
691
692
#[derive(Serialize, Deserialize, Debug)]
693
pub struct ResponseError {
694
pub code: i32,
695
pub message: String,
696
}
697
698
enum Outcome {
699
Success(Vec<u8>),
700
Error(ResponseError),
701
}
702
703
pub struct StreamDto {
704
req_id: u32,
705
streams: Vec<(u32, DuplexStream)>,
706
}
707
708
pub enum MaybeSync {
709
Stream((Option<StreamDto>, BoxFuture<'static, Option<Vec<u8>>>)),
710
Future(BoxFuture<'static, Option<Vec<u8>>>),
711
Sync(Option<Vec<u8>>),
712
}
713
714
#[cfg(test)]
715
mod tests {
716
use super::*;
717
718
#[tokio::test]
719
async fn test_remove() {
720
let streams = Streams::default();
721
let (writer, mut reader) = tokio::io::duplex(1024);
722
streams.insert(1, tokio::io::split(writer).1);
723
streams.remove(1).await;
724
725
assert!(streams.map.lock().unwrap().get(&1).is_none());
726
let mut buffer = Vec::new();
727
assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 0);
728
}
729
730
#[tokio::test]
731
async fn test_write() {
732
let streams = Streams::default();
733
let (writer, mut reader) = tokio::io::duplex(1024);
734
streams.insert(1, tokio::io::split(writer).1);
735
streams.write(1, vec![1, 2, 3]);
736
737
let mut buffer = [0; 3];
738
assert_eq!(reader.read_exact(&mut buffer).await.unwrap(), 3);
739
assert_eq!(buffer, [1, 2, 3]);
740
}
741
742
#[tokio::test]
743
async fn test_write_with_immediate_end() {
744
let streams = Streams::default();
745
let (writer, mut reader) = tokio::io::duplex(1);
746
streams.insert(1, tokio::io::split(writer).1);
747
streams.write(1, vec![1, 2, 3]); // spawn write loop
748
streams.write(1, vec![4, 5, 6]); // enqueued while writing
749
streams.remove(1).await; // end stream
750
751
let mut buffer = Vec::new();
752
assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 6);
753
assert_eq!(buffer, vec![1, 2, 3, 4, 5, 6]);
754
}
755
}
756
757