use std::collections::HashMap;
use std::convert::Infallible;
use std::fs;
use std::io::{Read, Write};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::{pin, time};
use crate::async_pipe::{
get_socket_name, get_socket_rw_stream, listen_socket_rw_stream, AsyncPipe,
};
use crate::constants::VSCODE_CLI_QUALITY;
use crate::download_cache::DownloadCache;
use crate::log;
use crate::options::Quality;
use crate::state::{LauncherPaths, PersistedState};
use crate::tunnels::shutdown_signal::ShutdownRequest;
use crate::update_service::{
unzip_downloaded_release, Platform, Release, TargetKind, UpdateService,
};
use crate::util::command::new_script_command;
use crate::util::errors::AnyError;
use crate::util::http::{self, ReqwestSimpleHttp};
use crate::util::io::SilentCopyProgress;
use crate::util::sync::{new_barrier, Barrier, BarrierOpener};
use crate::{
tunnels::legal,
util::{errors::CodeError, prereqs::PreReqChecker},
};
use super::{args::ServeWebArgs, CommandContext};
const COMMIT_HASH_LEN: usize = 40;
const SERVER_IDLE_TIMEOUT_SECS: u64 = 60 * 60;
const SERVER_ACTIVE_TIMEOUT_SECS: u64 = SERVER_IDLE_TIMEOUT_SECS * 24 * 30 * 12;
const RELEASE_CHECK_INTERVAL: u64 = 60 * 60;
const SECRET_KEY_BYTES: usize = 32;
const SECRET_KEY_MINT_PATH: &str = "_vscode-cli/mint-key";
const PATH_COOKIE_NAME: &str = "vscode-secret-key-path";
const SECRET_KEY_COOKIE_NAME: &str = "vscode-cli-secret-half";
pub async fn serve_web(ctx: CommandContext, mut args: ServeWebArgs) -> Result<i32, AnyError> {
legal::require_consent(&ctx.paths, args.accept_server_license_terms)?;
let platform: crate::update_service::Platform = PreReqChecker::new().verify().await?;
if !args.without_connection_token {
if let Some(p) = args.connection_token_file.as_deref() {
let token = fs::read_to_string(PathBuf::from(p))
.map_err(CodeError::CouldNotReadConnectionTokenFile)?;
args.connection_token = Some(token.trim().to_string());
} else {
let token_path = ctx.paths.root().join("serve-web-token");
let token = mint_connection_token(&token_path, args.connection_token.clone())
.map_err(CodeError::CouldNotCreateConnectionTokenFile)?;
args.connection_token = Some(token);
args.connection_token_file = Some(token_path.to_string_lossy().to_string());
}
}
let cm: Arc<ConnectionManager> = ConnectionManager::new(&ctx, platform, args.clone());
let update_check_interval = 3600;
if args.commit_id.is_none() {
cm.clone()
.start_update_checker(Duration::from_secs(update_check_interval));
} else {
if let Err(e) = cm.get_latest_release().await {
warning!(cm.log, "error getting latest version: {}", e);
}
}
let key = get_server_key_half(&ctx.paths);
let make_svc = move || {
let ctx = HandleContext {
cm: cm.clone(),
log: cm.log.clone(),
server_secret_key: key.clone(),
};
let service = service_fn(move |req| handle(ctx.clone(), req));
async move { Ok::<_, Infallible>(service) }
};
let mut shutdown = ShutdownRequest::create_rx([ShutdownRequest::CtrlC]);
let r = if let Some(s) = args.socket_path {
let s = PathBuf::from(&s);
let socket = listen_socket_rw_stream(&s).await?;
ctx.log
.result(format!("Web UI available on {}", s.display()));
let r = Server::builder(socket.into_pollable())
.serve(make_service_fn(|_| make_svc()))
.with_graceful_shutdown(async {
let _ = shutdown.wait().await;
})
.await;
let _ = std::fs::remove_file(&s);
r
} else {
let addr: SocketAddr = match &args.host {
Some(h) => {
SocketAddr::new(h.parse().map_err(CodeError::InvalidHostAddress)?, args.port)
}
None => SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), args.port),
};
let builder = Server::try_bind(&addr).map_err(CodeError::CouldNotListenOnInterface)?;
let bound_addr = builder.local_addr();
let mut listening = format!("Web UI available at http://{bound_addr}");
if let Some(base) = args.server_base_path {
if !base.starts_with('/') {
listening.push('/');
}
listening.push_str(&base);
}
if let Some(ct) = args.connection_token {
listening.push_str(&format!("?tkn={ct}"));
}
ctx.log.result(listening);
builder
.serve(make_service_fn(|_| make_svc()))
.with_graceful_shutdown(async {
let _ = shutdown.wait().await;
})
.await
};
r.map_err(CodeError::CouldNotListenOnInterface)?;
Ok(0)
}
#[derive(Clone)]
struct HandleContext {
cm: Arc<ConnectionManager>,
log: log::Logger,
server_secret_key: SecretKeyPart,
}
async fn handle(ctx: HandleContext, req: Request<Body>) -> Result<Response<Body>, Infallible> {
let client_key_half = get_client_key_half(&req);
let path = req.uri().path();
let mut res = if path.starts_with(&ctx.cm.base_path)
&& path.get(ctx.cm.base_path.len()..).unwrap_or_default() == SECRET_KEY_MINT_PATH
{
handle_secret_mint(&ctx, req)
} else {
handle_proxied(&ctx, req).await
};
append_secret_headers(&ctx.cm.base_path, &mut res, &client_key_half);
Ok(res)
}
async fn handle_proxied(ctx: &HandleContext, req: Request<Body>) -> Response<Body> {
let release = if let Some((r, _)) = get_release_from_path(req.uri().path(), ctx.cm.platform) {
r
} else {
match ctx.cm.get_release_from_cache().await {
Ok(r) => r,
Err(e) => {
error!(ctx.log, "error getting latest version: {}", e);
return response::code_err(e);
}
}
};
match ctx.cm.get_connection(release).await {
Ok(rw) => {
if req.headers().contains_key(hyper::header::UPGRADE) {
forward_ws_req_to_server(ctx.log.clone(), rw, req).await
} else {
forward_http_req_to_server(rw, req).await
}
}
Err(CodeError::ServerNotYetDownloaded) => response::wait_for_download(),
Err(e) => response::code_err(e),
}
}
fn handle_secret_mint(ctx: &HandleContext, req: Request<Body>) -> Response<Body> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(ctx.server_secret_key.0.as_ref());
hasher.update(get_client_key_half(&req).0.as_ref());
let hash = hasher.finalize();
let hash = hash[..SECRET_KEY_BYTES].to_vec();
response::secret_key(hash)
}
fn append_secret_headers(
base_path: &str,
res: &mut Response<Body>,
client_key_half: &SecretKeyPart,
) {
let headers = res.headers_mut();
headers.append(
hyper::header::SET_COOKIE,
format!("{PATH_COOKIE_NAME}={base_path}{SECRET_KEY_MINT_PATH}; SameSite=Strict; Path=/",)
.parse()
.unwrap(),
);
headers.append(
hyper::header::SET_COOKIE,
format!(
"{}={}; SameSite=Strict; HttpOnly; Max-Age=2592000; Path=/",
SECRET_KEY_COOKIE_NAME,
client_key_half.encode()
)
.parse()
.unwrap(),
);
}
fn get_release_from_path(path: &str, platform: Platform) -> Option<(Release, String)> {
if !path.starts_with('/') {
return None;
}
let path = &path[1..];
let i = path.find('/').unwrap_or(path.len());
let quality_commit_sep = path.get(..i).and_then(|p| p.find('-'))?;
let (quality_commit, remaining) = path.split_at(i);
let (quality, commit) = quality_commit.split_at(quality_commit_sep);
let commit = &commit[1..];
if !is_commit_hash(commit) {
return None;
}
Some((
Release {
quality: Quality::try_from(quality).ok()?,
commit: commit.to_string(),
platform,
target: TargetKind::Web,
name: "".to_string(),
},
remaining.to_string(),
))
}
async fn forward_http_req_to_server(
(rw, handle): (AsyncPipe, ConnectionHandle),
req: Request<Body>,
) -> Response<Body> {
let (mut request_sender, connection) =
match hyper::client::conn::Builder::new().handshake(rw).await {
Ok(r) => r,
Err(e) => return response::connection_err(e),
};
tokio::spawn(connection);
let res = request_sender
.send_request(req)
.await
.unwrap_or_else(response::connection_err);
drop(handle);
res
}
async fn forward_ws_req_to_server(
log: log::Logger,
(rw, handle): (AsyncPipe, ConnectionHandle),
mut req: Request<Body>,
) -> Response<Body> {
let (mut request_sender, connection) =
match hyper::client::conn::Builder::new().handshake(rw).await {
Ok(r) => r,
Err(e) => return response::connection_err(e),
};
tokio::spawn(connection);
let mut proxied_req = Request::builder().uri(req.uri());
for (k, v) in req.headers() {
proxied_req = proxied_req.header(k, v);
}
let mut res = request_sender
.send_request(proxied_req.body(Body::empty()).unwrap())
.await
.unwrap_or_else(response::connection_err);
let mut proxied_res = Response::new(Body::empty());
*proxied_res.status_mut() = res.status();
for (k, v) in res.headers() {
proxied_res.headers_mut().insert(k, v.clone());
}
if res.status() == hyper::StatusCode::SWITCHING_PROTOCOLS {
tokio::spawn(async move {
let (s_req, s_res) =
tokio::join!(hyper::upgrade::on(&mut req), hyper::upgrade::on(&mut res));
match (s_req, s_res) {
(Err(e1), Err(e2)) => debug!(
log,
"client ({}) and server ({}) websocket upgrade failed", e1, e2
),
(Err(e1), _) => debug!(log, "client ({}) websocket upgrade failed", e1),
(_, Err(e2)) => debug!(log, "server ({}) websocket upgrade failed", e2),
(Ok(mut s_req), Ok(mut s_res)) => {
trace!(log, "websocket upgrade succeeded");
let r = tokio::io::copy_bidirectional(&mut s_req, &mut s_res).await;
trace!(log, "websocket closed (error: {:?})", r.err());
}
}
drop(handle);
});
}
proxied_res
}
fn is_commit_hash(s: &str) -> bool {
s.len() == COMMIT_HASH_LEN && s.chars().all(|c| c.is_ascii_hexdigit())
}
fn extract_cookie(req: &Request<Body>, name: &str) -> Option<String> {
for h in req.headers().get_all(hyper::header::COOKIE) {
if let Ok(str) = h.to_str() {
for pair in str.split("; ") {
let i = match pair.find('=') {
Some(i) => i,
None => continue,
};
if &pair[..i] == name {
return Some(pair[i + 1..].to_string());
}
}
}
}
None
}
#[derive(Clone)]
struct SecretKeyPart(Box<[u8; SECRET_KEY_BYTES]>);
impl SecretKeyPart {
pub fn new() -> Self {
let key: [u8; SECRET_KEY_BYTES] = rand::random();
Self(Box::new(key))
}
pub fn decode(s: &str) -> Result<Self, base64::DecodeSliceError> {
use base64::{engine::general_purpose, Engine as _};
let mut key: [u8; SECRET_KEY_BYTES] = [0; SECRET_KEY_BYTES];
let v = general_purpose::URL_SAFE.decode(s)?;
if v.len() != SECRET_KEY_BYTES {
return Err(base64::DecodeSliceError::OutputSliceTooSmall);
}
key.copy_from_slice(&v);
Ok(Self(Box::new(key)))
}
pub fn encode(&self) -> String {
use base64::{engine::general_purpose, Engine as _};
general_purpose::URL_SAFE.encode(self.0.as_ref())
}
}
fn get_server_key_half(paths: &LauncherPaths) -> SecretKeyPart {
let ps = PersistedState::new(paths.root().join("serve-web-key-half"));
let value: String = ps.load();
if let Ok(sk) = SecretKeyPart::decode(&value) {
return sk;
}
let key = SecretKeyPart::new();
let _ = ps.save(key.encode());
key
}
fn get_client_key_half(req: &Request<Body>) -> SecretKeyPart {
if let Some(c) = extract_cookie(req, SECRET_KEY_COOKIE_NAME) {
if let Ok(sk) = SecretKeyPart::decode(&c) {
return sk;
}
}
SecretKeyPart::new()
}
mod response {
use const_format::concatcp;
use crate::constants::QUALITYLESS_SERVER_NAME;
use super::*;
pub fn connection_err(err: hyper::Error) -> Response<Body> {
Response::builder()
.status(503)
.body(Body::from(format!("Error connecting to server: {err:?}")))
.unwrap()
}
pub fn code_err(err: CodeError) -> Response<Body> {
Response::builder()
.status(500)
.body(Body::from(format!("Error serving request: {err}")))
.unwrap()
}
pub fn wait_for_download() -> Response<Body> {
Response::builder()
.status(202)
.header("Content-Type", "text/html")
.body(Body::from(concatcp!("The latest version of the ", QUALITYLESS_SERVER_NAME, " is downloading, please wait a moment...<script>setTimeout(()=>location.reload(),1500)</script>", )))
.unwrap()
}
pub fn secret_key(hash: Vec<u8>) -> Response<Body> {
Response::builder()
.status(200)
.header("Content-Type", "application/octet-stream")
.body(Body::from(hash))
.unwrap()
}
}
struct ConnectionHandle {
client_counter: Arc<tokio::sync::watch::Sender<usize>>,
}
impl ConnectionHandle {
pub fn new(client_counter: Arc<tokio::sync::watch::Sender<usize>>) -> Self {
client_counter.send_modify(|v| {
*v += 1;
});
Self { client_counter }
}
}
impl Drop for ConnectionHandle {
fn drop(&mut self) {
self.client_counter.send_modify(|v| {
*v -= 1;
});
}
}
type StartData = (PathBuf, Arc<tokio::sync::watch::Sender<usize>>);
struct VersionState {
downloaded: bool,
socket_path: Barrier<Result<StartData, String>>,
}
type ConnectionStateMap = Arc<Mutex<HashMap<(Quality, String), VersionState>>>;
struct ConnectionManager {
pub platform: Platform,
pub log: log::Logger,
args: ServeWebArgs,
base_path: String,
cache: DownloadCache,
state: ConnectionStateMap,
update_service: UpdateService,
latest_version: tokio::sync::Mutex<Option<(Instant, Release)>>,
}
fn key_for_release(release: &Release) -> (Quality, String) {
(release.quality, release.commit.clone())
}
fn normalize_base_path(p: &str) -> String {
let p = p.trim_matches('/');
if p.is_empty() {
return "/".to_string();
}
format!("/{}/", p.trim_matches('/'))
}
impl ConnectionManager {
pub fn new(ctx: &CommandContext, platform: Platform, args: ServeWebArgs) -> Arc<Self> {
let base_path = normalize_base_path(args.server_base_path.as_deref().unwrap_or_default());
let cache = DownloadCache::new(ctx.paths.web_server_storage());
let target_kind = TargetKind::Web;
let quality = VSCODE_CLI_QUALITY.map_or(Quality::Stable, |q| match Quality::try_from(q) {
Ok(q) => q,
Err(_) => Quality::Stable,
});
let now = Instant::now();
let latest_version = tokio::sync::Mutex::new(cache.get().first().map(|latest_commit| {
(
now.checked_sub(Duration::from_secs(RELEASE_CHECK_INTERVAL))
.unwrap_or(now),
Release {
name: String::from("0.0.0"),
commit: latest_commit.clone(),
platform,
target: target_kind,
quality,
},
)
}));
Arc::new(Self {
platform,
args,
base_path,
log: ctx.log.clone(),
cache,
update_service: UpdateService::new(
ctx.log.clone(),
Arc::new(ReqwestSimpleHttp::with_client(ctx.http.clone())),
),
state: ConnectionStateMap::default(),
latest_version,
})
}
pub fn start_update_checker(self: Arc<Self>, duration: Duration) {
tokio::spawn(async move {
let mut interval = time::interval(duration);
loop {
interval.tick().await;
if let Err(e) = self.get_latest_release().await {
warning!(self.log, "error getting latest version: {}", e);
}
}
});
}
pub async fn get_release_from_cache(&self) -> Result<Release, CodeError> {
let latest = self.latest_version.lock().await;
if let Some((_, release)) = &*latest {
return Ok(release.clone());
}
drop(latest);
self.get_latest_release().await
}
pub async fn get_connection(
&self,
release: Release,
) -> Result<(AsyncPipe, ConnectionHandle), CodeError> {
let (path, counter) = self.get_version_data(release).await?;
let handle = ConnectionHandle::new(counter);
let rw = get_socket_rw_stream(&path).await?;
Ok((rw, handle))
}
pub async fn get_latest_release(&self) -> Result<Release, CodeError> {
let mut latest = self.latest_version.lock().await;
let now = Instant::now();
let target_kind = TargetKind::Web;
let quality = VSCODE_CLI_QUALITY
.ok_or_else(|| CodeError::UpdatesNotConfigured("no configured quality"))
.and_then(|q| {
Quality::try_from(q).map_err(|_| CodeError::UpdatesNotConfigured("unknown quality"))
})?;
if let Some(commit) = &self.args.commit_id {
let release = Release {
name: commit.to_string(),
commit: commit.to_string(),
platform: self.platform,
target: target_kind,
quality,
};
debug!(
self.log,
"using provided commit instead of latest release: {}", release
);
*latest = Some((now, release.clone()));
return Ok(release);
}
let release = self
.update_service
.get_latest_commit(self.platform, target_kind, quality)
.await
.map_err(|e| CodeError::UpdateCheckFailed(e.to_string()));
if let (Err(e), Some((_, previous))) = (&release, latest.clone()) {
warning!(self.log, "error getting latest release, using stale: {}", e);
*latest = Some((now, previous.clone()));
return Ok(previous.clone());
}
let release = release?;
debug!(self.log, "refreshed latest release: {}", release);
*latest = Some((now, release.clone()));
Ok(release)
}
async fn get_version_data(&self, release: Release) -> Result<StartData, CodeError> {
self.get_version_data_inner(release)?
.wait()
.await
.unwrap()
.map_err(CodeError::ServerDownloadError)
}
fn get_version_data_inner(
&self,
release: Release,
) -> Result<Barrier<Result<StartData, String>>, CodeError> {
let mut state = self.state.lock().unwrap();
let key = key_for_release(&release);
if let Some(s) = state.get_mut(&key) {
if !s.downloaded {
if s.socket_path.is_open() {
s.downloaded = true;
} else {
return Err(CodeError::ServerNotYetDownloaded);
}
}
return Ok(s.socket_path.clone());
}
let (socket_path, opener) = new_barrier();
let state_map_dup = self.state.clone();
let args = StartArgs {
args: self.args.clone(),
log: self.log.clone(),
opener,
release,
};
if let Some(p) = self.cache.exists(&args.release.commit) {
state.insert(
key.clone(),
VersionState {
socket_path: socket_path.clone(),
downloaded: true,
},
);
tokio::spawn(async move {
Self::start_version(args, p).await;
state_map_dup.lock().unwrap().remove(&key);
});
Ok(socket_path)
} else {
state.insert(
key.clone(),
VersionState {
socket_path,
downloaded: false,
},
);
let update_service = self.update_service.clone();
let cache = self.cache.clone();
tokio::spawn(async move {
Self::download_version(args, update_service.clone(), cache.clone()).await;
state_map_dup.lock().unwrap().remove(&key);
});
Err(CodeError::ServerNotYetDownloaded)
}
}
async fn download_version(
args: StartArgs,
update_service: UpdateService,
cache: DownloadCache,
) {
let release_for_fut = args.release.clone();
let log_for_fut = args.log.clone();
let dir_fut = cache.create(&args.release.commit, |target_dir| async move {
info!(log_for_fut, "Downloading server {}", release_for_fut.commit);
let tmpdir = tempfile::tempdir().unwrap();
let response = update_service.get_download_stream(&release_for_fut).await?;
let name = response.url_path_basename().unwrap();
let archive_path = tmpdir.path().join(name);
http::download_into_file(
&archive_path,
log_for_fut.get_download_logger("Downloading server:"),
response,
)
.await?;
unzip_downloaded_release(&archive_path, &target_dir, SilentCopyProgress())?;
Ok(())
});
match dir_fut.await {
Err(e) => args.opener.open(Err(e.to_string())),
Ok(dir) => Self::start_version(args, dir).await,
}
}
async fn start_version(args: StartArgs, path: PathBuf) {
info!(args.log, "Starting server {}", args.release.commit);
let executable = path
.join("bin")
.join(args.release.quality.server_entrypoint());
let socket_path = get_socket_name();
let mut cmd = new_script_command(&executable);
cmd.stdin(std::process::Stdio::null());
cmd.stderr(std::process::Stdio::piped());
cmd.stdout(std::process::Stdio::piped());
cmd.arg("--socket-path");
cmd.arg(&socket_path);
cmd.args(["--accept-server-license-terms"]);
if let Some(a) = &args.args.server_base_path {
cmd.arg("--server-base-path");
cmd.arg(a);
}
if let Some(a) = &args.args.server_data_dir {
cmd.arg("--server-data-dir");
cmd.arg(a);
}
if args.args.without_connection_token {
cmd.arg("--without-connection-token");
}
if let Some(ct) = &args.args.connection_token_file {
cmd.arg("--connection-token-file");
cmd.arg(ct);
}
cmd.env_remove("VSCODE_DEV");
let mut child = match cmd.spawn() {
Ok(c) => c,
Err(e) => {
args.opener.open(Err(e.to_string()));
return;
}
};
let (mut stdout, mut stderr) = (
BufReader::new(child.stdout.take().unwrap()).lines(),
BufReader::new(child.stderr.take().unwrap()).lines(),
);
let (counter_tx, mut counter_rx) = tokio::sync::watch::channel(0);
let mut opener = Some((args.opener, socket_path, Arc::new(counter_tx)));
let commit_prefix = &args.release.commit[..7];
let kill_timer = tokio::time::sleep(Duration::from_secs(SERVER_IDLE_TIMEOUT_SECS));
pin!(kill_timer);
loop {
tokio::select! {
Ok(Some(l)) = stdout.next_line() => {
info!(args.log, "[{} stdout]: {}", commit_prefix, l);
if l.contains("Server bound to") {
if let Some((opener, path, counter_tx)) = opener.take() {
opener.open(Ok((path, counter_tx)));
}
}
}
Ok(Some(l)) = stderr.next_line() => {
info!(args.log, "[{} stderr]: {}", commit_prefix, l);
},
n = counter_rx.changed() => {
kill_timer.as_mut().reset(match n {
Err(_) => tokio::time::Instant::now(),
Ok(_) => {
if *counter_rx.borrow() == 0 {
tokio::time::Instant::now() + Duration::from_secs(SERVER_IDLE_TIMEOUT_SECS)
} else {
tokio::time::Instant::now() + Duration::from_secs(SERVER_ACTIVE_TIMEOUT_SECS)
}
}
});
}
_ = &mut kill_timer => {
info!(args.log, "[{} process]: idle timeout reached, ending", commit_prefix);
let _ = child.kill().await;
break;
}
e = child.wait() => {
info!(args.log, "[{} process]: exited: {:?}", commit_prefix, e);
break;
}
}
}
}
}
struct StartArgs {
log: log::Logger,
args: ServeWebArgs,
release: Release,
opener: BarrierOpener<Result<StartData, String>>,
}
fn mint_connection_token(path: &Path, prefer_token: Option<String>) -> std::io::Result<String> {
#[cfg(not(windows))]
use std::os::unix::fs::OpenOptionsExt;
let mut f = fs::OpenOptions::new();
f.create(true);
f.write(true);
f.read(true);
#[cfg(not(windows))]
f.mode(0o600);
let mut f = f.open(path)?;
if prefer_token.is_none() {
let mut t = String::new();
f.read_to_string(&mut t)?;
let t = t.trim();
if !t.is_empty() {
return Ok(t.to_string());
}
}
f.set_len(0)?;
let prefer_token = prefer_token.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
f.write_all(prefer_token.as_bytes())?;
Ok(prefer_token)
}