Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/cli/src/tunnels/control_server.rs
3316 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
use crate::async_pipe::get_socket_rw_stream;
6
use crate::constants::{CONTROL_PORT, PRODUCT_NAME_LONG};
7
use crate::log;
8
use crate::msgpack_rpc::{new_msgpack_rpc, start_msgpack_rpc, MsgPackCodec, MsgPackSerializer};
9
use crate::options::Quality;
10
use crate::rpc::{MaybeSync, RpcBuilder, RpcCaller, RpcDispatcher};
11
use crate::self_update::SelfUpdate;
12
use crate::state::LauncherPaths;
13
use crate::tunnels::protocol::{HttpRequestParams, PortPrivacy, METHOD_CHALLENGE_ISSUE};
14
use crate::tunnels::socket_signal::CloseReason;
15
use crate::update_service::{Platform, Release, TargetKind, UpdateService};
16
use crate::util::command::new_tokio_command;
17
use crate::util::errors::{
18
wrap, AnyError, CodeError, MismatchedLaunchModeError, NoAttachedServerError,
19
};
20
use crate::util::http::{
21
DelegatedHttpRequest, DelegatedSimpleHttp, FallbackSimpleHttp, ReqwestSimpleHttp,
22
};
23
use crate::util::io::SilentCopyProgress;
24
use crate::util::is_integrated_cli;
25
use crate::util::machine::kill_pid;
26
use crate::util::os::os_release;
27
use crate::util::sync::{new_barrier, Barrier, BarrierOpener};
28
29
use futures::stream::FuturesUnordered;
30
use futures::FutureExt;
31
use opentelemetry::trace::SpanKind;
32
use opentelemetry::KeyValue;
33
use std::collections::HashMap;
34
use std::path::PathBuf;
35
use std::process::Stdio;
36
use tokio::net::TcpStream;
37
use tokio::pin;
38
use tokio::process::{ChildStderr, ChildStdin};
39
use tokio_util::codec::Decoder;
40
41
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
42
use std::sync::Arc;
43
use std::time::Instant;
44
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, DuplexStream};
45
use tokio::sync::{mpsc, Mutex};
46
47
use super::challenge::{create_challenge, sign_challenge, verify_challenge};
48
use super::code_server::{
49
download_cli_into_cache, AnyCodeServer, CodeServerArgs, ServerBuilder, ServerParamsRaw,
50
SocketCodeServer,
51
};
52
use super::dev_tunnels::ActiveTunnel;
53
use super::paths::prune_stopped_servers;
54
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
55
use super::protocol::{
56
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueParams,
57
ChallengeIssueResponse, ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams,
58
ForwardResult, FsReadDirEntry, FsReadDirResponse, FsRenameRequest, FsSinglePathRequest,
59
FsStatResponse, GetEnvResponse, GetHostnameResponse, HttpBodyParams, HttpHeadersParams,
60
NetConnectRequest, ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult,
61
SysKillRequest, SysKillResponse, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult,
62
VersionResponse, METHOD_CHALLENGE_VERIFY,
63
};
64
use super::server_bridge::ServerBridge;
65
use super::server_multiplexer::ServerMultiplexer;
66
use super::shutdown_signal::ShutdownSignal;
67
use super::socket_signal::{
68
ClientMessageDecoder, ServerMessageDestination, ServerMessageSink, SocketSignal,
69
};
70
71
type HttpRequestsMap = Arc<std::sync::Mutex<HashMap<u32, DelegatedHttpRequest>>>;
72
type CodeServerCell = Arc<Mutex<Option<SocketCodeServer>>>;
73
74
struct HandlerContext {
75
/// Log handle for the server
76
log: log::Logger,
77
/// Whether the server update during the handler session.
78
did_update: Arc<AtomicBool>,
79
/// Whether authentication is still required on the socket.
80
auth_state: Arc<std::sync::Mutex<AuthState>>,
81
/// A loopback channel to talk to the socket server task.
82
socket_tx: mpsc::Sender<SocketSignal>,
83
/// Configured launcher paths.
84
launcher_paths: LauncherPaths,
85
/// Connected VS Code Server
86
code_server: CodeServerCell,
87
/// Potentially many "websocket" connections to client
88
server_bridges: ServerMultiplexer,
89
// the cli arguments used to start the code server
90
code_server_args: CodeServerArgs,
91
/// port forwarding functionality
92
port_forwarding: Option<PortForwarding>,
93
/// install platform for the VS Code server
94
platform: Platform,
95
/// http client to make download/update requests
96
http: Arc<FallbackSimpleHttp>,
97
/// requests being served by the client
98
http_requests: HttpRequestsMap,
99
}
100
101
/// Handler auth state.
102
enum AuthState {
103
/// Auth is required, we're waiting for the client to send its challenge optionally bearing a token.
104
WaitingForChallenge(Option<String>),
105
/// A challenge has been issued. Waiting for a verification.
106
ChallengeIssued(String),
107
/// Auth is no longer required.
108
Authenticated,
109
}
110
111
static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
112
113
// Gets a next incrementing number that can be used in logs
114
pub fn next_message_id() -> u32 {
115
MESSAGE_ID_COUNTER.fetch_add(1, Ordering::SeqCst)
116
}
117
118
impl HandlerContext {
119
async fn dispose(&self) {
120
self.server_bridges.dispose().await;
121
info!(self.log, "Disposed of connection to running server.");
122
}
123
}
124
125
enum ServerSignal {
126
/// Signalled when the server has been updated and we want to respawn.
127
/// We'd generally need to stop and then restart the launcher, but the
128
/// program might be managed by a supervisor like systemd. Instead, we
129
/// will stop the TCP listener and spawn the launcher again as a subprocess
130
/// with the same arguments we used.
131
Respawn,
132
}
133
134
pub enum Next {
135
/// Whether the server should be respawned in a new binary (see ServerSignal.Respawn).
136
Respawn,
137
/// Whether the tunnel should be restarted
138
Restart,
139
/// Whether the process should exit
140
Exit,
141
}
142
143
pub struct ServerTermination {
144
pub next: Next,
145
pub tunnel: ActiveTunnel,
146
}
147
148
async fn preload_extensions(
149
log: &log::Logger,
150
platform: Platform,
151
mut args: CodeServerArgs,
152
launcher_paths: LauncherPaths,
153
) -> Result<(), AnyError> {
154
args.start_server = false;
155
156
let params_raw = ServerParamsRaw {
157
commit_id: None,
158
quality: Quality::Stable,
159
code_server_args: args.clone(),
160
headless: true,
161
platform,
162
};
163
164
// cannot use delegated HTTP here since there's no remote connection yet
165
let http = Arc::new(ReqwestSimpleHttp::new());
166
let resolved = params_raw.resolve(log, http.clone()).await?;
167
let sb = ServerBuilder::new(log, &resolved, &launcher_paths, http.clone());
168
169
sb.setup().await?;
170
sb.install_extensions().await
171
}
172
173
// Runs the launcher server. Exits on a ctrl+c or when requested by a user.
174
// Note that client connections may not be closed when this returns; use
175
// `close_all_clients()` on the ServerTermination to make this happen.
176
pub async fn serve(
177
log: &log::Logger,
178
mut tunnel: ActiveTunnel,
179
launcher_paths: &LauncherPaths,
180
code_server_args: &CodeServerArgs,
181
platform: Platform,
182
mut shutdown_rx: Barrier<ShutdownSignal>,
183
) -> Result<ServerTermination, AnyError> {
184
let mut port = tunnel.add_port_direct(CONTROL_PORT).await?;
185
let mut forwarding = PortForwardingProcessor::new();
186
let (tx, mut rx) = mpsc::channel::<ServerSignal>(4);
187
let (exit_barrier, signal_exit) = new_barrier();
188
189
if !code_server_args.install_extensions.is_empty() {
190
info!(
191
log,
192
"Preloading extensions using stable server: {:?}", code_server_args.install_extensions
193
);
194
let log = log.clone();
195
let code_server_args = code_server_args.clone();
196
let launcher_paths = launcher_paths.clone();
197
// This is run async to the primary tunnel setup to be speedy.
198
tokio::spawn(async move {
199
if let Err(e) =
200
preload_extensions(&log, platform, code_server_args, launcher_paths).await
201
{
202
warning!(log, "Failed to preload extensions: {:?}", e);
203
} else {
204
info!(log, "Extension install complete");
205
}
206
});
207
}
208
209
loop {
210
tokio::select! {
211
Ok(reason) = shutdown_rx.wait() => {
212
info!(log, "Shutting down: {}", reason);
213
drop(signal_exit);
214
return Ok(ServerTermination {
215
next: match reason {
216
ShutdownSignal::RpcRestartRequested => Next::Restart,
217
_ => Next::Exit,
218
},
219
tunnel,
220
});
221
},
222
c = rx.recv() => {
223
if let Some(ServerSignal::Respawn) = c {
224
drop(signal_exit);
225
return Ok(ServerTermination {
226
next: Next::Respawn,
227
tunnel,
228
});
229
}
230
},
231
Some(w) = forwarding.recv() => {
232
forwarding.process(w, &mut tunnel).await;
233
},
234
l = port.recv() => {
235
let socket = match l {
236
Some(p) => p,
237
None => {
238
warning!(log, "ssh tunnel disposed, tearing down");
239
return Ok(ServerTermination {
240
next: Next::Restart,
241
tunnel,
242
});
243
}
244
};
245
246
let own_log = log.prefixed(&log::new_rpc_prefix());
247
let own_tx = tx.clone();
248
let own_paths = launcher_paths.clone();
249
let own_exit = exit_barrier.clone();
250
let own_code_server_args = code_server_args.clone();
251
let own_forwarding = forwarding.handle();
252
253
tokio::spawn(async move {
254
use opentelemetry::trace::{FutureExt, TraceContextExt};
255
256
let span = own_log.span("server.socket").with_kind(SpanKind::Consumer).start(own_log.tracer());
257
let cx = opentelemetry::Context::current_with_span(span);
258
let serve_at = Instant::now();
259
260
debug!(own_log, "Serving new connection");
261
262
let (writehalf, readhalf) = socket.into_split();
263
let stats = process_socket(readhalf, writehalf, own_tx, Some(own_forwarding), ServeStreamParams {
264
log: own_log,
265
launcher_paths: own_paths,
266
code_server_args: own_code_server_args,
267
platform,
268
exit_barrier: own_exit,
269
requires_auth: AuthRequired::None,
270
}).with_context(cx.clone()).await;
271
272
cx.span().add_event(
273
"socket.bandwidth",
274
vec![
275
KeyValue::new("tx", stats.tx as f64),
276
KeyValue::new("rx", stats.rx as f64),
277
KeyValue::new("duration_ms", serve_at.elapsed().as_millis() as f64),
278
],
279
);
280
cx.span().end();
281
});
282
}
283
}
284
}
285
}
286
287
#[derive(Clone)]
288
pub enum AuthRequired {
289
None,
290
VSDA,
291
VSDAWithToken(String),
292
}
293
294
#[derive(Clone)]
295
pub struct ServeStreamParams {
296
pub log: log::Logger,
297
pub launcher_paths: LauncherPaths,
298
pub code_server_args: CodeServerArgs,
299
pub platform: Platform,
300
pub requires_auth: AuthRequired,
301
pub exit_barrier: Barrier<ShutdownSignal>,
302
}
303
304
pub async fn serve_stream(
305
readhalf: impl AsyncRead + Send + Unpin + 'static,
306
writehalf: impl AsyncWrite + Unpin,
307
params: ServeStreamParams,
308
) -> SocketStats {
309
// Currently the only server signal is respawn, that doesn't have much meaning
310
// when serving a stream, so make an ignored channel.
311
let (server_rx, server_tx) = mpsc::channel(1);
312
drop(server_tx);
313
314
process_socket(readhalf, writehalf, server_rx, None, params).await
315
}
316
317
pub struct SocketStats {
318
rx: usize,
319
tx: usize,
320
}
321
322
#[allow(clippy::too_many_arguments)]
323
fn make_socket_rpc(
324
log: log::Logger,
325
socket_tx: mpsc::Sender<SocketSignal>,
326
http_delegated: DelegatedSimpleHttp,
327
launcher_paths: LauncherPaths,
328
code_server_args: CodeServerArgs,
329
port_forwarding: Option<PortForwarding>,
330
requires_auth: AuthRequired,
331
platform: Platform,
332
http_requests: HttpRequestsMap,
333
) -> RpcDispatcher<MsgPackSerializer, HandlerContext> {
334
let server_bridges = ServerMultiplexer::new();
335
let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext {
336
did_update: Arc::new(AtomicBool::new(false)),
337
auth_state: Arc::new(std::sync::Mutex::new(match requires_auth {
338
AuthRequired::VSDAWithToken(t) => AuthState::WaitingForChallenge(Some(t)),
339
AuthRequired::VSDA => AuthState::WaitingForChallenge(None),
340
AuthRequired::None => AuthState::Authenticated,
341
})),
342
socket_tx,
343
log: log.clone(),
344
launcher_paths,
345
code_server_args,
346
code_server: Arc::new(Mutex::new(None)),
347
server_bridges,
348
port_forwarding,
349
platform,
350
http: Arc::new(FallbackSimpleHttp::new(
351
ReqwestSimpleHttp::new(),
352
http_delegated,
353
)),
354
http_requests,
355
});
356
357
rpc.register_sync("ping", |_: EmptyObject, _| Ok(EmptyObject {}));
358
rpc.register_sync("gethostname", |_: EmptyObject, _| handle_get_hostname());
359
rpc.register_sync("sys_kill", |p: SysKillRequest, c| {
360
ensure_auth(&c.auth_state)?;
361
handle_sys_kill(p.pid)
362
});
363
rpc.register_sync("fs_stat", |p: FsSinglePathRequest, c| {
364
ensure_auth(&c.auth_state)?;
365
handle_stat(p.path)
366
});
367
rpc.register_duplex(
368
"fs_read",
369
1,
370
move |mut streams, p: FsSinglePathRequest, c| async move {
371
ensure_auth(&c.auth_state)?;
372
handle_fs_read(streams.remove(0), p.path).await
373
},
374
);
375
rpc.register_duplex(
376
"fs_write",
377
1,
378
move |mut streams, p: FsSinglePathRequest, c| async move {
379
ensure_auth(&c.auth_state)?;
380
handle_fs_write(streams.remove(0), p.path).await
381
},
382
);
383
rpc.register_duplex(
384
"fs_connect",
385
1,
386
move |mut streams, p: FsSinglePathRequest, c| async move {
387
ensure_auth(&c.auth_state)?;
388
handle_fs_connect(streams.remove(0), p.path).await
389
},
390
);
391
rpc.register_duplex(
392
"net_connect",
393
1,
394
move |mut streams, n: NetConnectRequest, c| async move {
395
ensure_auth(&c.auth_state)?;
396
handle_net_connect(streams.remove(0), n).await
397
},
398
);
399
rpc.register_async("fs_rm", move |p: FsSinglePathRequest, c| async move {
400
ensure_auth(&c.auth_state)?;
401
handle_fs_remove(p.path).await
402
});
403
rpc.register_sync("fs_mkdirp", |p: FsSinglePathRequest, c| {
404
ensure_auth(&c.auth_state)?;
405
handle_fs_mkdirp(p.path)
406
});
407
rpc.register_sync("fs_rename", |p: FsRenameRequest, c| {
408
ensure_auth(&c.auth_state)?;
409
handle_fs_rename(p.from_path, p.to_path)
410
});
411
rpc.register_sync("fs_readdir", |p: FsSinglePathRequest, c| {
412
ensure_auth(&c.auth_state)?;
413
handle_fs_readdir(p.path)
414
});
415
rpc.register_sync("get_env", |_: EmptyObject, c| {
416
ensure_auth(&c.auth_state)?;
417
handle_get_env()
418
});
419
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |p: ChallengeIssueParams, c| {
420
handle_challenge_issue(p, &c.auth_state)
421
});
422
rpc.register_sync(METHOD_CHALLENGE_VERIFY, |p: ChallengeVerifyParams, c| {
423
handle_challenge_verify(p.response, &c.auth_state)
424
});
425
rpc.register_async("serve", move |params: ServeParams, c| async move {
426
ensure_auth(&c.auth_state)?;
427
handle_serve(c, params).await
428
});
429
rpc.register_async("update", |p: UpdateParams, c| async move {
430
handle_update(&c.http, &c.log, &c.did_update, &p).await
431
});
432
rpc.register_sync("servermsg", |m: ServerMessageParams, c| {
433
if let Err(e) = handle_server_message(&c.log, &c.server_bridges, m) {
434
warning!(c.log, "error handling call: {:?}", e);
435
}
436
Ok(EmptyObject {})
437
});
438
rpc.register_sync("prune", |_: EmptyObject, c| handle_prune(&c.launcher_paths));
439
rpc.register_async("callserverhttp", |p: CallServerHttpParams, c| async move {
440
let code_server = c.code_server.lock().await.clone();
441
handle_call_server_http(code_server, p).await
442
});
443
rpc.register_async("forward", |p: ForwardParams, c| async move {
444
ensure_auth(&c.auth_state)?;
445
handle_forward(&c.log, &c.port_forwarding, p).await
446
});
447
rpc.register_async("unforward", |p: UnforwardParams, c| async move {
448
ensure_auth(&c.auth_state)?;
449
handle_unforward(&c.log, &c.port_forwarding, p).await
450
});
451
rpc.register_async("acquire_cli", |p: AcquireCliParams, c| async move {
452
ensure_auth(&c.auth_state)?;
453
handle_acquire_cli(&c.launcher_paths, &c.http, &c.log, p).await
454
});
455
rpc.register_duplex("spawn", 3, |mut streams, p: SpawnParams, c| async move {
456
ensure_auth(&c.auth_state)?;
457
handle_spawn(
458
&c.log,
459
p,
460
Some(streams.remove(0)),
461
Some(streams.remove(0)),
462
Some(streams.remove(0)),
463
)
464
.await
465
});
466
rpc.register_duplex(
467
"spawn_cli",
468
3,
469
|mut streams, p: SpawnParams, c| async move {
470
ensure_auth(&c.auth_state)?;
471
handle_spawn_cli(
472
&c.log,
473
p,
474
streams.remove(0),
475
streams.remove(0),
476
streams.remove(0),
477
)
478
.await
479
},
480
);
481
rpc.register_sync("httpheaders", |p: HttpHeadersParams, c| {
482
if let Some(req) = c.http_requests.lock().unwrap().get(&p.req_id) {
483
trace!(c.log, "got {} response for req {}", p.status_code, p.req_id);
484
req.initial_response(p.status_code, p.headers);
485
} else {
486
warning!(c.log, "got response for unknown req {}", p.req_id);
487
}
488
Ok(EmptyObject {})
489
});
490
rpc.register_sync("httpbody", move |p: HttpBodyParams, c| {
491
let mut reqs = c.http_requests.lock().unwrap();
492
if let Some(req) = reqs.get(&p.req_id) {
493
if !p.segment.is_empty() {
494
req.body(p.segment);
495
}
496
if p.complete {
497
trace!(c.log, "delegated request {} completed", p.req_id);
498
reqs.remove(&p.req_id);
499
}
500
}
501
Ok(EmptyObject {})
502
});
503
rpc.register_sync(
504
"version",
505
|_: EmptyObject, _| Ok(VersionResponse::default()),
506
);
507
508
rpc.build(log)
509
}
510
511
fn ensure_auth(is_authed: &Arc<std::sync::Mutex<AuthState>>) -> Result<(), AnyError> {
512
if let AuthState::Authenticated = &*is_authed.lock().unwrap() {
513
Ok(())
514
} else {
515
Err(CodeError::ServerAuthRequired.into())
516
}
517
}
518
519
#[allow(clippy::too_many_arguments)] // necessary here
520
async fn process_socket(
521
readhalf: impl AsyncRead + Send + Unpin + 'static,
522
mut writehalf: impl AsyncWrite + Unpin,
523
server_tx: mpsc::Sender<ServerSignal>,
524
port_forwarding: Option<PortForwarding>,
525
params: ServeStreamParams,
526
) -> SocketStats {
527
let ServeStreamParams {
528
mut exit_barrier,
529
log,
530
launcher_paths,
531
code_server_args,
532
platform,
533
requires_auth,
534
} = params;
535
536
let (http_delegated, mut http_rx) = DelegatedSimpleHttp::new(log.clone());
537
let (socket_tx, mut socket_rx) = mpsc::channel(4);
538
let rx_counter = Arc::new(AtomicUsize::new(0));
539
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
540
541
let already_authed = matches!(requires_auth, AuthRequired::None);
542
let rpc = make_socket_rpc(
543
log.clone(),
544
socket_tx.clone(),
545
http_delegated,
546
launcher_paths,
547
code_server_args,
548
port_forwarding,
549
requires_auth,
550
platform,
551
http_requests.clone(),
552
);
553
554
{
555
let log = log.clone();
556
let rx_counter = rx_counter.clone();
557
let socket_tx = socket_tx.clone();
558
let exit_barrier = exit_barrier.clone();
559
tokio::spawn(async move {
560
if already_authed {
561
send_version(&socket_tx).await;
562
}
563
564
if let Err(e) =
565
handle_socket_read(&log, readhalf, exit_barrier, &socket_tx, rx_counter, &rpc).await
566
{
567
debug!(log, "closing socket reader: {}", e);
568
socket_tx
569
.send(SocketSignal::CloseWith(CloseReason(format!("{e}"))))
570
.await
571
.ok();
572
}
573
574
let ctx = rpc.context();
575
576
// The connection is now closed, asked to respawn if needed
577
if ctx.did_update.load(Ordering::SeqCst) {
578
server_tx.send(ServerSignal::Respawn).await.ok();
579
}
580
581
ctx.dispose().await;
582
583
let _ = socket_tx
584
.send(SocketSignal::CloseWith(CloseReason("eof".to_string())))
585
.await;
586
});
587
}
588
589
let mut tx_counter = 0;
590
591
loop {
592
tokio::select! {
593
_ = exit_barrier.wait() => {
594
writehalf.shutdown().await.ok();
595
break;
596
},
597
Some(r) = http_rx.recv() => {
598
let id = next_message_id();
599
let serialized = rmp_serde::to_vec_named(&ToClientRequest {
600
id: None,
601
params: ClientRequestMethod::makehttpreq(HttpRequestParams {
602
url: &r.url,
603
method: r.method,
604
req_id: id,
605
}),
606
})
607
.unwrap();
608
609
http_requests.lock().unwrap().insert(id, r);
610
611
tx_counter += serialized.len();
612
if let Err(e) = writehalf.write_all(&serialized).await {
613
debug!(log, "Closing connection: {}", e);
614
break;
615
}
616
}
617
recv = socket_rx.recv() => match recv {
618
None => break,
619
Some(message) => match message {
620
SocketSignal::Send(bytes) => {
621
tx_counter += bytes.len();
622
if let Err(e) = writehalf.write_all(&bytes).await {
623
debug!(log, "Closing connection: {}", e);
624
break;
625
}
626
}
627
SocketSignal::CloseWith(reason) => {
628
debug!(log, "Closing connection: {}", reason.0);
629
break;
630
}
631
}
632
}
633
}
634
}
635
636
SocketStats {
637
tx: tx_counter,
638
rx: rx_counter.load(Ordering::Acquire),
639
}
640
}
641
642
async fn send_version(tx: &mpsc::Sender<SocketSignal>) {
643
tx.send(SocketSignal::from_message(&ToClientRequest {
644
id: None,
645
params: ClientRequestMethod::version(VersionResponse::default()),
646
}))
647
.await
648
.ok();
649
}
650
async fn handle_socket_read(
651
_log: &log::Logger,
652
readhalf: impl AsyncRead + Unpin,
653
mut closer: Barrier<ShutdownSignal>,
654
socket_tx: &mpsc::Sender<SocketSignal>,
655
rx_counter: Arc<AtomicUsize>,
656
rpc: &RpcDispatcher<MsgPackSerializer, HandlerContext>,
657
) -> Result<(), std::io::Error> {
658
let mut readhalf = BufReader::new(readhalf);
659
let mut decoder = MsgPackCodec::new();
660
let mut decoder_buf = bytes::BytesMut::new();
661
662
loop {
663
let read_len = tokio::select! {
664
r = readhalf.read_buf(&mut decoder_buf) => r,
665
_ = closer.wait() => Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "eof")),
666
}?;
667
668
if read_len == 0 {
669
return Ok(());
670
}
671
672
rx_counter.fetch_add(read_len, Ordering::Relaxed);
673
674
while let Some(frame) = decoder.decode(&mut decoder_buf)? {
675
match rpc.dispatch_with_partial(&frame.vec, frame.obj) {
676
MaybeSync::Sync(Some(v)) => {
677
if socket_tx.send(SocketSignal::Send(v)).await.is_err() {
678
return Ok(());
679
}
680
}
681
MaybeSync::Sync(None) => continue,
682
MaybeSync::Future(fut) => {
683
let socket_tx = socket_tx.clone();
684
tokio::spawn(async move {
685
if let Some(v) = fut.await {
686
socket_tx.send(SocketSignal::Send(v)).await.ok();
687
}
688
});
689
}
690
MaybeSync::Stream((stream, fut)) => {
691
if let Some(stream) = stream {
692
rpc.register_stream(socket_tx.clone(), stream).await;
693
}
694
let socket_tx = socket_tx.clone();
695
tokio::spawn(async move {
696
if let Some(v) = fut.await {
697
socket_tx.send(SocketSignal::Send(v)).await.ok();
698
}
699
});
700
}
701
}
702
}
703
}
704
}
705
706
#[derive(Clone)]
707
struct ServerOutputSink {
708
tx: mpsc::Sender<SocketSignal>,
709
}
710
711
impl log::LogSink for ServerOutputSink {
712
fn write_log(&self, level: log::Level, _prefix: &str, message: &str) {
713
let s = SocketSignal::from_message(&ToClientRequest {
714
id: None,
715
params: ClientRequestMethod::serverlog(ServerLog {
716
line: message,
717
level: level.to_u8(),
718
}),
719
});
720
721
self.tx.try_send(s).ok();
722
}
723
724
fn write_result(&self, _message: &str) {}
725
}
726
727
async fn handle_serve(
728
c: Arc<HandlerContext>,
729
params: ServeParams,
730
) -> Result<EmptyObject, AnyError> {
731
// fill params.extensions into code_server_args.install_extensions
732
let mut csa = c.code_server_args.clone();
733
csa.connection_token = params.connection_token.or(csa.connection_token);
734
csa.install_extensions.extend(params.extensions.into_iter());
735
736
let params_raw = ServerParamsRaw {
737
commit_id: params.commit_id,
738
quality: params.quality,
739
code_server_args: csa,
740
headless: true,
741
platform: c.platform,
742
};
743
744
let resolved = if params.use_local_download {
745
params_raw
746
.resolve(&c.log, Arc::new(c.http.delegated()))
747
.await
748
} else {
749
params_raw.resolve(&c.log, c.http.clone()).await
750
}?;
751
752
let mut server_ref = c.code_server.lock().await;
753
let server = match &*server_ref {
754
Some(o) => o.clone(),
755
None => {
756
let install_log = c.log.tee(ServerOutputSink {
757
tx: c.socket_tx.clone(),
758
});
759
760
macro_rules! do_setup {
761
($sb:expr) => {
762
match $sb.get_running().await? {
763
Some(AnyCodeServer::Socket(s)) => ($sb, Ok(s)),
764
Some(_) => return Err(AnyError::from(MismatchedLaunchModeError())),
765
None => {
766
$sb.setup().await?;
767
let r = $sb.listen_on_default_socket().await;
768
($sb, r)
769
}
770
}
771
};
772
}
773
774
let (sb, server) = if params.use_local_download {
775
let sb = ServerBuilder::new(
776
&install_log,
777
&resolved,
778
&c.launcher_paths,
779
Arc::new(c.http.delegated()),
780
);
781
do_setup!(sb)
782
} else {
783
let sb =
784
ServerBuilder::new(&install_log, &resolved, &c.launcher_paths, c.http.clone());
785
do_setup!(sb)
786
};
787
788
let server = match server {
789
Ok(s) => s,
790
Err(e) => {
791
// we don't loop to avoid doing so infinitely: allow the client to reconnect in this case.
792
if let AnyError::CodeError(CodeError::ServerUnexpectedExit(ref e)) = e {
793
warning!(
794
c.log,
795
"({}), removing server due to possible corruptions",
796
e
797
);
798
if let Err(e) = sb.evict().await {
799
warning!(c.log, "Failed to evict server: {}", e);
800
}
801
}
802
return Err(e);
803
}
804
};
805
806
server_ref.replace(server.clone());
807
server
808
}
809
};
810
811
attach_server_bridge(
812
&c.log,
813
server,
814
c.socket_tx.clone(),
815
c.server_bridges.clone(),
816
params.socket_id,
817
params.compress,
818
)
819
.await?;
820
Ok(EmptyObject {})
821
}
822
823
async fn attach_server_bridge(
824
log: &log::Logger,
825
code_server: SocketCodeServer,
826
socket_tx: mpsc::Sender<SocketSignal>,
827
multiplexer: ServerMultiplexer,
828
socket_id: u16,
829
compress: bool,
830
) -> Result<u16, AnyError> {
831
let (server_messages, decoder) = if compress {
832
(
833
ServerMessageSink::new_compressed(
834
multiplexer.clone(),
835
socket_id,
836
ServerMessageDestination::Channel(socket_tx),
837
),
838
ClientMessageDecoder::new_compressed(),
839
)
840
} else {
841
(
842
ServerMessageSink::new_plain(
843
multiplexer.clone(),
844
socket_id,
845
ServerMessageDestination::Channel(socket_tx),
846
),
847
ClientMessageDecoder::new_plain(),
848
)
849
};
850
851
let attached_fut = ServerBridge::new(&code_server.socket, server_messages, decoder).await;
852
match attached_fut {
853
Ok(a) => {
854
multiplexer.register(socket_id, a);
855
trace!(log, "Attached to server");
856
Ok(socket_id)
857
}
858
Err(e) => Err(e),
859
}
860
}
861
862
/// Handle an incoming server message. This is synchronous and uses a 'write loop'
863
/// to ensure message order is preserved exactly, which is necessary for compression.
864
fn handle_server_message(
865
log: &log::Logger,
866
multiplexer: &ServerMultiplexer,
867
params: ServerMessageParams,
868
) -> Result<EmptyObject, AnyError> {
869
if multiplexer.write_message(log, params.i, params.body) {
870
Ok(EmptyObject {})
871
} else {
872
Err(AnyError::from(NoAttachedServerError()))
873
}
874
}
875
876
fn handle_prune(paths: &LauncherPaths) -> Result<Vec<String>, AnyError> {
877
prune_stopped_servers(paths).map(|v| {
878
v.iter()
879
.map(|p| p.server_dir.display().to_string())
880
.collect()
881
})
882
}
883
884
async fn handle_update(
885
http: &Arc<FallbackSimpleHttp>,
886
log: &log::Logger,
887
did_update: &AtomicBool,
888
params: &UpdateParams,
889
) -> Result<UpdateResult, AnyError> {
890
if matches!(is_integrated_cli(), Ok(true)) || did_update.load(Ordering::SeqCst) {
891
return Ok(UpdateResult {
892
up_to_date: true,
893
did_update: false,
894
});
895
}
896
897
let update_service = UpdateService::new(log.clone(), http.clone());
898
let updater = SelfUpdate::new(&update_service)?;
899
let latest_release = updater.get_current_release().await?;
900
let up_to_date = updater.is_up_to_date_with(&latest_release);
901
902
let _ = updater.cleanup_old_update();
903
904
if !params.do_update || up_to_date {
905
return Ok(UpdateResult {
906
up_to_date,
907
did_update: false,
908
});
909
}
910
911
if did_update
912
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
913
.is_err()
914
{
915
return Ok(UpdateResult {
916
up_to_date: true,
917
did_update: true, // well, another thread did, but same difference...
918
});
919
}
920
921
info!(log, "Updating CLI to {}", latest_release);
922
923
let r = updater
924
.do_update(&latest_release, SilentCopyProgress())
925
.await;
926
927
if let Err(e) = r {
928
did_update.store(false, Ordering::SeqCst);
929
return Err(e);
930
}
931
932
Ok(UpdateResult {
933
up_to_date: true,
934
did_update: true,
935
})
936
}
937
938
fn handle_get_hostname() -> Result<GetHostnameResponse, AnyError> {
939
Ok(GetHostnameResponse {
940
value: gethostname::gethostname().to_string_lossy().into_owned(),
941
})
942
}
943
944
fn handle_stat(path: String) -> Result<FsStatResponse, AnyError> {
945
Ok(std::fs::metadata(path)
946
.map(|m| FsStatResponse {
947
exists: true,
948
size: Some(m.len()),
949
kind: Some(m.file_type().into()),
950
})
951
.unwrap_or_default())
952
}
953
954
async fn handle_fs_read(mut out: DuplexStream, path: String) -> Result<EmptyObject, AnyError> {
955
let mut f = tokio::fs::File::open(path)
956
.await
957
.map_err(|e| wrap(e, "file not found"))?;
958
959
tokio::io::copy(&mut f, &mut out)
960
.await
961
.map_err(|e| wrap(e, "error reading file"))?;
962
963
Ok(EmptyObject {})
964
}
965
966
async fn handle_fs_write(mut input: DuplexStream, path: String) -> Result<EmptyObject, AnyError> {
967
let mut f = tokio::fs::File::create(path)
968
.await
969
.map_err(|e| wrap(e, "file not found"))?;
970
971
tokio::io::copy(&mut input, &mut f)
972
.await
973
.map_err(|e| wrap(e, "error writing file"))?;
974
975
Ok(EmptyObject {})
976
}
977
978
async fn handle_net_connect(
979
mut stream: DuplexStream,
980
req: NetConnectRequest,
981
) -> Result<EmptyObject, AnyError> {
982
let mut s = TcpStream::connect((req.host, req.port))
983
.await
984
.map_err(|e| wrap(e, "could not connect to address"))?;
985
986
tokio::io::copy_bidirectional(&mut stream, &mut s)
987
.await
988
.map_err(|e| wrap(e, "error copying stream data"))?;
989
990
Ok(EmptyObject {})
991
}
992
async fn handle_fs_connect(
993
mut stream: DuplexStream,
994
path: String,
995
) -> Result<EmptyObject, AnyError> {
996
let mut s = get_socket_rw_stream(&PathBuf::from(path))
997
.await
998
.map_err(|e| wrap(e, "could not connect to socket"))?;
999
1000
tokio::io::copy_bidirectional(&mut stream, &mut s)
1001
.await
1002
.map_err(|e| wrap(e, "error copying stream data"))?;
1003
1004
Ok(EmptyObject {})
1005
}
1006
1007
async fn handle_fs_remove(path: String) -> Result<EmptyObject, AnyError> {
1008
tokio::fs::remove_dir_all(path)
1009
.await
1010
.map_err(|e| wrap(e, "error removing directory"))?;
1011
Ok(EmptyObject {})
1012
}
1013
1014
fn handle_fs_rename(from_path: String, to_path: String) -> Result<EmptyObject, AnyError> {
1015
std::fs::rename(from_path, to_path).map_err(|e| wrap(e, "error renaming"))?;
1016
Ok(EmptyObject {})
1017
}
1018
1019
fn handle_fs_mkdirp(path: String) -> Result<EmptyObject, AnyError> {
1020
std::fs::create_dir_all(path).map_err(|e| wrap(e, "error creating directory"))?;
1021
Ok(EmptyObject {})
1022
}
1023
1024
fn handle_fs_readdir(path: String) -> Result<FsReadDirResponse, AnyError> {
1025
let mut entries = std::fs::read_dir(path).map_err(|e| wrap(e, "error listing directory"))?;
1026
1027
let mut contents = Vec::new();
1028
while let Some(Ok(child)) = entries.next() {
1029
contents.push(FsReadDirEntry {
1030
name: child.file_name().to_string_lossy().into_owned(),
1031
kind: child.file_type().ok().map(|v| v.into()),
1032
});
1033
}
1034
1035
Ok(FsReadDirResponse { contents })
1036
}
1037
1038
fn handle_sys_kill(pid: u32) -> Result<SysKillResponse, AnyError> {
1039
Ok(SysKillResponse {
1040
success: kill_pid(pid),
1041
})
1042
}
1043
1044
fn handle_get_env() -> Result<GetEnvResponse, AnyError> {
1045
Ok(GetEnvResponse {
1046
env: std::env::vars().collect(),
1047
os_release: os_release().unwrap_or_else(|_| "unknown".to_string()),
1048
#[cfg(windows)]
1049
os_platform: "win32",
1050
#[cfg(target_os = "linux")]
1051
os_platform: "linux",
1052
#[cfg(target_os = "macos")]
1053
os_platform: "darwin",
1054
})
1055
}
1056
1057
fn handle_challenge_issue(
1058
params: ChallengeIssueParams,
1059
auth_state: &Arc<std::sync::Mutex<AuthState>>,
1060
) -> Result<ChallengeIssueResponse, AnyError> {
1061
let challenge = create_challenge();
1062
1063
let mut auth_state = auth_state.lock().unwrap();
1064
if let AuthState::WaitingForChallenge(Some(s)) = &*auth_state {
1065
match &params.token {
1066
Some(t) if s != t => return Err(CodeError::AuthChallengeBadToken.into()),
1067
None => return Err(CodeError::AuthChallengeBadToken.into()),
1068
_ => {}
1069
}
1070
}
1071
1072
*auth_state = AuthState::ChallengeIssued(challenge.clone());
1073
Ok(ChallengeIssueResponse { challenge })
1074
}
1075
1076
fn handle_challenge_verify(
1077
response: String,
1078
auth_state: &Arc<std::sync::Mutex<AuthState>>,
1079
) -> Result<EmptyObject, AnyError> {
1080
let mut auth_state = auth_state.lock().unwrap();
1081
1082
match &*auth_state {
1083
AuthState::Authenticated => Ok(EmptyObject {}),
1084
AuthState::WaitingForChallenge(_) => Err(CodeError::AuthChallengeNotIssued.into()),
1085
AuthState::ChallengeIssued(c) => match verify_challenge(c, &response) {
1086
false => Err(CodeError::AuthChallengeNotIssued.into()),
1087
true => {
1088
*auth_state = AuthState::Authenticated;
1089
Ok(EmptyObject {})
1090
}
1091
},
1092
}
1093
}
1094
1095
async fn handle_forward(
1096
log: &log::Logger,
1097
port_forwarding: &Option<PortForwarding>,
1098
params: ForwardParams,
1099
) -> Result<ForwardResult, AnyError> {
1100
let port_forwarding = port_forwarding
1101
.as_ref()
1102
.ok_or(CodeError::PortForwardingNotAvailable)?;
1103
info!(
1104
log,
1105
"Forwarding port {} (public={})", params.port, params.public
1106
);
1107
let privacy = match params.public {
1108
true => PortPrivacy::Public,
1109
false => PortPrivacy::Private,
1110
};
1111
1112
let uri = port_forwarding.forward(params.port, privacy).await?;
1113
Ok(ForwardResult { uri })
1114
}
1115
1116
async fn handle_unforward(
1117
log: &log::Logger,
1118
port_forwarding: &Option<PortForwarding>,
1119
params: UnforwardParams,
1120
) -> Result<EmptyObject, AnyError> {
1121
let port_forwarding = port_forwarding
1122
.as_ref()
1123
.ok_or(CodeError::PortForwardingNotAvailable)?;
1124
info!(log, "Unforwarding port {}", params.port);
1125
port_forwarding.unforward(params.port).await?;
1126
Ok(EmptyObject {})
1127
}
1128
1129
async fn handle_call_server_http(
1130
code_server: Option<SocketCodeServer>,
1131
params: CallServerHttpParams,
1132
) -> Result<CallServerHttpResult, AnyError> {
1133
use hyper::{body, client::conn::Builder, Body, Request};
1134
1135
// We use Hyper directly here since reqwest doesn't support sockets/pipes.
1136
// See https://github.com/seanmonstar/reqwest/issues/39
1137
1138
let socket = match &code_server {
1139
Some(cs) => &cs.socket,
1140
None => return Err(AnyError::from(NoAttachedServerError())),
1141
};
1142
1143
let rw = get_socket_rw_stream(socket).await?;
1144
1145
let (mut request_sender, connection) = Builder::new()
1146
.handshake(rw)
1147
.await
1148
.map_err(|e| wrap(e, "error establishing connection"))?;
1149
1150
// start the connection processing; it's shut down when the sender is dropped
1151
tokio::spawn(connection);
1152
1153
let mut request_builder = Request::builder()
1154
.method::<&str>(params.method.as_ref())
1155
.uri(format!("http://127.0.0.1{}", params.path))
1156
.header("Host", "127.0.0.1");
1157
1158
for (k, v) in params.headers {
1159
request_builder = request_builder.header(k, v);
1160
}
1161
let request = request_builder
1162
.body(Body::from(params.body.unwrap_or_default()))
1163
.map_err(|e| wrap(e, "invalid request"))?;
1164
1165
let response = request_sender
1166
.send_request(request)
1167
.await
1168
.map_err(|e| wrap(e, "error sending request"))?;
1169
1170
Ok(CallServerHttpResult {
1171
status: response.status().as_u16(),
1172
headers: response
1173
.headers()
1174
.into_iter()
1175
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
1176
.collect(),
1177
body: body::to_bytes(response)
1178
.await
1179
.map_err(|e| wrap(e, "error reading response body"))?
1180
.to_vec(),
1181
})
1182
}
1183
1184
async fn handle_acquire_cli(
1185
paths: &LauncherPaths,
1186
http: &Arc<FallbackSimpleHttp>,
1187
log: &log::Logger,
1188
params: AcquireCliParams,
1189
) -> Result<SpawnResult, AnyError> {
1190
let update_service = UpdateService::new(log.clone(), http.clone());
1191
1192
let release = match params.commit_id {
1193
Some(commit) => Release {
1194
name: format!("{PRODUCT_NAME_LONG} CLI"),
1195
commit,
1196
platform: params.platform,
1197
quality: params.quality,
1198
target: TargetKind::Cli,
1199
},
1200
None => {
1201
update_service
1202
.get_latest_commit(params.platform, TargetKind::Cli, params.quality)
1203
.await?
1204
}
1205
};
1206
1207
let cli = download_cli_into_cache(&paths.cli_cache, &release, &update_service).await?;
1208
let file = tokio::fs::File::open(cli)
1209
.await
1210
.map_err(|e| wrap(e, "error opening cli file"))?;
1211
1212
handle_spawn::<_, DuplexStream>(log, params.spawn, Some(file), None, None).await
1213
}
1214
1215
async fn handle_spawn<Stdin, StdoutAndErr>(
1216
log: &log::Logger,
1217
params: SpawnParams,
1218
stdin: Option<Stdin>,
1219
stdout: Option<StdoutAndErr>,
1220
stderr: Option<StdoutAndErr>,
1221
) -> Result<SpawnResult, AnyError>
1222
where
1223
Stdin: AsyncRead + Unpin + Send + 'static,
1224
StdoutAndErr: AsyncWrite + Unpin + Send + 'static,
1225
{
1226
debug!(
1227
log,
1228
"requested to spawn {} with args {:?}", params.command, params.args
1229
);
1230
1231
macro_rules! pipe_if {
1232
($e: expr) => {
1233
if $e {
1234
Stdio::piped()
1235
} else {
1236
Stdio::null()
1237
}
1238
};
1239
}
1240
1241
let mut p = new_tokio_command(&params.command);
1242
p.args(&params.args);
1243
p.envs(&params.env);
1244
p.stdin(pipe_if!(stdin.is_some()));
1245
p.stdout(pipe_if!(stdin.is_some()));
1246
p.stderr(pipe_if!(stderr.is_some()));
1247
if let Some(cwd) = &params.cwd {
1248
p.current_dir(cwd);
1249
}
1250
1251
#[cfg(target_os = "windows")]
1252
p.creation_flags(winapi::um::winbase::CREATE_NO_WINDOW);
1253
1254
let mut p = p.spawn().map_err(CodeError::ProcessSpawnFailed)?;
1255
1256
let block_futs = FuturesUnordered::new();
1257
let poll_futs = FuturesUnordered::new();
1258
if let (Some(mut a), Some(mut b)) = (p.stdout.take(), stdout) {
1259
block_futs.push(async move { tokio::io::copy(&mut a, &mut b).await }.boxed());
1260
}
1261
if let (Some(mut a), Some(mut b)) = (p.stderr.take(), stderr) {
1262
block_futs.push(async move { tokio::io::copy(&mut a, &mut b).await }.boxed());
1263
}
1264
if let (Some(mut b), Some(mut a)) = (p.stdin.take(), stdin) {
1265
poll_futs.push(async move { tokio::io::copy(&mut a, &mut b).await }.boxed());
1266
}
1267
1268
wait_for_process_exit(log, &params.command, p, block_futs, poll_futs).await
1269
}
1270
1271
async fn handle_spawn_cli(
1272
log: &log::Logger,
1273
params: SpawnParams,
1274
mut protocol_in: DuplexStream,
1275
mut protocol_out: DuplexStream,
1276
mut log_out: DuplexStream,
1277
) -> Result<SpawnResult, AnyError> {
1278
debug!(
1279
log,
1280
"requested to spawn cli {} with args {:?}", params.command, params.args
1281
);
1282
1283
let mut p = new_tokio_command(&params.command);
1284
p.args(&params.args);
1285
1286
// CLI args to spawn a server; contracted with clients that they should _not_ provide these.
1287
p.arg("--verbose");
1288
p.arg("command-shell");
1289
1290
p.envs(&params.env);
1291
p.stdin(Stdio::piped());
1292
p.stdout(Stdio::piped());
1293
p.stderr(Stdio::piped());
1294
if let Some(cwd) = &params.cwd {
1295
p.current_dir(cwd);
1296
}
1297
1298
let mut p = p.spawn().map_err(CodeError::ProcessSpawnFailed)?;
1299
1300
let mut stdin = p.stdin.take().unwrap();
1301
let mut stdout = p.stdout.take().unwrap();
1302
let mut stderr = p.stderr.take().unwrap();
1303
1304
// Start handling logs while doing the handshake in case there's some kind of error
1305
let log_pump = tokio::spawn(async move { tokio::io::copy(&mut stdout, &mut log_out).await });
1306
1307
// note: intentionally do not wrap stdin in a bufreader, since we don't
1308
// want to read anything other than our handshake messages.
1309
if let Err(e) = spawn_do_child_authentication(log, &mut stdin, &mut stderr).await {
1310
warning!(log, "failed to authenticate with child process {}", e);
1311
let _ = p.kill().await;
1312
return Err(e.into());
1313
}
1314
1315
debug!(log, "cli authenticated, attaching stdio");
1316
let block_futs = FuturesUnordered::new();
1317
let poll_futs = FuturesUnordered::new();
1318
poll_futs.push(async move { tokio::io::copy(&mut protocol_in, &mut stdin).await }.boxed());
1319
block_futs.push(async move { tokio::io::copy(&mut stderr, &mut protocol_out).await }.boxed());
1320
block_futs.push(async move { log_pump.await.unwrap() }.boxed());
1321
1322
wait_for_process_exit(log, &params.command, p, block_futs, poll_futs).await
1323
}
1324
1325
type TokioCopyFuture = dyn futures::Future<Output = Result<u64, std::io::Error>> + Send;
1326
1327
async fn get_joined_result(
1328
mut process: tokio::process::Child,
1329
block_futs: FuturesUnordered<std::pin::Pin<Box<TokioCopyFuture>>>,
1330
) -> Result<std::process::ExitStatus, std::io::Error> {
1331
let (_, r) = tokio::join!(futures::future::join_all(block_futs), process.wait());
1332
r
1333
}
1334
1335
/// Wait for the process to exit and sends the spawn result. Waits until the
1336
/// `block_futs` and the process have exited, and polls the `poll_futs` while
1337
/// doing so.
1338
async fn wait_for_process_exit(
1339
log: &log::Logger,
1340
command: &str,
1341
process: tokio::process::Child,
1342
block_futs: FuturesUnordered<std::pin::Pin<Box<TokioCopyFuture>>>,
1343
poll_futs: FuturesUnordered<std::pin::Pin<Box<TokioCopyFuture>>>,
1344
) -> Result<SpawnResult, AnyError> {
1345
let joined = get_joined_result(process, block_futs);
1346
pin!(joined);
1347
1348
let r = tokio::select! {
1349
_ = futures::future::join_all(poll_futs) => joined.await,
1350
r = &mut joined => r,
1351
};
1352
1353
let r = match r {
1354
Ok(e) => SpawnResult {
1355
message: e.to_string(),
1356
exit_code: e.code().unwrap_or(-1),
1357
},
1358
Err(e) => SpawnResult {
1359
message: e.to_string(),
1360
exit_code: -1,
1361
},
1362
};
1363
1364
debug!(
1365
log,
1366
"spawned cli {} exited with code {}", command, r.exit_code
1367
);
1368
1369
Ok(r)
1370
}
1371
1372
async fn spawn_do_child_authentication(
1373
log: &log::Logger,
1374
stdin: &mut ChildStdin,
1375
stdout: &mut ChildStderr,
1376
) -> Result<(), CodeError> {
1377
let (msg_tx, msg_rx) = mpsc::unbounded_channel();
1378
let (shutdown_rx, shutdown) = new_barrier();
1379
let mut rpc = new_msgpack_rpc();
1380
let caller = rpc.get_caller(msg_tx);
1381
1382
let challenge_response = do_challenge_response_flow(caller, shutdown);
1383
let rpc = start_msgpack_rpc(
1384
rpc.methods(()).build(log.prefixed("client-auth")),
1385
stdout,
1386
stdin,
1387
msg_rx,
1388
shutdown_rx,
1389
);
1390
pin!(rpc);
1391
1392
tokio::select! {
1393
r = &mut rpc => {
1394
match r {
1395
// means shutdown happened cleanly already, we're good
1396
Ok(_) => Ok(()),
1397
Err(e) => Err(CodeError::ProcessSpawnHandshakeFailed(e))
1398
}
1399
},
1400
r = challenge_response => {
1401
r?;
1402
rpc.await.map(|_| ()).map_err(CodeError::ProcessSpawnFailed)
1403
}
1404
}
1405
}
1406
1407
async fn do_challenge_response_flow(
1408
caller: RpcCaller<MsgPackSerializer>,
1409
shutdown: BarrierOpener<()>,
1410
) -> Result<(), CodeError> {
1411
let challenge: ChallengeIssueResponse = caller
1412
.call(METHOD_CHALLENGE_ISSUE, EmptyObject {})
1413
.await
1414
.unwrap()
1415
.map_err(CodeError::TunnelRpcCallFailed)?;
1416
1417
let _: EmptyObject = caller
1418
.call(
1419
METHOD_CHALLENGE_VERIFY,
1420
ChallengeVerifyParams {
1421
response: sign_challenge(&challenge.challenge),
1422
},
1423
)
1424
.await
1425
.unwrap()
1426
.map_err(CodeError::TunnelRpcCallFailed)?;
1427
1428
shutdown.open(());
1429
1430
Ok(())
1431
}
1432
1433