use serde::{Deserialize, Serialize};
use std::{
fs::{File, OpenOptions},
io::{Seek, SeekFrom, Write},
path::{Path, PathBuf},
time::Duration,
};
use sysinfo::{Pid, PidExt};
use crate::{
async_pipe::{
get_socket_name, get_socket_rw_stream, listen_socket_rw_stream, AsyncPipe,
AsyncPipeListener,
},
util::{
errors::CodeError,
file_lock::{FileLock, Lock, PREFIX_LOCKED_BYTES},
machine::wait_until_process_exits,
},
};
pub struct SingletonServer {
server: AsyncPipeListener,
_lock: FileLock,
}
impl SingletonServer {
pub async fn accept(&mut self) -> Result<AsyncPipe, CodeError> {
self.server.accept().await
}
}
pub enum SingletonConnection {
Singleton(SingletonServer),
Client(AsyncPipe),
}
#[derive(Deserialize, Serialize)]
struct LockFileMatter {
socket_path: String,
pid: u32,
}
pub async fn acquire_singleton(lock_file: &Path) -> Result<SingletonConnection, CodeError> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(lock_file)
.map_err(CodeError::SingletonLockfileOpenFailed)?;
match FileLock::acquire(file) {
Ok(Lock::AlreadyLocked(mut file)) => connect_as_client_with_file(&mut file)
.await
.map(SingletonConnection::Client),
Ok(Lock::Acquired(lock)) => start_singleton_server(lock)
.await
.map(SingletonConnection::Singleton),
Err(e) => Err(e),
}
}
pub async fn connect_as_client(lock_file: &Path) -> Result<AsyncPipe, CodeError> {
let mut file = OpenOptions::new()
.read(true)
.open(lock_file)
.map_err(CodeError::SingletonLockfileOpenFailed)?;
connect_as_client_with_file(&mut file).await
}
async fn start_singleton_server(mut lock: FileLock) -> Result<SingletonServer, CodeError> {
let socket_path = get_socket_name();
let mut vec = Vec::with_capacity(128);
let _ = vec.write(&[0; PREFIX_LOCKED_BYTES]);
let _ = rmp_serde::encode::write(
&mut vec,
&LockFileMatter {
socket_path: socket_path.to_string_lossy().to_string(),
pid: std::process::id(),
},
);
lock.file_mut()
.write_all(&vec)
.map_err(CodeError::SingletonLockfileOpenFailed)?;
let server = listen_socket_rw_stream(&socket_path).await?;
Ok(SingletonServer {
server,
_lock: lock,
})
}
const MAX_CLIENT_ATTEMPTS: i32 = 10;
async fn connect_as_client_with_file(mut file: &mut File) -> Result<AsyncPipe, CodeError> {
let mut attempt = 0;
loop {
let _ = file.seek(SeekFrom::Start(PREFIX_LOCKED_BYTES as u64));
let r = match rmp_serde::from_read::<_, LockFileMatter>(&mut file) {
Ok(prev) => {
let socket_path = PathBuf::from(prev.socket_path);
tokio::select! {
p = retry_get_socket_rw_stream(&socket_path, 5, Duration::from_millis(500)) => p,
_ = wait_until_process_exits(Pid::from_u32(prev.pid), 500) => return Err(CodeError::SingletonLockedProcessExited(prev.pid)),
}
}
Err(e) => Err(CodeError::SingletonLockfileReadFailed(e)),
};
if r.is_ok() || attempt == MAX_CLIENT_ATTEMPTS {
return r;
}
attempt += 1;
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
async fn retry_get_socket_rw_stream(
path: &Path,
max_tries: usize,
interval: Duration,
) -> Result<AsyncPipe, CodeError> {
for i in 0.. {
match get_socket_rw_stream(path).await {
Ok(s) => return Ok(s),
Err(e) if i == max_tries => return Err(e),
Err(_) => tokio::time::sleep(interval).await,
}
}
unreachable!()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_acquires_singleton() {
let dir = tempfile::tempdir().expect("expected to make temp dir");
let s = acquire_singleton(&dir.path().join("lock"))
.await
.expect("expected to acquire");
match s {
SingletonConnection::Singleton(_) => {}
_ => panic!("expected to be singleton"),
}
}
#[tokio::test]
async fn test_acquires_client() {
let dir = tempfile::tempdir().expect("expected to make temp dir");
let lockfile = dir.path().join("lock");
let s1 = acquire_singleton(&lockfile)
.await
.expect("expected to acquire1");
match s1 {
SingletonConnection::Singleton(mut l) => tokio::spawn(async move {
l.accept().await.expect("expected to accept");
}),
_ => panic!("expected to be singleton"),
};
let s2 = acquire_singleton(&lockfile)
.await
.expect("expected to acquire2");
match s2 {
SingletonConnection::Client(_) => {}
_ => panic!("expected to be client"),
}
}
}