use std::convert::Infallible;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use ::http::{Request, Response};
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper_util::rt::TokioIo;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::sync::Mutex;
use crate::async_pipe::{get_socket_name, get_socket_rw_stream, AsyncPipe};
use crate::constants::VSCODE_CLI_QUALITY;
use crate::download_cache::DownloadCache;
use crate::log;
use crate::options::Quality;
use crate::update_service::{
unzip_downloaded_release, Platform, Release, TargetKind, UpdateService,
};
use crate::util::command::new_script_command;
use crate::util::errors::CodeError;
use crate::util::http::{self, BoxedHttp};
use crate::util::http::{empty_body, full_body, HyperBody};
use crate::util::io::SilentCopyProgress;
use crate::util::sync::{new_barrier, Barrier, BarrierOpener};
use super::paths::{get_server_folder_name, SERVER_FOLDER_NAME};
pub const UPDATE_CHECK_INTERVAL: Duration = Duration::from_secs(6 * 60 * 60);
pub const UPDATE_POLL_INTERVAL: Duration = Duration::from_secs(10 * 60);
pub const STARTUP_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Clone, Debug)]
pub struct AgentHostConfig {
pub server_data_dir: Option<String>,
pub without_connection_token: bool,
pub connection_token: Option<String>,
pub connection_token_file: Option<String>,
}
struct RunningServer {
child: tokio::process::Child,
commit: String,
}
pub struct AgentHostManager {
log: log::Logger,
config: AgentHostConfig,
platform: Platform,
cache: DownloadCache,
update_service: UpdateService,
latest_release: Mutex<Option<(Instant, Release)>>,
running: Mutex<Option<RunningServer>>,
ready: Mutex<Option<Barrier<Result<PathBuf, String>>>>,
}
impl AgentHostManager {
pub fn new(
log: log::Logger,
platform: Platform,
cache: DownloadCache,
http: BoxedHttp,
config: AgentHostConfig,
) -> Arc<Self> {
Arc::new(Self {
update_service: UpdateService::new(log.clone(), http),
log,
config,
platform,
cache,
latest_release: Mutex::new(None),
running: Mutex::new(None),
ready: Mutex::new(None),
})
}
pub async fn ensure_server(self: &Arc<Self>) -> Result<PathBuf, CodeError> {
{
let ready = self.ready.lock().await;
if let Some(barrier) = &*ready {
if barrier.is_open() {
let running = self.running.lock().await;
if running.is_some() {
return barrier
.clone()
.wait()
.await
.unwrap()
.map_err(CodeError::ServerDownloadError);
}
} else {
let mut barrier = barrier.clone();
drop(ready);
return barrier
.wait()
.await
.unwrap()
.map_err(CodeError::ServerDownloadError);
}
}
}
self.start_server().await
}
async fn start_server(self: &Arc<Self>) -> Result<PathBuf, CodeError> {
let (release, server_dir) = self.get_cached_or_download().await?;
let (mut barrier, opener) = new_barrier::<Result<PathBuf, String>>();
{
let mut ready = self.ready.lock().await;
*ready = Some(barrier.clone());
}
let self_clone = self.clone();
let release_clone = release.clone();
tokio::spawn(async move {
self_clone
.run_server(release_clone, server_dir, opener)
.await;
});
barrier
.wait()
.await
.unwrap()
.map_err(CodeError::ServerDownloadError)
}
async fn run_server(
self: &Arc<Self>,
release: Release,
server_dir: PathBuf,
opener: BarrierOpener<Result<PathBuf, String>>,
) {
let executable = if let Some(p) = option_env!("VSCODE_CLI_OVERRIDE_SERVER_PATH") {
PathBuf::from(p)
} else {
server_dir
.join(SERVER_FOLDER_NAME)
.join("bin")
.join(release.quality.server_entrypoint())
};
let agent_host_socket = get_socket_name();
let mut cmd = new_script_command(&executable);
cmd.stdin(std::process::Stdio::null());
cmd.stderr(std::process::Stdio::piped());
cmd.stdout(std::process::Stdio::piped());
cmd.arg("--socket-path");
cmd.arg(get_socket_name());
cmd.arg("--agent-host-path");
cmd.arg(&agent_host_socket);
cmd.args([
"--start-server",
"--accept-server-license-terms",
"--enable-remote-auto-shutdown",
]);
if let Some(a) = &self.config.server_data_dir {
cmd.arg("--server-data-dir");
cmd.arg(a);
}
if self.config.without_connection_token {
cmd.arg("--without-connection-token");
}
if let Some(ct) = &self.config.connection_token_file {
cmd.arg("--connection-token-file");
cmd.arg(ct);
}
cmd.env_remove("VSCODE_DEV");
let mut child = match cmd.spawn() {
Ok(c) => c,
Err(e) => {
opener.open(Err(e.to_string()));
return;
}
};
let commit_prefix = &release.commit[..release.commit.len().min(7)];
let (mut stdout, mut stderr) = (
BufReader::new(child.stdout.take().unwrap()).lines(),
BufReader::new(child.stderr.take().unwrap()).lines(),
);
let mut opener = Some(opener);
let socket_path = agent_host_socket.clone();
let startup_deadline = tokio::time::sleep(STARTUP_TIMEOUT);
tokio::pin!(startup_deadline);
let mut ready = false;
loop {
tokio::select! {
Ok(Some(l)) = stdout.next_line() => {
debug!(self.log, "[{} stdout]: {}", commit_prefix, l);
if !ready && l.contains("Agent host server listening on") {
ready = true;
if let Some(o) = opener.take() {
o.open(Ok(socket_path.clone()));
}
}
}
Ok(Some(l)) = stderr.next_line() => {
debug!(self.log, "[{} stderr]: {}", commit_prefix, l);
}
_ = &mut startup_deadline, if !ready => {
warning!(self.log, "[{}]: Server did not become ready within {}s", commit_prefix, STARTUP_TIMEOUT.as_secs());
if let Some(o) = opener.take() {
o.open(Ok(socket_path.clone()));
}
ready = true;
}
e = child.wait() => {
info!(self.log, "[{} process]: exited: {:?}", commit_prefix, e);
if let Some(o) = opener.take() {
o.open(Err(format!("Server exited before ready: {e:?}")));
}
return;
}
}
if ready {
break;
}
}
{
let mut running = self.running.lock().await;
*running = Some(RunningServer {
child,
commit: release.commit.clone(),
});
}
info!(self.log, "[{}]: Server ready", commit_prefix);
let log = self.log.clone();
let commit_prefix = commit_prefix.to_string();
let self_clone = self.clone();
tokio::spawn(async move {
loop {
tokio::select! {
Ok(Some(l)) = stdout.next_line() => {
debug!(log, "[{} stdout]: {}", commit_prefix, l);
}
Ok(Some(l)) = stderr.next_line() => {
debug!(log, "[{} stderr]: {}", commit_prefix, l);
}
else => break,
}
}
info!(log, "[{}]: Server process ended", commit_prefix);
let mut running = self_clone.running.lock().await;
if let Some(r) = &*running {
if r.commit == commit_prefix || r.commit.starts_with(&commit_prefix) {
*running = None;
}
}
});
}
async fn get_cached_or_download(&self) -> Result<(Release, PathBuf), CodeError> {
if option_env!("VSCODE_CLI_OVERRIDE_SERVER_PATH").is_some() {
let release = Release {
name: String::new(),
commit: String::from("dev"),
platform: self.platform,
target: TargetKind::Server,
quality: Quality::Insiders,
};
return Ok((release, PathBuf::new()));
}
if let Some((_, release)) = &*self.latest_release.lock().await {
let name = get_server_folder_name(release.quality, &release.commit);
if let Some(dir) = self.cache.exists(&name) {
return Ok((release.clone(), dir));
}
}
let quality = VSCODE_CLI_QUALITY
.ok_or(CodeError::UpdatesNotConfigured("no configured quality"))
.and_then(|q| {
Quality::try_from(q).map_err(|_| CodeError::UpdatesNotConfigured("unknown quality"))
})?;
for entry in self.cache.get() {
if let Some(dir) = self.cache.exists(&entry) {
let (entry_quality, commit) = match entry.split_once('-') {
Some((q, c)) => match Quality::try_from(q.to_lowercase().as_str()) {
Ok(parsed) => (parsed, c.to_string()),
Err(_) => (quality, entry.clone()),
},
None => (quality, entry.clone()),
};
let release = Release {
name: String::new(),
commit,
platform: self.platform,
target: TargetKind::Server,
quality: entry_quality,
};
return Ok((release, dir));
}
}
info!(self.log, "No cached server version, downloading latest...");
let release = self.get_latest_release().await?;
let dir = self.ensure_downloaded(&release).await?;
Ok((release, dir))
}
pub async fn ensure_downloaded(&self, release: &Release) -> Result<PathBuf, CodeError> {
let cache_name = get_server_folder_name(release.quality, &release.commit);
if let Some(dir) = self.cache.exists(&cache_name) {
return Ok(dir);
}
info!(self.log, "Downloading server {}", release.commit);
let release = release.clone();
let log = self.log.clone();
let update_service = self.update_service.clone();
self.cache
.create(&cache_name, |target_dir| async move {
let tmpdir = tempfile::tempdir().unwrap();
let response = update_service.get_download_stream(&release).await?;
let name = response.url_path_basename().unwrap();
let archive_path = tmpdir.path().join(name);
http::download_into_file(
&archive_path,
log.get_download_logger("Downloading server:"),
response,
)
.await?;
let server_dir = target_dir.join(SERVER_FOLDER_NAME);
unzip_downloaded_release(&archive_path, &server_dir, SilentCopyProgress())?;
Ok(())
})
.await
.map_err(|e| CodeError::ServerDownloadError(e.to_string()))
}
pub async fn get_latest_release(&self) -> Result<Release, CodeError> {
let mut latest = self.latest_release.lock().await;
let now = Instant::now();
let quality = VSCODE_CLI_QUALITY
.ok_or(CodeError::UpdatesNotConfigured("no configured quality"))
.and_then(|q| {
Quality::try_from(q).map_err(|_| CodeError::UpdatesNotConfigured("unknown quality"))
})?;
let result = self
.update_service
.get_latest_commit(self.platform, TargetKind::Server, quality)
.await
.map_err(|e| CodeError::UpdateCheckFailed(e.to_string()));
if let (Err(e), Some((_, previous))) = (&result, latest.clone()) {
warning!(self.log, "Error checking for updates, using cached: {}", e);
*latest = Some((now, previous.clone()));
return Ok(previous);
}
let release = result?;
debug!(self.log, "Resolved server version: {}", release);
*latest = Some((now, release.clone()));
Ok(release)
}
pub async fn run_update_loop(self: Arc<Self>) {
let mut interval = tokio::time::interval(UPDATE_CHECK_INTERVAL);
interval.tick().await;
loop {
interval.tick().await;
let new_release = match self.get_latest_release().await {
Ok(r) => r,
Err(e) => {
warning!(self.log, "Update check failed: {}", e);
continue;
}
};
let name = get_server_folder_name(new_release.quality, &new_release.commit);
if self.cache.exists(&name).is_some() {
continue;
}
info!(self.log, "New server version available: {}", new_release);
loop {
{
let running = self.running.lock().await;
if running.is_none() {
break;
}
}
debug!(self.log, "Server still running, waiting before updating...");
tokio::time::sleep(UPDATE_POLL_INTERVAL).await;
}
match self.ensure_downloaded(&new_release).await {
Ok(_) => info!(self.log, "Updated server to {}", new_release),
Err(e) => warning!(self.log, "Failed to download update: {}", e),
}
}
}
pub async fn kill_running_server(&self) {
let mut running = self.running.lock().await;
if let Some(mut server) = running.take() {
let _ = server.child.kill().await;
}
}
}
pub async fn handle_request(
manager: Arc<AgentHostManager>,
req: Request<Incoming>,
) -> Result<Response<HyperBody>, Infallible> {
let socket_path = match manager.ensure_server().await {
Ok(p) => p,
Err(e) => {
error!(manager.log, "Error starting agent host: {:?}", e);
return Ok(Response::builder()
.status(503)
.body(full_body(format!("Error starting agent host: {e:?}")))
.unwrap());
}
};
let is_upgrade = req.headers().contains_key(::http::header::UPGRADE);
let rw = match get_socket_rw_stream(&socket_path).await {
Ok(rw) => rw,
Err(e) => {
error!(
manager.log,
"Error connecting to agent host socket: {:?}", e
);
return Ok(Response::builder()
.status(503)
.body(full_body(format!("Error connecting to agent host: {e:?}")))
.unwrap());
}
};
if is_upgrade {
Ok(forward_ws_to_server(manager.log.clone(), rw, req).await)
} else {
Ok(forward_http_to_server(rw, req).await)
}
}
async fn forward_http_to_server(rw: AsyncPipe, req: Request<Incoming>) -> Response<HyperBody> {
let (mut request_sender, connection) =
match hyper::client::conn::http1::handshake(TokioIo::new(rw)).await {
Ok(r) => r,
Err(e) => return connection_err(e),
};
tokio::spawn(connection);
match request_sender.send_request(req).await {
Ok(res) => res.map(|b| b.boxed()),
Err(e) => connection_err(e),
}
}
async fn forward_ws_to_server(
log: log::Logger,
rw: AsyncPipe,
mut req: Request<Incoming>,
) -> Response<HyperBody> {
let (mut request_sender, connection) =
match hyper::client::conn::http1::handshake(TokioIo::new(rw)).await {
Ok(r) => r,
Err(e) => return connection_err(e),
};
tokio::spawn(connection.with_upgrades());
let mut proxied_req = Request::builder().uri(req.uri());
for (k, v) in req.headers() {
proxied_req = proxied_req.header(k, v);
}
let mut res = match request_sender
.send_request(
proxied_req
.body(http_body_util::Empty::<bytes::Bytes>::new())
.unwrap(),
)
.await
{
Ok(r) => r,
Err(e) => return connection_err(e),
};
let mut proxied_res = Response::new(empty_body());
*proxied_res.status_mut() = res.status();
for (k, v) in res.headers() {
proxied_res.headers_mut().insert(k, v.clone());
}
if res.status() == ::http::StatusCode::SWITCHING_PROTOCOLS {
tokio::spawn(async move {
let (s_req, s_res) =
tokio::join!(hyper::upgrade::on(&mut req), hyper::upgrade::on(&mut res));
match (s_req, s_res) {
(Ok(s_req), Ok(s_res)) => {
let mut s_req = TokioIo::new(s_req);
let mut s_res = TokioIo::new(s_res);
if let Err(e) = tokio::io::copy_bidirectional(&mut s_req, &mut s_res).await {
debug!(log, "Agent host WebSocket proxy ended with error: {:?}", e);
}
}
(Err(e), _) => {
warning!(
log,
"Agent host client-side WebSocket upgrade failed: {:?}",
e
);
}
(_, Err(e)) => {
warning!(
log,
"Agent host server-side WebSocket upgrade failed: {:?}",
e
);
}
}
});
}
proxied_res
}
fn connection_err(err: hyper::Error) -> Response<HyperBody> {
Response::builder()
.status(503)
.body(full_body(format!(
"Error connecting to agent host: {err:?}"
)))
.unwrap()
}