Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/cli/src/util/sync.rs
3314 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
use async_trait::async_trait;
6
use std::{marker::PhantomData, sync::Arc};
7
use tokio::sync::{
8
broadcast, mpsc,
9
watch::{self, error::RecvError},
10
};
11
12
#[derive(Clone)]
13
pub struct Barrier<T>(watch::Receiver<Option<T>>)
14
where
15
T: Clone;
16
17
impl<T> Barrier<T>
18
where
19
T: Clone,
20
{
21
/// Waits for the barrier to be closed, returning a value if one was sent.
22
pub async fn wait(&mut self) -> Result<T, RecvError> {
23
loop {
24
self.0.changed().await?;
25
26
if let Some(v) = self.0.borrow().clone() {
27
return Ok(v);
28
}
29
}
30
}
31
32
/// Gets whether the barrier is currently open
33
pub fn is_open(&self) -> bool {
34
self.0.borrow().is_some()
35
}
36
}
37
38
#[async_trait]
39
impl<T: Clone + Send + Sync> Receivable<T> for Barrier<T> {
40
async fn recv_msg(&mut self) -> Option<T> {
41
self.wait().await.ok()
42
}
43
}
44
45
#[derive(Clone)]
46
pub struct BarrierOpener<T: Clone>(Arc<watch::Sender<Option<T>>>);
47
48
impl<T: Clone> BarrierOpener<T> {
49
/// Opens the barrier.
50
pub fn open(&self, value: T) {
51
self.0.send_if_modified(|v| {
52
if v.is_none() {
53
*v = Some(value);
54
true
55
} else {
56
false
57
}
58
});
59
}
60
}
61
62
/// The Barrier is something that can be opened once from one side,
63
/// and is thereafter permanently closed. It can contain a value.
64
pub fn new_barrier<T>() -> (Barrier<T>, BarrierOpener<T>)
65
where
66
T: Clone,
67
{
68
let (closed_tx, closed_rx) = watch::channel(None);
69
(Barrier(closed_rx), BarrierOpener(Arc::new(closed_tx)))
70
}
71
72
/// Type that can receive messages in an async way.
73
#[async_trait]
74
pub trait Receivable<T> {
75
async fn recv_msg(&mut self) -> Option<T>;
76
}
77
78
// todo: ideally we would use an Arc in the broadcast::Receiver to avoid having
79
// to clone bytes everywhere, requires updating rpc consumers as well.
80
#[async_trait]
81
impl<T: Clone + Send> Receivable<T> for broadcast::Receiver<T> {
82
async fn recv_msg(&mut self) -> Option<T> {
83
loop {
84
match self.recv().await {
85
Ok(v) => return Some(v),
86
Err(broadcast::error::RecvError::Lagged(_)) => continue,
87
Err(broadcast::error::RecvError::Closed) => return None,
88
}
89
}
90
}
91
}
92
93
#[async_trait]
94
impl<T: Send> Receivable<T> for mpsc::UnboundedReceiver<T> {
95
async fn recv_msg(&mut self) -> Option<T> {
96
self.recv().await
97
}
98
}
99
100
#[async_trait]
101
impl<T: Send> Receivable<T> for () {
102
async fn recv_msg(&mut self) -> Option<T> {
103
futures::future::pending().await
104
}
105
}
106
107
pub struct ConcatReceivable<T: Send, A: Receivable<T>, B: Receivable<T>> {
108
left: Option<A>,
109
right: B,
110
_marker: PhantomData<T>,
111
}
112
113
impl<T: Send, A: Receivable<T>, B: Receivable<T>> ConcatReceivable<T, A, B> {
114
pub fn new(left: A, right: B) -> Self {
115
Self {
116
left: Some(left),
117
right,
118
_marker: PhantomData,
119
}
120
}
121
}
122
123
#[async_trait]
124
impl<T: Send, A: Send + Receivable<T>, B: Send + Receivable<T>> Receivable<T>
125
for ConcatReceivable<T, A, B>
126
{
127
async fn recv_msg(&mut self) -> Option<T> {
128
if let Some(left) = &mut self.left {
129
match left.recv_msg().await {
130
Some(v) => return Some(v),
131
None => {
132
self.left = None;
133
}
134
}
135
}
136
137
return self.right.recv_msg().await;
138
}
139
}
140
141
pub struct MergedReceivable<T: Send, A: Receivable<T>, B: Receivable<T>> {
142
left: Option<A>,
143
right: Option<B>,
144
_marker: PhantomData<T>,
145
}
146
147
impl<T: Send, A: Receivable<T>, B: Receivable<T>> MergedReceivable<T, A, B> {
148
pub fn new(left: A, right: B) -> Self {
149
Self {
150
left: Some(left),
151
right: Some(right),
152
_marker: PhantomData,
153
}
154
}
155
}
156
157
#[async_trait]
158
impl<T: Send, A: Send + Receivable<T>, B: Send + Receivable<T>> Receivable<T>
159
for MergedReceivable<T, A, B>
160
{
161
async fn recv_msg(&mut self) -> Option<T> {
162
loop {
163
match (&mut self.left, &mut self.right) {
164
(Some(left), Some(right)) => {
165
tokio::select! {
166
left = left.recv_msg() => match left {
167
Some(v) => return Some(v),
168
None => { self.left = None; continue; },
169
},
170
right = right.recv_msg() => match right {
171
Some(v) => return Some(v),
172
None => { self.right = None; continue; },
173
},
174
}
175
}
176
(Some(a), None) => break a.recv_msg().await,
177
(None, Some(b)) => break b.recv_msg().await,
178
(None, None) => break None,
179
}
180
}
181
}
182
}
183
184
#[cfg(test)]
185
mod tests {
186
use super::*;
187
188
#[tokio::test]
189
async fn test_barrier_close_after_spawn() {
190
let (mut barrier, opener) = new_barrier::<u32>();
191
let (tx, rx) = tokio::sync::oneshot::channel::<u32>();
192
193
tokio::spawn(async move {
194
tx.send(barrier.wait().await.unwrap()).unwrap();
195
});
196
197
opener.open(42);
198
199
assert!(rx.await.unwrap() == 42);
200
}
201
202
#[tokio::test]
203
async fn test_barrier_close_before_spawn() {
204
let (barrier, opener) = new_barrier::<u32>();
205
let (tx1, rx1) = tokio::sync::oneshot::channel::<u32>();
206
let (tx2, rx2) = tokio::sync::oneshot::channel::<u32>();
207
208
opener.open(42);
209
let mut b1 = barrier.clone();
210
tokio::spawn(async move {
211
tx1.send(b1.wait().await.unwrap()).unwrap();
212
});
213
let mut b2 = barrier.clone();
214
tokio::spawn(async move {
215
tx2.send(b2.wait().await.unwrap()).unwrap();
216
});
217
218
assert!(rx1.await.unwrap() == 42);
219
assert!(rx2.await.unwrap() == 42);
220
}
221
}
222
223