use super::protocol::{self, PortPrivacy, PortProtocol};
use crate::auth;
use crate::constants::{IS_INTERACTIVE_CLI, PROTOCOL_VERSION_TAG, TUNNEL_SERVICE_USER_AGENT};
use crate::state::{LauncherPaths, PersistedState};
use crate::util::errors::{
wrap, AnyError, CodeError, DevTunnelError, InvalidTunnelName, TunnelCreationFailed,
WrappedError,
};
use crate::util::input::prompt_placeholder;
use crate::{debug, info, log, spanf, trace, warning};
use async_trait::async_trait;
use futures::future::BoxFuture;
use futures::{FutureExt, TryFutureExt};
use lazy_static::lazy_static;
use rand::prelude::IteratorRandom;
use regex::Regex;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use tunnels::connections::{ForwardedPortConnection, RelayTunnelHost};
use tunnels::contracts::{
Tunnel, TunnelAccessControl, TunnelPort, TunnelRelayTunnelEndpoint, PORT_TOKEN,
TUNNEL_ACCESS_SCOPES_CONNECT, TUNNEL_PROTOCOL_AUTO,
};
use tunnels::management::{
new_tunnel_management, HttpError, TunnelLocator, TunnelManagementClient, TunnelRequestOptions,
NO_REQUEST_OPTIONS,
};
static TUNNEL_COUNT_LIMIT_NAME: &str = "TunnelsPerUserPerLocation";
#[allow(dead_code)]
mod tunnel_flags {
use crate::{log, tunnels::wsl_detect::is_wsl_installed};
pub const IS_WSL_INSTALLED: u32 = 1 << 0;
pub const IS_WINDOWS: u32 = 1 << 1;
pub const IS_LINUX: u32 = 1 << 2;
pub const IS_MACOS: u32 = 1 << 3;
pub fn create(log: &log::Logger) -> String {
let mut flags = 0;
#[cfg(windows)]
{
flags |= IS_WINDOWS;
}
#[cfg(target_os = "linux")]
{
flags |= IS_LINUX;
}
#[cfg(target_os = "macos")]
{
flags |= IS_MACOS;
}
if is_wsl_installed(log) {
flags |= IS_WSL_INSTALLED;
}
format!("_flag{flags}")
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct PersistedTunnel {
pub name: String,
pub id: String,
pub cluster: String,
}
impl PersistedTunnel {
pub fn into_locator(self) -> TunnelLocator {
TunnelLocator::ID {
cluster: self.cluster,
id: self.id,
}
}
pub fn locator(&self) -> TunnelLocator {
TunnelLocator::ID {
cluster: self.cluster.clone(),
id: self.id.clone(),
}
}
}
#[async_trait]
trait AccessTokenProvider: Send + Sync {
async fn refresh_token(&self) -> Result<String, WrappedError>;
fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>>;
}
struct StaticAccessTokenProvider(String);
impl StaticAccessTokenProvider {
pub fn new(token: String) -> Self {
Self(token)
}
}
#[async_trait]
impl AccessTokenProvider for StaticAccessTokenProvider {
async fn refresh_token(&self) -> Result<String, WrappedError> {
Ok(self.0.clone())
}
fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>> {
futures::future::pending().boxed()
}
}
struct LookupAccessTokenProvider {
auth: auth::Auth,
client: TunnelManagementClient,
locator: TunnelLocator,
log: log::Logger,
initial_token: Arc<Mutex<Option<String>>>,
}
impl LookupAccessTokenProvider {
pub fn new(
auth: auth::Auth,
client: TunnelManagementClient,
locator: TunnelLocator,
log: log::Logger,
initial_token: Option<String>,
) -> Self {
Self {
auth,
client,
locator,
log,
initial_token: Arc::new(Mutex::new(initial_token)),
}
}
}
#[async_trait]
impl AccessTokenProvider for LookupAccessTokenProvider {
async fn refresh_token(&self) -> Result<String, WrappedError> {
if let Some(token) = self.initial_token.lock().unwrap().take() {
return Ok(token);
}
let tunnel_lookup = spanf!(
self.log,
self.log.span("dev-tunnel.tag.get"),
self.client.get_tunnel(
&self.locator,
&TunnelRequestOptions {
token_scopes: vec!["host".to_string()],
..Default::default()
}
)
);
trace!(self.log, "Successfully refreshed access token");
match tunnel_lookup {
Ok(tunnel) => Ok(get_host_token_from_tunnel(&tunnel)),
Err(e) => Err(wrap(e, "failed to lookup tunnel for host token")),
}
}
fn keep_alive(&self) -> BoxFuture<'static, Result<(), AnyError>> {
let auth = self.auth.clone();
auth.keep_token_alive().boxed()
}
}
#[derive(Clone)]
pub struct DevTunnels {
auth: auth::Auth,
log: log::Logger,
launcher_tunnel: PersistedState<Option<PersistedTunnel>>,
client: TunnelManagementClient,
tag: &'static str,
}
pub struct ActiveTunnel {
pub name: String,
pub id: String,
manager: ActiveTunnelManager,
}
impl ActiveTunnel {
pub async fn close(&mut self) -> Result<(), AnyError> {
self.manager.kill().await?;
Ok(())
}
pub async fn add_port_direct(
&mut self,
port_number: u16,
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, AnyError> {
let port = self.manager.add_port_direct(port_number).await?;
Ok(port)
}
pub async fn add_port_tcp(
&self,
port_number: u16,
privacy: PortPrivacy,
protocol: PortProtocol,
) -> Result<(), AnyError> {
self.manager
.add_port_tcp(port_number, privacy, protocol)
.await?;
Ok(())
}
pub async fn remove_port(&self, port_number: u16) -> Result<(), AnyError> {
self.manager.remove_port(port_number).await?;
Ok(())
}
pub fn get_port_format(&self) -> Result<String, AnyError> {
if let Some(details) = &*self.manager.endpoint_rx.borrow() {
return details
.as_ref()
.map(|r| {
r.base
.port_uri_format
.clone()
.expect("expected to have port format")
})
.map_err(|e| e.clone().into());
}
Err(CodeError::NoTunnelEndpoint.into())
}
pub fn get_port_uri(&self, port: u16) -> Result<String, AnyError> {
self.get_port_format()
.map(|f| f.replace(PORT_TOKEN, &port.to_string()))
}
pub fn status(&self) -> StatusLock {
self.manager.get_status()
}
}
const VSCODE_CLI_TUNNEL_TAG: &str = "vscode-server-launcher";
const VSCODE_CLI_FORWARDING_TAG: &str = "vscode-port-forward";
const OWNED_TUNNEL_TAGS: &[&str] = &[VSCODE_CLI_TUNNEL_TAG, VSCODE_CLI_FORWARDING_TAG];
const MAX_TUNNEL_NAME_LENGTH: usize = 20;
fn get_host_token_from_tunnel(tunnel: &Tunnel) -> String {
tunnel
.access_tokens
.as_ref()
.expect("expected to have access tokens")
.get("host")
.expect("expected to have host token")
.to_string()
}
fn is_valid_name(name: &str) -> Result<(), InvalidTunnelName> {
if name.len() > MAX_TUNNEL_NAME_LENGTH {
return Err(InvalidTunnelName(format!(
"Names cannot be longer than {MAX_TUNNEL_NAME_LENGTH} characters. Please try a different name."
)));
}
let re = Regex::new(r"^([\w-]+)$").unwrap();
if !re.is_match(name) {
return Err(InvalidTunnelName(
"Names can only contain letters, numbers, and '-'. Spaces, commas, and all other special characters are not allowed. Please try a different name.".to_string()
));
}
Ok(())
}
lazy_static! {
static ref HOST_TUNNEL_REQUEST_OPTIONS: TunnelRequestOptions = TunnelRequestOptions {
include_ports: true,
token_scopes: vec!["host".to_string()],
..Default::default()
};
}
#[derive(Clone, Debug)]
pub struct ExistingTunnel {
pub tunnel_name: Option<String>,
pub host_token: String,
pub tunnel_id: String,
pub cluster: String,
}
impl DevTunnels {
pub fn new_port_forwarding(
log: &log::Logger,
auth: auth::Auth,
paths: &LauncherPaths,
) -> DevTunnels {
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
client.authorization_provider(auth.clone());
DevTunnels {
auth,
log: log.clone(),
client: client.into(),
launcher_tunnel: PersistedState::new(paths.root().join("port_forwarding_tunnel.json")),
tag: VSCODE_CLI_FORWARDING_TAG,
}
}
pub fn new_remote_tunnel(
log: &log::Logger,
auth: auth::Auth,
paths: &LauncherPaths,
) -> DevTunnels {
let mut client = new_tunnel_management(&TUNNEL_SERVICE_USER_AGENT);
client.authorization_provider(auth.clone());
DevTunnels {
auth,
log: log.clone(),
client: client.into(),
launcher_tunnel: PersistedState::new(paths.root().join("code_tunnel.json")),
tag: VSCODE_CLI_TUNNEL_TAG,
}
}
pub async fn remove_tunnel(&mut self) -> Result<(), AnyError> {
let tunnel = match self.launcher_tunnel.load() {
Some(t) => t,
None => {
return Ok(());
}
};
spanf!(
self.log,
self.log.span("dev-tunnel.delete"),
self.client
.delete_tunnel(&tunnel.into_locator(), NO_REQUEST_OPTIONS)
)
.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
self.launcher_tunnel.save(None)?;
Ok(())
}
pub async fn rename_tunnel(&mut self, name: &str) -> Result<(), AnyError> {
self.update_tunnel_name(self.launcher_tunnel.load(), name)
.await
.map(|_| ())
}
async fn update_tunnel_name(
&mut self,
persisted: Option<PersistedTunnel>,
name: &str,
) -> Result<(Tunnel, PersistedTunnel), AnyError> {
let name = name.to_ascii_lowercase();
let (mut full_tunnel, mut persisted, is_new) = match persisted {
Some(persisted) => {
debug!(
self.log,
"Found a persisted tunnel, seeing if the name matches..."
);
self.get_or_create_tunnel(persisted, Some(&name), NO_REQUEST_OPTIONS)
.await
}
None => {
debug!(self.log, "Creating a new tunnel with the requested name");
self.create_tunnel(&name, NO_REQUEST_OPTIONS)
.await
.map(|(pt, t)| (t, pt, true))
}
}?;
let desired_tags = self.get_labels(&name);
if is_new || vec_eq_as_set(&full_tunnel.labels, &desired_tags) {
return Ok((full_tunnel, persisted));
}
debug!(self.log, "Tunnel name changed, applying updates...");
full_tunnel.labels = desired_tags;
let updated_tunnel = spanf!(
self.log,
self.log.span("dev-tunnel.tag.update"),
self.client.update_tunnel(&full_tunnel, NO_REQUEST_OPTIONS)
)
.map_err(|e| wrap(e, "failed to rename tunnel"))?;
persisted.name = name;
self.launcher_tunnel.save(Some(persisted.clone()))?;
Ok((updated_tunnel, persisted))
}
async fn get_or_create_tunnel(
&mut self,
persisted: PersistedTunnel,
create_with_new_name: Option<&str>,
options: &TunnelRequestOptions,
) -> Result<(Tunnel, PersistedTunnel, bool), AnyError> {
let tunnel_lookup = spanf!(
self.log,
self.log.span("dev-tunnel.tag.get"),
self.client.get_tunnel(&persisted.locator(), options)
);
match tunnel_lookup {
Ok(ft) => Ok((ft, persisted, false)),
Err(HttpError::ResponseError(e))
if e.status_code == StatusCode::NOT_FOUND
|| e.status_code == StatusCode::FORBIDDEN =>
{
let (persisted, tunnel) = self
.create_tunnel(create_with_new_name.unwrap_or(&persisted.name), options)
.await?;
Ok((tunnel, persisted, true))
}
Err(e) => Err(wrap(e, "failed to lookup tunnel").into()),
}
}
pub async fn start_new_launcher_tunnel(
&mut self,
preferred_name: Option<&str>,
use_random_name: bool,
preserve_ports: &[u16],
) -> Result<ActiveTunnel, AnyError> {
let (mut tunnel, persisted) = match self.launcher_tunnel.load() {
Some(mut persisted) => {
if let Some(preferred_name) = preferred_name.map(|n| n.to_ascii_lowercase()) {
if persisted.name.to_ascii_lowercase() != preferred_name {
(_, persisted) = self
.update_tunnel_name(Some(persisted), &preferred_name)
.await?;
}
}
let (tunnel, persisted, _) = self
.get_or_create_tunnel(persisted, None, &HOST_TUNNEL_REQUEST_OPTIONS)
.await?;
(tunnel, persisted)
}
None => {
debug!(self.log, "No code server tunnel found, creating new one");
let name = self
.get_name_for_tunnel(preferred_name, use_random_name)
.await?;
let (persisted, full_tunnel) = self
.create_tunnel(&name, &HOST_TUNNEL_REQUEST_OPTIONS)
.await?;
(full_tunnel, persisted)
}
};
tunnel = self
.sync_tunnel_tags(
&self.client,
&persisted.name,
tunnel,
&HOST_TUNNEL_REQUEST_OPTIONS,
)
.await?;
let locator = TunnelLocator::try_from(&tunnel).unwrap();
let host_token = get_host_token_from_tunnel(&tunnel);
for port_to_delete in tunnel
.ports
.iter()
.filter(|p: &&TunnelPort| !preserve_ports.contains(&p.port_number))
{
let output_fut = self.client.delete_tunnel_port(
&locator,
port_to_delete.port_number,
NO_REQUEST_OPTIONS,
);
spanf!(
self.log,
self.log.span("dev-tunnel.port.delete"),
output_fut
)
.map_err(|e| wrap(e, "failed to delete port"))?;
}
for endpoint in tunnel.endpoints {
let fut = self.client.delete_tunnel_endpoints(
&locator,
&endpoint.host_id,
NO_REQUEST_OPTIONS,
);
spanf!(self.log, self.log.span("dev-tunnel.endpoint.prune"), fut)
.map_err(|e| wrap(e, "failed to prune tunnel endpoint"))?;
}
self.start_tunnel(
locator.clone(),
&persisted,
self.client.clone(),
LookupAccessTokenProvider::new(
self.auth.clone(),
self.client.clone(),
locator,
self.log.clone(),
Some(host_token),
),
)
.await
}
async fn create_tunnel(
&mut self,
name: &str,
options: &TunnelRequestOptions,
) -> Result<(PersistedTunnel, Tunnel), AnyError> {
info!(self.log, "Creating tunnel with the name: {}", name);
let tunnel = match self.get_existing_tunnel_with_name(name).await? {
Some(e) => {
if tunnel_has_host_connection(&e) {
return Err(CodeError::TunnelActiveAndInUse(name.to_string()).into());
}
let loc = TunnelLocator::try_from(&e).unwrap();
info!(self.log, "Adopting existing tunnel (ID={:?})", loc);
spanf!(
self.log,
self.log.span("dev-tunnel.tag.get"),
self.client.get_tunnel(&loc, &HOST_TUNNEL_REQUEST_OPTIONS)
)
.map_err(|e| wrap(e, "failed to lookup tunnel"))?
}
None => loop {
let result = spanf!(
self.log,
self.log.span("dev-tunnel.create"),
self.client.create_tunnel(
Tunnel {
labels: self.get_labels(name),
..Default::default()
},
options
)
);
match result {
Err(HttpError::ResponseError(e))
if e.status_code == StatusCode::TOO_MANY_REQUESTS =>
{
if let Some(d) = e.get_details() {
let detail = d.detail.unwrap_or_else(|| "unknown".to_string());
if detail.contains(TUNNEL_COUNT_LIMIT_NAME)
&& self.try_recycle_tunnel().await?
{
continue;
}
return Err(AnyError::from(TunnelCreationFailed(
name.to_string(),
detail,
)));
}
return Err(AnyError::from(TunnelCreationFailed(
name.to_string(),
"You have exceeded a limit for the port fowarding service. Please remove other machines before trying to add this machine.".to_string(),
)));
}
Err(e) => {
return Err(AnyError::from(TunnelCreationFailed(
name.to_string(),
format!("{e:?}"),
)))
}
Ok(t) => break t,
}
},
};
let pt = PersistedTunnel {
cluster: tunnel.cluster_id.clone().unwrap(),
id: tunnel.tunnel_id.clone().unwrap(),
name: name.to_string(),
};
self.launcher_tunnel.save(Some(pt.clone()))?;
Ok((pt, tunnel))
}
fn get_labels(&self, name: &str) -> Vec<String> {
vec![
name.to_string(),
PROTOCOL_VERSION_TAG.to_string(),
self.tag.to_string(),
tunnel_flags::create(&self.log),
]
}
async fn sync_tunnel_tags(
&self,
client: &TunnelManagementClient,
name: &str,
tunnel: Tunnel,
options: &TunnelRequestOptions,
) -> Result<Tunnel, AnyError> {
let new_labels = self.get_labels(name);
if vec_eq_as_set(&tunnel.labels, &new_labels) {
return Ok(tunnel);
}
debug!(
self.log,
"Updating tunnel tags {} -> {}",
tunnel.labels.join(", "),
new_labels.join(", ")
);
let tunnel_update = Tunnel {
labels: new_labels,
tunnel_id: tunnel.tunnel_id.clone(),
cluster_id: tunnel.cluster_id.clone(),
..Default::default()
};
let result = spanf!(
self.log,
self.log.span("dev-tunnel.protocol-tag-update"),
client.update_tunnel(&tunnel_update, options)
);
result.map_err(|e| wrap(e, "tunnel tag update failed").into())
}
async fn try_recycle_tunnel(&mut self) -> Result<bool, AnyError> {
trace!(
self.log,
"Tunnel limit hit, trying to recycle an old tunnel"
);
let existing_tunnels = self.list_tunnels_with_tag(OWNED_TUNNEL_TAGS).await?;
let recyclable = existing_tunnels
.iter()
.filter(|t| !tunnel_has_host_connection(t))
.choose(&mut rand::thread_rng());
match recyclable {
Some(tunnel) => {
trace!(self.log, "Recycling tunnel ID {:?}", tunnel.tunnel_id);
spanf!(
self.log,
self.log.span("dev-tunnel.delete"),
self.client
.delete_tunnel(&tunnel.try_into().unwrap(), NO_REQUEST_OPTIONS)
)
.map_err(|e| wrap(e, "failed to execute `tunnel delete`"))?;
Ok(true)
}
None => {
trace!(self.log, "No tunnels available to recycle");
Ok(false)
}
}
}
async fn list_tunnels_with_tag(
&mut self,
tags: &[&'static str],
) -> Result<Vec<Tunnel>, AnyError> {
let tunnels = spanf!(
self.log,
self.log.span("dev-tunnel.listall"),
self.client.list_all_tunnels(&TunnelRequestOptions {
labels: tags.iter().map(|t| t.to_string()).collect(),
..Default::default()
})
)
.map_err(|e| wrap(e, "error listing current tunnels"))?;
Ok(tunnels)
}
async fn get_existing_tunnel_with_name(&self, name: &str) -> Result<Option<Tunnel>, AnyError> {
let existing: Vec<Tunnel> = spanf!(
self.log,
self.log.span("dev-tunnel.rename.search"),
self.client.list_all_tunnels(&TunnelRequestOptions {
labels: vec![self.tag.to_string(), name.to_string()],
require_all_labels: true,
limit: 1,
include_ports: true,
token_scopes: vec!["host".to_string()],
..Default::default()
})
)
.map_err(|e| wrap(e, "failed to list existing tunnels"))?;
Ok(existing.into_iter().next())
}
fn get_placeholder_name() -> String {
let mut n = clean_hostname_for_tunnel(&gethostname::gethostname().to_string_lossy());
n.make_ascii_lowercase();
n.truncate(MAX_TUNNEL_NAME_LENGTH);
n
}
async fn get_name_for_tunnel(
&mut self,
preferred_name: Option<&str>,
mut use_random_name: bool,
) -> Result<String, AnyError> {
let existing_tunnels = self.list_tunnels_with_tag(&[self.tag]).await?;
let is_name_free = |n: &str| {
!existing_tunnels
.iter()
.any(|v| tunnel_has_host_connection(v) && v.labels.iter().any(|t| t == n))
};
if let Some(machine_name) = preferred_name {
let name = machine_name.to_ascii_lowercase();
if let Err(e) = is_valid_name(&name) {
info!(self.log, "{} is an invalid name", e);
return Err(AnyError::from(wrap(e, "invalid name")));
}
if is_name_free(&name) {
return Ok(name);
}
info!(
self.log,
"{} is already taken, using a random name instead", &name
);
use_random_name = true;
}
let mut placeholder_name = Self::get_placeholder_name();
if !is_name_free(&placeholder_name) {
for i in 2.. {
let fixed_name = format!("{placeholder_name}{i}");
if is_name_free(&fixed_name) {
placeholder_name = fixed_name;
break;
}
}
}
if use_random_name || !*IS_INTERACTIVE_CLI {
return Ok(placeholder_name);
}
loop {
let mut name = prompt_placeholder(
"What would you like to call this machine?",
&placeholder_name,
)?;
name.make_ascii_lowercase();
if let Err(e) = is_valid_name(&name) {
info!(self.log, "{}", e);
continue;
}
if is_name_free(&name) {
return Ok(name);
}
info!(self.log, "The name {} is already in use", name);
}
}
pub async fn start_existing_tunnel(
&mut self,
tunnel: ExistingTunnel,
) -> Result<ActiveTunnel, AnyError> {
let tunnel_details = PersistedTunnel {
name: match tunnel.tunnel_name {
Some(n) => n,
None => Self::get_placeholder_name(),
},
id: tunnel.tunnel_id,
cluster: tunnel.cluster,
};
let mut mgmt = self.client.build();
mgmt.authorization(tunnels::management::Authorization::Tunnel(
tunnel.host_token.clone(),
));
let client = mgmt.into();
self.sync_tunnel_tags(
&client,
&tunnel_details.name,
Tunnel {
cluster_id: Some(tunnel_details.cluster.clone()),
tunnel_id: Some(tunnel_details.id.clone()),
..Default::default()
},
&HOST_TUNNEL_REQUEST_OPTIONS,
)
.await?;
self.start_tunnel(
tunnel_details.locator(),
&tunnel_details,
client,
StaticAccessTokenProvider::new(tunnel.host_token),
)
.await
}
async fn start_tunnel(
&mut self,
locator: TunnelLocator,
tunnel_details: &PersistedTunnel,
client: TunnelManagementClient,
access_token: impl AccessTokenProvider + 'static,
) -> Result<ActiveTunnel, AnyError> {
let mut manager = ActiveTunnelManager::new(self.log.clone(), client, locator, access_token);
let endpoint_result = spanf!(
self.log,
self.log.span("dev-tunnel.serve.callback"),
manager.get_endpoint()
);
let endpoint = match endpoint_result {
Ok(endpoint) => endpoint,
Err(e) => {
error!(self.log, "Error connecting to tunnel endpoint: {}", e);
manager.kill().await.ok();
return Err(e);
}
};
debug!(self.log, "Connected to tunnel endpoint: {:?}", endpoint);
Ok(ActiveTunnel {
name: tunnel_details.name.clone(),
id: tunnel_details.id.clone(),
manager,
})
}
}
#[derive(Clone, Default)]
pub struct StatusLock(Arc<std::sync::Mutex<protocol::singleton::Status>>);
impl StatusLock {
fn succeed(&self) {
let mut status = self.0.lock().unwrap();
status.tunnel = protocol::singleton::TunnelState::Connected;
status.last_connected_at = Some(chrono::Utc::now());
}
fn fail(&self, reason: String) {
let mut status = self.0.lock().unwrap();
if let protocol::singleton::TunnelState::Connected = status.tunnel {
status.last_disconnected_at = Some(chrono::Utc::now());
status.tunnel = protocol::singleton::TunnelState::Disconnected;
}
status.last_fail_reason = Some(reason);
}
pub fn read(&self) -> protocol::singleton::Status {
let status = self.0.lock().unwrap();
status.clone()
}
}
struct ActiveTunnelManager {
close_tx: Option<mpsc::Sender<()>>,
endpoint_rx: watch::Receiver<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
status: StatusLock,
}
impl ActiveTunnelManager {
pub fn new(
log: log::Logger,
mgmt: TunnelManagementClient,
locator: TunnelLocator,
access_token: impl AccessTokenProvider + 'static,
) -> ActiveTunnelManager {
let (endpoint_tx, endpoint_rx) = watch::channel(None);
let (close_tx, close_rx) = mpsc::channel(1);
let relay = Arc::new(tokio::sync::Mutex::new(RelayTunnelHost::new(locator, mgmt)));
let relay_spawned = relay.clone();
let status = StatusLock::default();
let status_spawned = status.clone();
tokio::spawn(async move {
ActiveTunnelManager::spawn_tunnel(
log,
relay_spawned,
close_rx,
endpoint_tx,
access_token,
status_spawned,
)
.await;
});
ActiveTunnelManager {
endpoint_rx,
relay,
close_tx: Some(close_tx),
status,
}
}
pub fn get_status(&self) -> StatusLock {
self.status.clone()
}
pub async fn add_port_tcp(
&self,
port_number: u16,
privacy: PortPrivacy,
protocol: PortProtocol,
) -> Result<(), WrappedError> {
self.relay
.lock()
.await
.add_port(&TunnelPort {
port_number,
protocol: Some(protocol.to_contract_str().to_string()),
access_control: Some(privacy_to_tunnel_acl(privacy)),
..Default::default()
})
.await
.map_err(|e| wrap(e, "error adding port to relay"))?;
Ok(())
}
pub async fn add_port_direct(
&self,
port_number: u16,
) -> Result<mpsc::UnboundedReceiver<ForwardedPortConnection>, WrappedError> {
self.relay
.lock()
.await
.add_port_raw(&TunnelPort {
port_number,
protocol: Some(TUNNEL_PROTOCOL_AUTO.to_owned()),
access_control: Some(privacy_to_tunnel_acl(PortPrivacy::Private)),
..Default::default()
})
.await
.map_err(|e| wrap(e, "error adding port to relay"))
}
pub async fn remove_port(&self, port_number: u16) -> Result<(), WrappedError> {
self.relay
.lock()
.await
.remove_port(port_number)
.await
.map_err(|e| wrap(e, "error remove port from relay"))
}
pub async fn get_endpoint(&mut self) -> Result<TunnelRelayTunnelEndpoint, AnyError> {
loop {
if let Some(details) = &*self.endpoint_rx.borrow() {
return details.clone().map_err(AnyError::from);
}
if self.endpoint_rx.changed().await.is_err() {
return Err(DevTunnelError("tunnel creation cancelled".to_string()).into());
}
}
}
pub async fn kill(&mut self) -> Result<(), AnyError> {
if let Some(tx) = self.close_tx.take() {
drop(tx);
}
self.relay
.lock()
.await
.unregister()
.await
.map_err(|e| wrap(e, "error unregistering relay"))?;
while self.endpoint_rx.changed().await.is_ok() {}
Ok(())
}
async fn spawn_tunnel(
log: log::Logger,
relay: Arc<tokio::sync::Mutex<RelayTunnelHost>>,
mut close_rx: mpsc::Receiver<()>,
endpoint_tx: watch::Sender<Option<Result<TunnelRelayTunnelEndpoint, WrappedError>>>,
access_token_provider: impl AccessTokenProvider + 'static,
status: StatusLock,
) {
let mut token_ka = access_token_provider.keep_alive();
let mut backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(120));
macro_rules! fail {
($e: expr, $msg: expr) => {
let fmt = format!("{}: {}", $msg, $e);
warning!(log, &fmt);
status.fail(fmt);
endpoint_tx.send(Some(Err($e))).ok();
backoff.delay().await;
};
}
loop {
debug!(log, "Starting tunnel to server...");
let access_token = match access_token_provider.refresh_token().await {
Ok(t) => t,
Err(e) => {
fail!(e, "Error refreshing access token, will retry");
continue;
}
};
let handle_res = {
let mut relay = relay.lock().await;
relay
.connect(&access_token)
.await
.map_err(|e| wrap(e, "error connecting to tunnel"))
};
let mut handle = match handle_res {
Ok(handle) => handle,
Err(e) => {
fail!(e, "Error connecting to relay, will retry");
continue;
}
};
backoff.reset();
status.succeed();
endpoint_tx.send(Some(Ok(handle.endpoint().clone()))).ok();
tokio::select! {
res = (&mut handle).map_err(|e| wrap(e, "error from tunnel connection")) => {
if let Err(e) = res {
fail!(e, "Tunnel exited unexpectedly, reconnecting");
} else {
warning!(log, "Tunnel exited unexpectedly but gracefully, reconnecting");
backoff.delay().await;
}
},
Err(e) = &mut token_ka => {
error!(log, "access token is no longer valid, exiting: {}", e);
return;
},
_ = close_rx.recv() => {
trace!(log, "Tunnel closing gracefully");
trace!(log, "Tunnel closed with result: {:?}", handle.close().await);
break;
}
}
}
}
}
struct Backoff {
failures: u32,
base_duration: Duration,
max_duration: Duration,
}
impl Backoff {
pub fn new(base_duration: Duration, max_duration: Duration) -> Self {
Self {
failures: 0,
base_duration,
max_duration,
}
}
pub async fn delay(&mut self) {
tokio::time::sleep(self.next()).await
}
pub fn next(&mut self) -> Duration {
self.failures += 1;
let duration = self
.base_duration
.checked_mul(self.failures)
.unwrap_or(self.max_duration);
std::cmp::min(duration, self.max_duration)
}
pub fn reset(&mut self) {
self.failures = 0;
}
}
fn clean_hostname_for_tunnel(hostname: &str) -> String {
let mut out = String::new();
for char in hostname.chars().take(60) {
match char {
'-' | '_' | ' ' => {
out.push('-');
}
'0'..='9' | 'a'..='z' | 'A'..='Z' => {
out.push(char);
}
_ => {}
}
}
let trimmed = out.trim_matches('-');
if trimmed.len() < 2 {
"remote-machine".to_string()
} else {
trimmed.to_owned()
}
}
fn vec_eq_as_set(a: &[String], b: &[String]) -> bool {
if a.len() != b.len() {
return false;
}
for item in a {
if !b.contains(item) {
return false;
}
}
true
}
fn privacy_to_tunnel_acl(privacy: PortPrivacy) -> TunnelAccessControl {
TunnelAccessControl {
entries: vec![match privacy {
PortPrivacy::Public => tunnels::contracts::TunnelAccessControlEntry {
kind: tunnels::contracts::TunnelAccessControlEntryType::Anonymous,
provider: None,
is_inherited: false,
is_deny: false,
is_inverse: false,
organization: None,
expiration: None,
subjects: vec![],
scopes: vec![TUNNEL_ACCESS_SCOPES_CONNECT.to_string()],
},
PortPrivacy::Private => tunnels::contracts::TunnelAccessControlEntry {
kind: tunnels::contracts::TunnelAccessControlEntryType::Anonymous,
provider: None,
is_inherited: false,
is_deny: true,
is_inverse: false,
organization: None,
expiration: None,
subjects: vec![],
scopes: vec![TUNNEL_ACCESS_SCOPES_CONNECT.to_string()],
},
}],
}
}
fn tunnel_has_host_connection(tunnel: &Tunnel) -> bool {
tunnel
.status
.as_ref()
.and_then(|s| s.host_connection_count.as_ref().map(|c| c.get_count() > 0))
.unwrap_or_default()
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_clean_hostname_for_tunnel() {
assert_eq!(
clean_hostname_for_tunnel("hello123"),
"hello123".to_string()
);
assert_eq!(
clean_hostname_for_tunnel("-cool-name-"),
"cool-name".to_string()
);
assert_eq!(
clean_hostname_for_tunnel("cool!name with_chars"),
"coolname-with-chars".to_string()
);
assert_eq!(clean_hostname_for_tunnel("z"), "remote-machine".to_string());
}
}