Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
microsoft
GitHub Repository: microsoft/vscode
Path: blob/main/cli/src/tunnels/dev_tunnels.rs
3314 views
1
/*---------------------------------------------------------------------------------------------
2
* Copyright (c) Microsoft Corporation. All rights reserved.
3
* Licensed under the MIT License. See License.txt in the project root for license information.
4
*--------------------------------------------------------------------------------------------*/
5
use super::protocol::{self, PortPrivacy, PortProtocol};
6
use crate::auth;
7
use crate::constants::{IS_INTERACTIVE_CLI, PROTOCOL_VERSION_TAG, TUNNEL_SERVICE_USER_AGENT};
8
use crate::state::{LauncherPaths, PersistedState};
9
use crate::util::errors::{
10
wrap, AnyError, CodeError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed,
11
WrappedError,
12
};
13
use crate::util::input::prompt_placeholder;
14
use crate::{debug, info, log, spanf, trace, warning};
15
use async_trait::async_trait;
16
use futures::future::BoxFuture;
17
use futures::{FutureExt, TryFutureExt};
18
use lazy_static::lazy_static;
19
use rand::prelude::IteratorRandom;
20
use regex::Regex;
21
use reqwest::StatusCode;
22
use serde::{Deserialize, Serialize};
23
use std::sync::{Arc, Mutex};
24
use std::time::Duration;
25
use tokio::sync::{mpsc, watch};
26
use tunnels::connections::{ForwardedPortConnection, RelayTunnelHost};
27
use tunnels::contracts::{
28
Tunnel, TunnelAccessControl, TunnelPort, TunnelRelayTunnelEndpoint, PORT_TOKEN,
29
TUNNEL_ACCESS_SCOPES_CONNECT, TUNNEL_PROTOCOL_AUTO,
30
};
31
use tunnels::management::{
32
new_tunnel_management, HttpError, TunnelLocator, TunnelManagementClient, TunnelRequestOptions,
33
NO_REQUEST_OPTIONS,
34
};
35
36
static TUNNEL_COUNT_LIMIT_NAME: &str = "TunnelsPerUserPerLocation";
37
38
#[allow(dead_code)]
39
mod tunnel_flags {
40
use crate::{log, tunnels::wsl_detect::is_wsl_installed};
41
42
pub const IS_WSL_INSTALLED: u32 = 1 << 0;
43
pub const IS_WINDOWS: u32 = 1 << 1;
44
pub const IS_LINUX: u32 = 1 << 2;
45
pub const IS_MACOS: u32 = 1 << 3;
46
47
/// Creates a flag string for the tunnel
48
pub fn create(log: &log::Logger) -> String {
49
let mut flags = 0;
50
51
#[cfg(windows)]
52
{
53
flags |= IS_WINDOWS;
54
}
55
#[cfg(target_os = "linux")]
56
{
57
flags |= IS_LINUX;
58
}
59
#[cfg(target_os = "macos")]
60
{
61
flags |= IS_MACOS;
62
}
63
64
if is_wsl_installed(log) {
65
flags |= IS_WSL_INSTALLED;
66
}
67
68
format!("_flag{flags}")
69
}
70
}
71
72
#[derive(Clone, Serialize, Deserialize)]
73
pub struct PersistedTunnel {
74
pub name: String,
75
pub id: String,
76
pub cluster: String,
77
}
78
79
impl PersistedTunnel {
80
pub fn into_locator(self) -> TunnelLocator {
81
TunnelLocator::ID {
82
cluster: self.cluster,
83
id: self.id,
84
}
85
}
86
pub fn locator(&self) -> TunnelLocator {
87
TunnelLocator::ID {
88
cluster: self.cluster.clone(),
89
id: self.id.clone(),
90
}
91
}
92
}
93
94
#[async_trait]
95
trait AccessTokenProvider: Send + Sync {
96
/// Gets the current access token.
97
async fn refresh_token(&self) -> Result<String, WrappedError>;
98
99
/// Maintains the stored credential by refreshing it against the service
100
/// to ensure its stays current. Returns a future that should be polled and
101
/// only completes if a refresh fails in a consistent way.
102
fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>>;
103
}
104
105
/// Access token provider that provides a fixed token without refreshing.
106
struct StaticAccessTokenProvider(String);
107
108
impl StaticAccessTokenProvider {
109
pub fn new(token: String) -> Self {
110
Self(token)
111
}
112
}
113
114
#[async_trait]
115
impl AccessTokenProvider for StaticAccessTokenProvider {
116
async fn refresh_token(&self) -> Result<String, WrappedError> {
117
Ok(self.0.clone())
118
}
119
120
fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>> {
121
futures::future::pending().boxed()
122
}
123
}
124
125
/// Access token provider that looks up the token from the tunnels API.
126
struct LookupAccessTokenProvider {
127
auth: auth::Auth,
128
client: TunnelManagementClient,
129
locator: TunnelLocator,
130
log: log::Logger,
131
initial_token: Arc<Mutex<Option<String>>>,
132
}
133
134
impl LookupAccessTokenProvider {
135
pub fn new(
136
auth: auth::Auth,
137
client: TunnelManagementClient,
138
locator: TunnelLocator,
139
log: log::Logger,
140
initial_token: Option<String>,
141
) -> Self {
142
Self {
143
auth,
144
client,
145
locator,
146
log,
147
initial_token: Arc::new(Mutex::new(initial_token)),
148
}
149
}
150
}
151
152
#[async_trait]
153
impl AccessTokenProvider for LookupAccessTokenProvider {
154
async fn refresh_token(&self) -> Result<String, WrappedError> {
155
if let Some(token) = self.initial_token.lock().unwrap().take() {
156
return Ok(token);
157
}
158
159
let tunnel_lookup = spanf!(
160
self.log,
161
self.log.span("dev-tunnel.tag.get"),
162
self.client.get_tunnel(
163
&self.locator,
164
&TunnelRequestOptions {
165
token_scopes: vec!["host".to_string()],
166
..Default::default()
167
}
168
)
169
);
170
171
trace!(self.log, "Successfully refreshed access token");
172
173
match tunnel_lookup {
174
Ok(tunnel) => Ok(get_host_token_from_tunnel(&tunnel)),
175
Err(e) => Err(wrap(e, "failed to lookup tunnel for host token")),
176
}
177
}
178
179
fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>> {
180
let auth = self.auth.clone();
181
auth.keep_token_alive().boxed()
182
}
183
}
184
185
#[derive(Clone)]
186
pub struct DevTunnels {
187
auth: auth::Auth,
188
log: log::Logger,
189
launcher_tunnel: PersistedState<Option<PersistedTunnel>>,
190
client: TunnelManagementClient,
191
tag: &'static str,
192
}
193
194
/// Representation of a tunnel returned from the `start` methods.
195
pub struct ActiveTunnel {
196
/// Name of the tunnel
197
pub name: String,
198
/// Underlying dev tunnels ID
199
pub id: String,
200
manager: ActiveTunnelManager,
201
}
202
203
impl ActiveTunnel {
204
/// Closes and unregisters the tunnel.
205
pub async fn close(&mut self) -> Result<(), AnyError> {
206
self.manager.kill().await?;
207
Ok(())
208
}
209
210
/// Forwards a port to local connections.
211
pub async fn add_port_direct(
212
&mut self,
213
port_number: u16,
214
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, AnyError> {
215
let port = self.manager.add_port_direct(port_number).await?;
216
Ok(port)
217
}
218
219
/// Forwards a port over TCP.
220
pub async fn add_port_tcp(
221
&self,
222
port_number: u16,
223
privacy: PortPrivacy,
224
protocol: PortProtocol,
225
) -> Result<(), AnyError> {
226
self.manager
227
.add_port_tcp(port_number, privacy, protocol)
228
.await?;
229
Ok(())
230
}
231
232
/// Removes a forwarded port TCP.
233
pub async fn remove_port(&self, port_number: u16) -> Result<(), AnyError> {
234
self.manager.remove_port(port_number).await?;
235
Ok(())
236
}
237
238
/// Gets the template string for forming forwarded port web URIs..
239
pub fn get_port_format(&self) -> Result<String, AnyError> {
240
if let Some(details) = &*self.manager.endpoint_rx.borrow() {
241
return details
242
.as_ref()
243
.map(|r| {
244
r.base
245
.port_uri_format
246
.clone()
247
.expect("expected to have port format")
248
})
249
.map_err(|e| e.clone().into());
250
}
251
252
Err(CodeError::NoTunnelEndpoint.into())
253
}
254
255
/// Gets the public URI on which a forwarded port can be access in browser.
256
pub fn get_port_uri(&self, port: u16) -> Result<String, AnyError> {
257
self.get_port_format()
258
.map(|f| f.replace(PORT_TOKEN, &port.to_string()))
259
}
260
261
/// Gets an object to read the current tunnel status.
262
pub fn status(&self) -> StatusLock {
263
self.manager.get_status()
264
}
265
}
266
267
const VSCODE_CLI_TUNNEL_TAG: &str = "vscode-server-launcher";
268
const VSCODE_CLI_FORWARDING_TAG: &str = "vscode-port-forward";
269
const OWNED_TUNNEL_TAGS: &[&str] = &[VSCODE_CLI_TUNNEL_TAG, VSCODE_CLI_FORWARDING_TAG];
270
const MAX_TUNNEL_NAME_LENGTH: usize = 20;
271
272
fn get_host_token_from_tunnel(tunnel: &Tunnel) -> String {
273
tunnel
274
.access_tokens
275
.as_ref()
276
.expect("expected to have access tokens")
277
.get("host")
278
.expect("expected to have host token")
279
.to_string()
280
}
281
282
fn is_valid_name(name: &str) -> Result<(), InvalidTunnelName> {
283
if name.len() > MAX_TUNNEL_NAME_LENGTH {
284
return Err(InvalidTunnelName(format!(
285
"Names cannot be longer than {MAX_TUNNEL_NAME_LENGTH} characters. Please try a different name."
286
)));
287
}
288
289
let re = Regex::new(r"^([\w-]+)$").unwrap();
290
291
if !re.is_match(name) {
292
return Err(InvalidTunnelName(
293
"Names can only contain letters, numbers, and '-'. Spaces, commas, and all other special characters are not allowed. Please try a different name.".to_string()
294
));
295
}
296
297
Ok(())
298
}
299
300
lazy_static! {
301
static ref HOST_TUNNEL_REQUEST_OPTIONS: TunnelRequestOptions = TunnelRequestOptions {
302
include_ports: true,
303
token_scopes: vec!["host".to_string()],
304
..Default::default()
305
};
306
}
307
308
/// Structure optionally passed into `start_existing_tunnel` to forward an existing tunnel.
309
#[derive(Clone, Debug)]
310
pub struct ExistingTunnel {
311
/// Name you'd like to assign preexisting tunnel to use to connect to the VS Code Server
312
pub tunnel_name: Option<String>,
313
314
/// Token to authenticate and use preexisting tunnel
315
pub host_token: String,
316
317
/// Id of preexisting tunnel to use to connect to the VS Code Server
318
pub tunnel_id: String,
319
320
/// Cluster of preexisting tunnel to use to connect to the VS Code Server
321
pub cluster: String,
322
}
323
324
impl DevTunnels {
325
/// Creates a new DevTunnels client used for port forwarding.
326
pub fn new_port_forwarding(
327
log: &log::Logger,
328
auth: auth::Auth,
329
paths: &LauncherPaths,
330
) -> DevTunnels {
331
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
332
client.authorization_provider(auth.clone());
333
334
DevTunnels {
335
auth,
336
log: log.clone(),
337
client: client.into(),
338
launcher_tunnel: PersistedState::new(paths.root().join("port_forwarding_tunnel.json")),
339
tag: VSCODE_CLI_FORWARDING_TAG,
340
}
341
}
342
343
/// Creates a new DevTunnels client used for the Remote Tunnels extension to access the VS Code Server.
344
pub fn new_remote_tunnel(
345
log: &log::Logger,
346
auth: auth::Auth,
347
paths: &LauncherPaths,
348
) -> DevTunnels {
349
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
350
client.authorization_provider(auth.clone());
351
352
DevTunnels {
353
auth,
354
log: log.clone(),
355
client: client.into(),
356
launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")),
357
tag: VSCODE_CLI_TUNNEL_TAG,
358
}
359
}
360
361
pub async fn remove_tunnel(&mut self) -> Result<(), AnyError> {
362
let tunnel = match self.launcher_tunnel.load() {
363
Some(t) => t,
364
None => {
365
return Ok(());
366
}
367
};
368
369
spanf!(
370
self.log,
371
self.log.span("dev-tunnel.delete"),
372
self.client
373
.delete_tunnel(&tunnel.into_locator(), NO_REQUEST_OPTIONS)
374
)
375
.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
376
377
self.launcher_tunnel.save(None)?;
378
Ok(())
379
}
380
381
/// Renames the current tunnel to the new name.
382
pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> {
383
self.update_tunnel_name(self.launcher_tunnel.load(), name)
384
.await
385
.map(|_| ())
386
}
387
388
/// Updates the name of the existing persisted tunnel to the new name.
389
/// Gracefully creates a new tunnel if the previous one was deleted.
390
async fn update_tunnel_name(
391
&mut self,
392
persisted: Option<PersistedTunnel>,
393
name: &str,
394
) -> Result<(Tunnel, PersistedTunnel), AnyError> {
395
let name = name.to_ascii_lowercase();
396
397
let (mut full_tunnel, mut persisted, is_new) = match persisted {
398
Some(persisted) => {
399
debug!(
400
self.log,
401
"Found a persisted tunnel, seeing if the name matches..."
402
);
403
self.get_or_create_tunnel(persisted, Some(&name), NO_REQUEST_OPTIONS)
404
.await
405
}
406
None => {
407
debug!(self.log, "Creating a new tunnel with the requested name");
408
self.create_tunnel(&name, NO_REQUEST_OPTIONS)
409
.await
410
.map(|(pt, t)| (t, pt, true))
411
}
412
}?;
413
414
let desired_tags = self.get_labels(&name);
415
if is_new || vec_eq_as_set(&full_tunnel.labels, &desired_tags) {
416
return Ok((full_tunnel, persisted));
417
}
418
419
debug!(self.log, "Tunnel name changed, applying updates...");
420
421
full_tunnel.labels = desired_tags;
422
423
let updated_tunnel = spanf!(
424
self.log,
425
self.log.span("dev-tunnel.tag.update"),
426
self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS)
427
)
428
.map_err(|e| wrap(e, "failed to rename tunnel"))?;
429
430
persisted.name = name;
431
self.launcher_tunnel.save(Some(persisted.clone()))?;
432
433
Ok((updated_tunnel, persisted))
434
}
435
436
/// Gets the persisted tunnel from the service, or creates a new one.
437
/// If `create_with_new_name` is given, the new tunnel has that name
438
/// instead of the one previously persisted.
439
async fn get_or_create_tunnel(
440
&mut self,
441
persisted: PersistedTunnel,
442
create_with_new_name: Option<&str>,
443
options: &TunnelRequestOptions,
444
) -> Result<(Tunnel, PersistedTunnel, /* is_new */ bool), AnyError> {
445
let tunnel_lookup = spanf!(
446
self.log,
447
self.log.span("dev-tunnel.tag.get"),
448
self.client.get_tunnel(&persisted.locator(), options)
449
);
450
451
match tunnel_lookup {
452
Ok(ft) => Ok((ft, persisted, false)),
453
Err(HttpError::ResponseError(e))
454
if e.status_code == StatusCode::NOT_FOUND
455
|| e.status_code == StatusCode::FORBIDDEN =>
456
{
457
let (persisted, tunnel) = self
458
.create_tunnel(create_with_new_name.unwrap_or(&persisted.name), options)
459
.await?;
460
Ok((tunnel, persisted, true))
461
}
462
Err(e) => Err(wrap(e, "failed to lookup tunnel").into()),
463
}
464
}
465
466
/// Starts a new tunnel for the code server on the port. Unlike `start_new_tunnel`,
467
/// this attempts to reuse or create a tunnel of a preferred name or of a generated friendly tunnel name.
468
pub async fn start_new_launcher_tunnel(
469
&mut self,
470
preferred_name: Option<&str>,
471
use_random_name: bool,
472
preserve_ports: &[u16],
473
) -> Result<ActiveTunnel, AnyError> {
474
let (mut tunnel, persisted) = match self.launcher_tunnel.load() {
475
Some(mut persisted) => {
476
if let Some(preferred_name) = preferred_name.map(|n| n.to_ascii_lowercase()) {
477
if persisted.name.to_ascii_lowercase() != preferred_name {
478
(_, persisted) = self
479
.update_tunnel_name(Some(persisted), &preferred_name)
480
.await?;
481
}
482
}
483
484
let (tunnel, persisted, _) = self
485
.get_or_create_tunnel(persisted, None, &HOST_TUNNEL_REQUEST_OPTIONS)
486
.await?;
487
(tunnel, persisted)
488
}
489
None => {
490
debug!(self.log, "No code server tunnel found, creating new one");
491
let name = self
492
.get_name_for_tunnel(preferred_name, use_random_name)
493
.await?;
494
let (persisted, full_tunnel) = self
495
.create_tunnel(&name, &HOST_TUNNEL_REQUEST_OPTIONS)
496
.await?;
497
(full_tunnel, persisted)
498
}
499
};
500
501
tunnel = self
502
.sync_tunnel_tags(
503
&self.client,
504
&persisted.name,
505
tunnel,
506
&HOST_TUNNEL_REQUEST_OPTIONS,
507
)
508
.await?;
509
510
let locator = TunnelLocator::try_from(&tunnel).unwrap();
511
let host_token = get_host_token_from_tunnel(&tunnel);
512
513
for port_to_delete in tunnel
514
.ports
515
.iter()
516
.filter(|p: &&TunnelPort| !preserve_ports.contains(&p.port_number))
517
{
518
let output_fut = self.client.delete_tunnel_port(
519
&locator,
520
port_to_delete.port_number,
521
NO_REQUEST_OPTIONS,
522
);
523
spanf!(
524
self.log,
525
self.log.span("dev-tunnel.port.delete"),
526
output_fut
527
)
528
.map_err(|e| wrap(e, "failed to delete port"))?;
529
}
530
531
// cleanup any old trailing tunnel endpoints
532
for endpoint in tunnel.endpoints {
533
let fut = self.client.delete_tunnel_endpoints(
534
&locator,
535
&endpoint.host_id,
536
NO_REQUEST_OPTIONS,
537
);
538
539
spanf!(self.log, self.log.span("dev-tunnel.endpoint.prune"), fut)
540
.map_err(|e| wrap(e, "failed to prune tunnel endpoint"))?;
541
}
542
543
self.start_tunnel(
544
locator.clone(),
545
&persisted,
546
self.client.clone(),
547
LookupAccessTokenProvider::new(
548
self.auth.clone(),
549
self.client.clone(),
550
locator,
551
self.log.clone(),
552
Some(host_token),
553
),
554
)
555
.await
556
}
557
558
async fn create_tunnel(
559
&mut self,
560
name: &str,
561
options: &TunnelRequestOptions,
562
) -> Result<(PersistedTunnel, Tunnel), AnyError> {
563
info!(self.log, "Creating tunnel with the name: {}", name);
564
565
let tunnel = match self.get_existing_tunnel_with_name(name).await? {
566
Some(e) => {
567
if tunnel_has_host_connection(&e) {
568
return Err(CodeError::TunnelActiveAndInUse(name.to_string()).into());
569
}
570
571
let loc = TunnelLocator::try_from(&e).unwrap();
572
info!(self.log, "Adopting existing tunnel (ID={:?})", loc);
573
spanf!(
574
self.log,
575
self.log.span("dev-tunnel.tag.get"),
576
self.client.get_tunnel(&loc, &HOST_TUNNEL_REQUEST_OPTIONS)
577
)
578
.map_err(|e| wrap(e, "failed to lookup tunnel"))?
579
}
580
None => loop {
581
let result = spanf!(
582
self.log,
583
self.log.span("dev-tunnel.create"),
584
self.client.create_tunnel(
585
Tunnel {
586
labels: self.get_labels(name),
587
..Default::default()
588
},
589
options
590
)
591
);
592
593
match result {
594
Err(HttpError::ResponseError(e))
595
if e.status_code == StatusCode::TOO_MANY_REQUESTS =>
596
{
597
if let Some(d) = e.get_details() {
598
let detail = d.detail.unwrap_or_else(|| "unknown".to_string());
599
if detail.contains(TUNNEL_COUNT_LIMIT_NAME)
600
&& self.try_recycle_tunnel().await?
601
{
602
continue;
603
}
604
605
return Err(AnyError::from(TunnelCreationFailed(
606
name.to_string(),
607
detail,
608
)));
609
}
610
611
return Err(AnyError::from(TunnelCreationFailed(
612
name.to_string(),
613
"You have exceeded a limit for the port fowarding service. Please remove other machines before trying to add this machine.".to_string(),
614
)));
615
}
616
Err(e) => {
617
return Err(AnyError::from(TunnelCreationFailed(
618
name.to_string(),
619
format!("{e:?}"),
620
)))
621
}
622
Ok(t) => break t,
623
}
624
},
625
};
626
627
let pt = PersistedTunnel {
628
cluster: tunnel.cluster_id.clone().unwrap(),
629
id: tunnel.tunnel_id.clone().unwrap(),
630
name: name.to_string(),
631
};
632
633
self.launcher_tunnel.save(Some(pt.clone()))?;
634
Ok((pt, tunnel))
635
}
636
637
/// Gets the expected tunnel tags
638
fn get_labels(&self, name: &str) -> Vec<String> {
639
vec![
640
name.to_string(),
641
PROTOCOL_VERSION_TAG.to_string(),
642
self.tag.to_string(),
643
tunnel_flags::create(&self.log),
644
]
645
}
646
647
/// Ensures the tunnel contains a tag for the current PROTCOL_VERSION, and no
648
/// other version tags.
649
async fn sync_tunnel_tags(
650
&self,
651
client: &TunnelManagementClient,
652
name: &str,
653
tunnel: Tunnel,
654
options: &TunnelRequestOptions,
655
) -> Result<Tunnel, AnyError> {
656
let new_labels = self.get_labels(name);
657
if vec_eq_as_set(&tunnel.labels, &new_labels) {
658
return Ok(tunnel);
659
}
660
661
debug!(
662
self.log,
663
"Updating tunnel tags {} -> {}",
664
tunnel.labels.join(", "),
665
new_labels.join(", ")
666
);
667
668
let tunnel_update = Tunnel {
669
labels: new_labels,
670
tunnel_id: tunnel.tunnel_id.clone(),
671
cluster_id: tunnel.cluster_id.clone(),
672
..Default::default()
673
};
674
675
let result = spanf!(
676
self.log,
677
self.log.span("dev-tunnel.protocol-tag-update"),
678
client.update_tunnel(&tunnel_update, options)
679
);
680
681
result.map_err(|e| wrap(e, "tunnel tag update failed").into())
682
}
683
684
/// Tries to delete an unused tunnel, and then creates a tunnel with the
685
/// given `new_name`.
686
async fn try_recycle_tunnel(&mut self) -> Result<bool, AnyError> {
687
trace!(
688
self.log,
689
"Tunnel limit hit, trying to recycle an old tunnel"
690
);
691
692
let existing_tunnels = self.list_tunnels_with_tag(OWNED_TUNNEL_TAGS).await?;
693
694
let recyclable = existing_tunnels
695
.iter()
696
.filter(|t| !tunnel_has_host_connection(t))
697
.choose(&mut rand::thread_rng());
698
699
match recyclable {
700
Some(tunnel) => {
701
trace!(self.log, "Recycling tunnel ID {:?}", tunnel.tunnel_id);
702
spanf!(
703
self.log,
704
self.log.span("dev-tunnel.delete"),
705
self.client
706
.delete_tunnel(&tunnel.try_into().unwrap(), NO_REQUEST_OPTIONS)
707
)
708
.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
709
Ok(true)
710
}
711
None => {
712
trace!(self.log, "No tunnels available to recycle");
713
Ok(false)
714
}
715
}
716
}
717
718
async fn list_tunnels_with_tag(
719
&mut self,
720
tags: &[&'static str],
721
) -> Result<Vec<Tunnel>, AnyError> {
722
let tunnels = spanf!(
723
self.log,
724
self.log.span("dev-tunnel.listall"),
725
self.client.list_all_tunnels(&TunnelRequestOptions {
726
labels: tags.iter().map(|t| t.to_string()).collect(),
727
..Default::default()
728
})
729
)
730
.map_err(|e| wrap(e, "error listing current tunnels"))?;
731
732
Ok(tunnels)
733
}
734
735
async fn get_existing_tunnel_with_name(&self, name: &str) -> Result<Option<Tunnel>, AnyError> {
736
let existing: Vec<Tunnel> = spanf!(
737
self.log,
738
self.log.span("dev-tunnel.rename.search"),
739
self.client.list_all_tunnels(&TunnelRequestOptions {
740
labels: vec![self.tag.to_string(), name.to_string()],
741
require_all_labels: true,
742
limit: 1,
743
include_ports: true,
744
token_scopes: vec!["host".to_string()],
745
..Default::default()
746
})
747
)
748
.map_err(|e| wrap(e, "failed to list existing tunnels"))?;
749
750
Ok(existing.into_iter().next())
751
}
752
753
fn get_placeholder_name() -> String {
754
let mut n = clean_hostname_for_tunnel(&gethostname::gethostname().to_string_lossy());
755
n.make_ascii_lowercase();
756
n.truncate(MAX_TUNNEL_NAME_LENGTH);
757
n
758
}
759
760
async fn get_name_for_tunnel(
761
&mut self,
762
preferred_name: Option<&str>,
763
mut use_random_name: bool,
764
) -> Result<String, AnyError> {
765
let existing_tunnels = self.list_tunnels_with_tag(&[self.tag]).await?;
766
let is_name_free = |n: &str| {
767
!existing_tunnels
768
.iter()
769
.any(|v| tunnel_has_host_connection(v) && v.labels.iter().any(|t| t == n))
770
};
771
772
if let Some(machine_name) = preferred_name {
773
let name = machine_name.to_ascii_lowercase();
774
if let Err(e) = is_valid_name(&name) {
775
info!(self.log, "{} is an invalid name", e);
776
return Err(AnyError::from(wrap(e, "invalid name")));
777
}
778
if is_name_free(&name) {
779
return Ok(name);
780
}
781
info!(
782
self.log,
783
"{} is already taken, using a random name instead", &name
784
);
785
use_random_name = true;
786
}
787
788
let mut placeholder_name = Self::get_placeholder_name();
789
if !is_name_free(&placeholder_name) {
790
for i in 2.. {
791
let fixed_name = format!("{placeholder_name}{i}");
792
if is_name_free(&fixed_name) {
793
placeholder_name = fixed_name;
794
break;
795
}
796
}
797
}
798
799
if use_random_name || !*IS_INTERACTIVE_CLI {
800
return Ok(placeholder_name);
801
}
802
803
loop {
804
let mut name = prompt_placeholder(
805
"What would you like to call this machine?",
806
&placeholder_name,
807
)?;
808
809
name.make_ascii_lowercase();
810
811
if let Err(e) = is_valid_name(&name) {
812
info!(self.log, "{}", e);
813
continue;
814
}
815
816
if is_name_free(&name) {
817
return Ok(name);
818
}
819
820
info!(self.log, "The name {} is already in use", name);
821
}
822
}
823
824
/// Hosts an existing tunnel, where the tunnel ID and host token are given.
825
pub async fn start_existing_tunnel(
826
&mut self,
827
tunnel: ExistingTunnel,
828
) -> Result<ActiveTunnel, AnyError> {
829
let tunnel_details = PersistedTunnel {
830
name: match tunnel.tunnel_name {
831
Some(n) => n,
832
None => Self::get_placeholder_name(),
833
},
834
id: tunnel.tunnel_id,
835
cluster: tunnel.cluster,
836
};
837
838
let mut mgmt = self.client.build();
839
mgmt.authorization(tunnels::management::Authorization::Tunnel(
840
tunnel.host_token.clone(),
841
));
842
843
let client = mgmt.into();
844
self.sync_tunnel_tags(
845
&client,
846
&tunnel_details.name,
847
Tunnel {
848
cluster_id: Some(tunnel_details.cluster.clone()),
849
tunnel_id: Some(tunnel_details.id.clone()),
850
..Default::default()
851
},
852
&HOST_TUNNEL_REQUEST_OPTIONS,
853
)
854
.await?;
855
856
self.start_tunnel(
857
tunnel_details.locator(),
858
&tunnel_details,
859
client,
860
StaticAccessTokenProvider::new(tunnel.host_token),
861
)
862
.await
863
}
864
865
async fn start_tunnel(
866
&mut self,
867
locator: TunnelLocator,
868
tunnel_details: &PersistedTunnel,
869
client: TunnelManagementClient,
870
access_token: impl AccessTokenProvider + 'static,
871
) -> Result<ActiveTunnel, AnyError> {
872
let mut manager = ActiveTunnelManager::new(self.log.clone(), client, locator, access_token);
873
874
let endpoint_result = spanf!(
875
self.log,
876
self.log.span("dev-tunnel.serve.callback"),
877
manager.get_endpoint()
878
);
879
880
let endpoint = match endpoint_result {
881
Ok(endpoint) => endpoint,
882
Err(e) => {
883
error!(self.log, "Error connecting to tunnel endpoint: {}", e);
884
manager.kill().await.ok();
885
return Err(e);
886
}
887
};
888
889
debug!(self.log, "Connected to tunnel endpoint: {:?}", endpoint);
890
891
Ok(ActiveTunnel {
892
name: tunnel_details.name.clone(),
893
id: tunnel_details.id.clone(),
894
manager,
895
})
896
}
897
}
898
899
#[derive(Clone, Default)]
900
pub struct StatusLock(Arc<std::sync::Mutex<protocol::singleton::Status>>);
901
902
impl StatusLock {
903
fn succeed(&self) {
904
let mut status = self.0.lock().unwrap();
905
status.tunnel = protocol::singleton::TunnelState::Connected;
906
status.last_connected_at = Some(chrono::Utc::now());
907
}
908
909
fn fail(&self, reason: String) {
910
let mut status = self.0.lock().unwrap();
911
if let protocol::singleton::TunnelState::Connected = status.tunnel {
912
status.last_disconnected_at = Some(chrono::Utc::now());
913
status.tunnel = protocol::singleton::TunnelState::Disconnected;
914
}
915
status.last_fail_reason = Some(reason);
916
}
917
918
pub fn read(&self) -> protocol::singleton::Status {
919
let status = self.0.lock().unwrap();
920
status.clone()
921
}
922
}
923
924
struct ActiveTunnelManager {
925
close_tx: Option<mpsc::Sender<()>>,
926
endpoint_rx: watch::Receiver<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
927
relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
928
status: StatusLock,
929
}
930
931
impl ActiveTunnelManager {
932
pub fn new(
933
log: log::Logger,
934
mgmt: TunnelManagementClient,
935
locator: TunnelLocator,
936
access_token: impl AccessTokenProvider + 'static,
937
) -> ActiveTunnelManager {
938
let (endpoint_tx, endpoint_rx) = watch::channel(None);
939
let (close_tx, close_rx) = mpsc::channel(1);
940
941
let relay = Arc::new(tokio::sync::Mutex::new(RelayTunnelHost::new(locator, mgmt)));
942
let relay_spawned = relay.clone();
943
944
let status = StatusLock::default();
945
946
let status_spawned = status.clone();
947
tokio::spawn(async move {
948
ActiveTunnelManager::spawn_tunnel(
949
log,
950
relay_spawned,
951
close_rx,
952
endpoint_tx,
953
access_token,
954
status_spawned,
955
)
956
.await;
957
});
958
959
ActiveTunnelManager {
960
endpoint_rx,
961
relay,
962
close_tx: Some(close_tx),
963
status,
964
}
965
}
966
967
/// Gets a copy of the current tunnel status information
968
pub fn get_status(&self) -> StatusLock {
969
self.status.clone()
970
}
971
972
/// Adds a port for TCP/IP forwarding.
973
pub async fn add_port_tcp(
974
&self,
975
port_number: u16,
976
privacy: PortPrivacy,
977
protocol: PortProtocol,
978
) -> Result<(), WrappedError> {
979
self.relay
980
.lock()
981
.await
982
.add_port(&TunnelPort {
983
port_number,
984
protocol: Some(protocol.to_contract_str().to_string()),
985
access_control: Some(privacy_to_tunnel_acl(privacy)),
986
..Default::default()
987
})
988
.await
989
.map_err(|e| wrap(e, "error adding port to relay"))?;
990
Ok(())
991
}
992
993
/// Adds a port for TCP/IP forwarding.
994
pub async fn add_port_direct(
995
&self,
996
port_number: u16,
997
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, WrappedError> {
998
self.relay
999
.lock()
1000
.await
1001
.add_port_raw(&TunnelPort {
1002
port_number,
1003
protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
1004
access_control: Some(privacy_to_tunnel_acl(PortPrivacy::Private)),
1005
..Default::default()
1006
})
1007
.await
1008
.map_err(|e| wrap(e, "error adding port to relay"))
1009
}
1010
1011
/// Removes a port from TCP/IP forwarding.
1012
pub async fn remove_port(&self, port_number: u16) -> Result<(), WrappedError> {
1013
self.relay
1014
.lock()
1015
.await
1016
.remove_port(port_number)
1017
.await
1018
.map_err(|e| wrap(e, "error remove port from relay"))
1019
}
1020
1021
/// Gets the most recent details from the tunnel process. Returns None if
1022
/// the process exited before providing details.
1023
pub async fn get_endpoint(&mut self) -> Result<TunnelRelayTunnelEndpoint, AnyError> {
1024
loop {
1025
if let Some(details) = &*self.endpoint_rx.borrow() {
1026
return details.clone().map_err(AnyError::from);
1027
}
1028
1029
if self.endpoint_rx.changed().await.is_err() {
1030
return Err(DevTunnelError("tunnel creation cancelled".to_string()).into());
1031
}
1032
}
1033
}
1034
1035
/// Kills the process, and waits for it to exit.
1036
/// See https://tokio.rs/tokio/topics/shutdown#waiting-for-things-to-finish-shutting-down for how this works
1037
pub async fn kill(&mut self) -> Result<(), AnyError> {
1038
if let Some(tx) = self.close_tx.take() {
1039
drop(tx);
1040
}
1041
1042
self.relay
1043
.lock()
1044
.await
1045
.unregister()
1046
.await
1047
.map_err(|e| wrap(e, "error unregistering relay"))?;
1048
1049
while self.endpoint_rx.changed().await.is_ok() {}
1050
1051
Ok(())
1052
}
1053
1054
async fn spawn_tunnel(
1055
log: log::Logger,
1056
relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
1057
mut close_rx: mpsc::Receiver<()>,
1058
endpoint_tx: watch::Sender<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
1059
access_token_provider: impl AccessTokenProvider + 'static,
1060
status: StatusLock,
1061
) {
1062
let mut token_ka = access_token_provider.keep_alive();
1063
let mut backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(120));
1064
1065
macro_rules! fail {
1066
($e: expr, $msg: expr) => {
1067
let fmt = format!("{}: {}", $msg, $e);
1068
warning!(log, &fmt);
1069
status.fail(fmt);
1070
endpoint_tx.send(Some(Err($e))).ok();
1071
backoff.delay().await;
1072
};
1073
}
1074
1075
loop {
1076
debug!(log, "Starting tunnel to server...");
1077
1078
let access_token = match access_token_provider.refresh_token().await {
1079
Ok(t) => t,
1080
Err(e) => {
1081
fail!(e, "Error refreshing access token, will retry");
1082
continue;
1083
}
1084
};
1085
1086
// we don't bother making a client that can refresh the token, since
1087
// the tunnel won't be able to host as soon as the access token expires.
1088
let handle_res = {
1089
let mut relay = relay.lock().await;
1090
relay
1091
.connect(&access_token)
1092
.await
1093
.map_err(|e| wrap(e, "error connecting to tunnel"))
1094
};
1095
1096
let mut handle = match handle_res {
1097
Ok(handle) => handle,
1098
Err(e) => {
1099
fail!(e, "Error connecting to relay, will retry");
1100
continue;
1101
}
1102
};
1103
1104
backoff.reset();
1105
status.succeed();
1106
endpoint_tx.send(Some(Ok(handle.endpoint().clone()))).ok();
1107
1108
tokio::select! {
1109
// error is mapped like this prevent it being used across an await,
1110
// which Rust dislikes since there's a non-sendable dyn Error in there
1111
res = (&mut handle).map_err(|e| wrap(e, "error from tunnel connection")) => {
1112
if let Err(e) = res {
1113
fail!(e, "Tunnel exited unexpectedly, reconnecting");
1114
} else {
1115
warning!(log, "Tunnel exited unexpectedly but gracefully, reconnecting");
1116
backoff.delay().await;
1117
}
1118
},
1119
Err(e) = &mut token_ka => {
1120
error!(log, "access token is no longer valid, exiting: {}", e);
1121
return;
1122
},
1123
_ = close_rx.recv() => {
1124
trace!(log, "Tunnel closing gracefully");
1125
trace!(log, "Tunnel closed with result: {:?}", handle.close().await);
1126
break;
1127
}
1128
}
1129
}
1130
}
1131
}
1132
1133
struct Backoff {
1134
failures: u32,
1135
base_duration: Duration,
1136
max_duration: Duration,
1137
}
1138
1139
impl Backoff {
1140
pub fn new(base_duration: Duration, max_duration: Duration) -> Self {
1141
Self {
1142
failures: 0,
1143
base_duration,
1144
max_duration,
1145
}
1146
}
1147
1148
pub async fn delay(&mut self) {
1149
tokio::time::sleep(self.next()).await
1150
}
1151
1152
pub fn next(&mut self) -> Duration {
1153
self.failures += 1;
1154
let duration = self
1155
.base_duration
1156
.checked_mul(self.failures)
1157
.unwrap_or(self.max_duration);
1158
std::cmp::min(duration, self.max_duration)
1159
}
1160
1161
pub fn reset(&mut self) {
1162
self.failures = 0;
1163
}
1164
}
1165
1166
/// Cleans up the hostname so it can be used as a tunnel name.
1167
/// See TUNNEL_NAME_PATTERN in the tunnels SDK for the rules we try to use.
1168
fn clean_hostname_for_tunnel(hostname: &str) -> String {
1169
let mut out = String::new();
1170
for char in hostname.chars().take(60) {
1171
match char {
1172
'-' | '_' | ' ' => {
1173
out.push('-');
1174
}
1175
'0'..='9' | 'a'..='z' | 'A'..='Z' => {
1176
out.push(char);
1177
}
1178
_ => {}
1179
}
1180
}
1181
1182
let trimmed = out.trim_matches('-');
1183
if trimmed.len() < 2 {
1184
"remote-machine".to_string() // placeholder if the result was empty
1185
} else {
1186
trimmed.to_owned()
1187
}
1188
}
1189
1190
fn vec_eq_as_set(a: &[String], b: &[String]) -> bool {
1191
if a.len() != b.len() {
1192
return false;
1193
}
1194
1195
for item in a {
1196
if !b.contains(item) {
1197
return false;
1198
}
1199
}
1200
1201
true
1202
}
1203
1204
fn privacy_to_tunnel_acl(privacy: PortPrivacy) -> TunnelAccessControl {
1205
TunnelAccessControl {
1206
entries: vec![match privacy {
1207
PortPrivacy::Public => tunnels::contracts::TunnelAccessControlEntry {
1208
kind: tunnels::contracts::TunnelAccessControlEntryType::Anonymous,
1209
provider: None,
1210
is_inherited: false,
1211
is_deny: false,
1212
is_inverse: false,
1213
organization: None,
1214
expiration: None,
1215
subjects: vec![],
1216
scopes: vec![TUNNEL_ACCESS_SCOPES_CONNECT.to_string()],
1217
},
1218
// Ensure private ports are actually private and do not inherit any
1219
// default visibility that may be set on the tunnel:
1220
PortPrivacy::Private => tunnels::contracts::TunnelAccessControlEntry {
1221
kind: tunnels::contracts::TunnelAccessControlEntryType::Anonymous,
1222
provider: None,
1223
is_inherited: false,
1224
is_deny: true,
1225
is_inverse: false,
1226
organization: None,
1227
expiration: None,
1228
subjects: vec![],
1229
scopes: vec![TUNNEL_ACCESS_SCOPES_CONNECT.to_string()],
1230
},
1231
}],
1232
}
1233
}
1234
1235
fn tunnel_has_host_connection(tunnel: &Tunnel) -> bool {
1236
tunnel
1237
.status
1238
.as_ref()
1239
.and_then(|s| s.host_connection_count.as_ref().map(|c| c.get_count() > 0))
1240
.unwrap_or_default()
1241
}
1242
1243
#[cfg(test)]
1244
mod test {
1245
use super::*;
1246
1247
#[test]
1248
fn test_clean_hostname_for_tunnel() {
1249
assert_eq!(
1250
clean_hostname_for_tunnel("hello123"),
1251
"hello123".to_string()
1252
);
1253
assert_eq!(
1254
clean_hostname_for_tunnel("-cool-name-"),
1255
"cool-name".to_string()
1256
);
1257
assert_eq!(
1258
clean_hostname_for_tunnel("cool!name with_chars"),
1259
"coolname-with-chars".to_string()
1260
);
1261
assert_eq!(clean_hostname_for_tunnel("z"), "remote-machine".to_string());
1262
}
1263
}
1264
1265