Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/cli/src/tunnels/local_forwarding.rs
3314 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
6
use std::{
7
collections::HashMap,
8
ops::{Index, IndexMut},
9
sync::{Arc, Mutex},
10
};
11
12
use tokio::{
13
pin,
14
sync::{mpsc, watch},
15
};
16
17
use crate::{
18
async_pipe::{socket_stream_split, AsyncPipe},
19
json_rpc::{new_json_rpc, start_json_rpc},
20
log,
21
singleton::SingletonServer,
22
util::{errors::CodeError, sync::Barrier},
23
};
24
25
use super::{
26
dev_tunnels::ActiveTunnel,
27
protocol::{
28
self,
29
forward_singleton::{PortList, SetPortsResponse},
30
PortPrivacy, PortProtocol,
31
},
32
shutdown_signal::ShutdownSignal,
33
};
34
35
#[derive(Default, Clone)]
36
struct PortCount {
37
public: u32,
38
private: u32,
39
}
40
41
impl Index<PortPrivacy> for PortCount {
42
type Output = u32;
43
44
fn index(&self, privacy: PortPrivacy) -> &Self::Output {
45
match privacy {
46
PortPrivacy::Public => &self.public,
47
PortPrivacy::Private => &self.private,
48
}
49
}
50
}
51
52
impl IndexMut<PortPrivacy> for PortCount {
53
fn index_mut(&mut self, privacy: PortPrivacy) -> &mut Self::Output {
54
match privacy {
55
PortPrivacy::Public => &mut self.public,
56
PortPrivacy::Private => &mut self.private,
57
}
58
}
59
}
60
61
impl PortCount {
62
fn is_empty(&self) -> bool {
63
self.public == 0 && self.private == 0
64
}
65
66
fn primary_privacy(&self) -> PortPrivacy {
67
if self.public > 0 {
68
PortPrivacy::Public
69
} else {
70
PortPrivacy::Private
71
}
72
}
73
}
74
#[derive(Clone)]
75
struct PortMapRec {
76
count: PortCount,
77
protocol: PortProtocol,
78
}
79
80
type PortMap = HashMap<u16, PortMapRec>;
81
82
/// The PortForwardingHandle is given out to multiple consumers to allow
83
/// them to set_ports that they want to be forwarded.
84
struct PortForwardingSender {
85
/// Todo: when `SyncUnsafeCell` is no longer nightly, we can use it here with
86
/// the following comment:
87
///
88
/// SyncUnsafeCell is used and safe here because PortForwardingSender is used
89
/// exclusively in synchronous dispatch *and* we create a new sender in the
90
/// context for each connection, in `serve_singleton_rpc`.
91
///
92
/// If PortForwardingSender is ever used in a different context, this should
93
/// be refactored, e.g. to use locks or `&mut self` in set_ports`
94
///
95
/// see https://doc.rust-lang.org/stable/std/cell/struct.SyncUnsafeCell.html
96
current: Mutex<PortList>,
97
sender: Arc<Mutex<watch::Sender<PortMap>>>,
98
}
99
100
impl PortForwardingSender {
101
pub fn set_ports(&self, ports: PortList) {
102
let mut current = self.current.lock().unwrap();
103
self.sender.lock().unwrap().send_modify(|v| {
104
for p in current.iter() {
105
if !ports.contains(p) {
106
let n = v.get_mut(&p.number).expect("expected port in map");
107
n.count[p.privacy] -= 1;
108
if n.count.is_empty() {
109
v.remove(&p.number);
110
}
111
}
112
}
113
114
for p in ports.iter() {
115
if !current.contains(p) {
116
match v.get_mut(&p.number) {
117
Some(n) => {
118
n.count[p.privacy] += 1;
119
n.protocol = p.protocol;
120
}
121
None => {
122
let mut count = PortCount::default();
123
count[p.privacy] += 1;
124
v.insert(
125
p.number,
126
PortMapRec {
127
count,
128
protocol: p.protocol,
129
},
130
);
131
}
132
};
133
}
134
}
135
136
current.splice(.., ports);
137
});
138
}
139
}
140
141
impl Clone for PortForwardingSender {
142
fn clone(&self) -> Self {
143
Self {
144
current: Mutex::new(vec![]),
145
sender: self.sender.clone(),
146
}
147
}
148
}
149
150
impl Drop for PortForwardingSender {
151
fn drop(&mut self) {
152
self.set_ports(vec![]);
153
}
154
}
155
156
struct PortForwardingReceiver {
157
receiver: watch::Receiver<PortMap>,
158
}
159
160
impl PortForwardingReceiver {
161
pub fn new() -> (PortForwardingSender, Self) {
162
let (sender, receiver) = watch::channel(HashMap::new());
163
let handle = PortForwardingSender {
164
current: Mutex::new(vec![]),
165
sender: Arc::new(Mutex::new(sender)),
166
};
167
168
let tracker = Self { receiver };
169
170
(handle, tracker)
171
}
172
173
/// Applies all changes from PortForwardingHandles to the tunnel.
174
pub async fn apply_to(&mut self, log: log::Logger, tunnel: Arc<ActiveTunnel>) {
175
let mut current: PortMap = HashMap::new();
176
while self.receiver.changed().await.is_ok() {
177
let next = self.receiver.borrow().clone();
178
179
for (port, rec) in current.iter() {
180
let privacy = rec.count.primary_privacy();
181
if !matches!(next.get(port), Some(n) if n.count.primary_privacy() == privacy) {
182
match tunnel.remove_port(*port).await {
183
Ok(_) => info!(
184
log,
185
"stopped forwarding {} port {} at {:?}", rec.protocol, *port, privacy
186
),
187
Err(e) => error!(
188
log,
189
"failed to stop forwarding {} port {}: {}", rec.protocol, port, e
190
),
191
}
192
}
193
}
194
195
for (port, rec) in next.iter() {
196
let privacy = rec.count.primary_privacy();
197
if !matches!(current.get(port), Some(n) if n.count.primary_privacy() == privacy) {
198
match tunnel.add_port_tcp(*port, privacy, rec.protocol).await {
199
Ok(_) => info!(
200
log,
201
"forwarding {} port {} at {:?}", rec.protocol, port, privacy
202
),
203
Err(e) => error!(
204
log,
205
"failed to forward {} port {}: {}", rec.protocol, port, e
206
),
207
}
208
}
209
}
210
211
current = next;
212
}
213
}
214
}
215
216
pub struct SingletonClientArgs {
217
pub log: log::Logger,
218
pub stream: AsyncPipe,
219
pub shutdown: Barrier<ShutdownSignal>,
220
pub port_requests: watch::Receiver<PortList>,
221
}
222
223
#[derive(Clone)]
224
struct SingletonServerContext {
225
log: log::Logger,
226
handle: PortForwardingSender,
227
tunnel: Arc<ActiveTunnel>,
228
}
229
230
/// Serves a client singleton for port forwarding.
231
pub async fn client(args: SingletonClientArgs) -> Result<(), std::io::Error> {
232
let mut rpc = new_json_rpc();
233
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
234
let SingletonClientArgs {
235
log,
236
shutdown,
237
stream,
238
mut port_requests,
239
} = args;
240
241
debug!(
242
log,
243
"An existing port forwarding process is running on this machine, connecting to it..."
244
);
245
246
let caller = rpc.get_caller(msg_tx);
247
let rpc = rpc.methods(()).build(log.clone());
248
let (read, write) = socket_stream_split(stream);
249
250
let serve = start_json_rpc(rpc, read, write, msg_rx, shutdown);
251
let forward = async move {
252
while port_requests.changed().await.is_ok() {
253
let ports = port_requests.borrow().clone();
254
let r = caller
255
.call::<_, _, protocol::forward_singleton::SetPortsResponse>(
256
protocol::forward_singleton::METHOD_SET_PORTS,
257
protocol::forward_singleton::SetPortsParams { ports },
258
)
259
.await
260
.unwrap();
261
262
match r {
263
Err(e) => error!(log, "failed to set ports: {:?}", e),
264
Ok(r) => print_forwarding_addr(&r),
265
};
266
}
267
};
268
269
tokio::select! {
270
r = serve => r.map(|_| ()),
271
_ = forward => Ok(()),
272
}
273
}
274
275
/// Serves a port-forwarding singleton.
276
pub async fn server(
277
log: log::Logger,
278
tunnel: ActiveTunnel,
279
server: SingletonServer,
280
mut port_requests: watch::Receiver<PortList>,
281
shutdown_rx: Barrier<ShutdownSignal>,
282
) -> Result<(), CodeError> {
283
let tunnel = Arc::new(tunnel);
284
let (forward_tx, mut forward_rx) = PortForwardingReceiver::new();
285
286
let forward_own_tunnel = tunnel.clone();
287
let forward_own_tx = forward_tx.clone();
288
let forward_own = async move {
289
while port_requests.changed().await.is_ok() {
290
forward_own_tx.set_ports(port_requests.borrow().clone());
291
print_forwarding_addr(&SetPortsResponse {
292
port_format: forward_own_tunnel.get_port_format().ok(),
293
});
294
}
295
};
296
297
tokio::select! {
298
_ = forward_own => Ok(()),
299
_ = forward_rx.apply_to(log.clone(), tunnel.clone()) => Ok(()),
300
r = serve_singleton_rpc(server, log, tunnel, forward_tx, shutdown_rx) => r,
301
}
302
}
303
304
async fn serve_singleton_rpc(
305
mut server: SingletonServer,
306
log: log::Logger,
307
tunnel: Arc<ActiveTunnel>,
308
forward_tx: PortForwardingSender,
309
shutdown_rx: Barrier<ShutdownSignal>,
310
) -> Result<(), CodeError> {
311
let mut own_shutdown = shutdown_rx.clone();
312
let shutdown_fut = own_shutdown.wait();
313
pin!(shutdown_fut);
314
315
loop {
316
let cnx = tokio::select! {
317
c = server.accept() => c?,
318
_ = &mut shutdown_fut => return Ok(()),
319
};
320
321
let (read, write) = socket_stream_split(cnx);
322
let shutdown_rx = shutdown_rx.clone();
323
324
let handle = forward_tx.clone();
325
let log = log.clone();
326
let tunnel = tunnel.clone();
327
tokio::spawn(async move {
328
// we make an rpc for the connection instead of re-using a dispatcher
329
// so that we can have the "handle" drop when the connection drops.
330
let rpc = new_json_rpc();
331
let mut rpc = rpc.methods(SingletonServerContext {
332
log: log.clone(),
333
handle,
334
tunnel,
335
});
336
337
rpc.register_sync(
338
protocol::forward_singleton::METHOD_SET_PORTS,
339
|p: protocol::forward_singleton::SetPortsParams, ctx| {
340
info!(ctx.log, "client setting ports to {:?}", p.ports);
341
ctx.handle.set_ports(p.ports);
342
Ok(SetPortsResponse {
343
port_format: ctx.tunnel.get_port_format().ok(),
344
})
345
},
346
);
347
348
let _ = start_json_rpc(rpc.build(log), read, write, (), shutdown_rx).await;
349
});
350
}
351
}
352
353
fn print_forwarding_addr(r: &SetPortsResponse) {
354
eprintln!("{}\n", serde_json::to_string(r).unwrap());
355
}
356
357