use std::any::Any;
use std::panic::{UnwindSafe, catch_unwind};
use std::sync::atomic::{AtomicU64, Ordering};
pub struct KeyboardInterrupt;
static POLARS_KEYBOARD_INTERRUPT_STRING: &str = "__POLARS_KEYBOARD_INTERRUPT";
static INTERRUPT_STATE: AtomicU64 = AtomicU64::new(0);
fn is_keyboard_interrupt(p: &dyn Any) -> bool {
if let Some(s) = p.downcast_ref::<&str>() {
s.contains(POLARS_KEYBOARD_INTERRUPT_STRING)
} else if let Some(s) = p.downcast_ref::<String>() {
s.contains(POLARS_KEYBOARD_INTERRUPT_STRING)
} else {
false
}
}
pub fn register_polars_keyboard_interrupt_hook() {
let default_hook = std::panic::take_hook();
std::panic::set_hook(Box::new(move |p| {
let num_catchers = INTERRUPT_STATE.load(Ordering::Relaxed) >> 1;
let suppress = num_catchers > 0 && is_keyboard_interrupt(p.payload());
if !suppress {
default_hook(p);
}
}));
#[cfg(not(target_family = "wasm"))]
unsafe {
signal_hook::low_level::register(signal_hook::consts::signal::SIGINT, move || {
INTERRUPT_STATE
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| {
let num_catchers = state >> 1;
if num_catchers > 0 {
Some(state | 1)
} else {
None
}
})
.ok();
})
.unwrap();
}
}
#[inline(always)]
pub fn try_raise_keyboard_interrupt() {
if INTERRUPT_STATE.load(Ordering::Relaxed) & 1 != 0 {
try_raise_keyboard_interrupt_slow()
}
}
#[inline(never)]
#[cold]
fn try_raise_keyboard_interrupt_slow() {
std::panic::panic_any(POLARS_KEYBOARD_INTERRUPT_STRING);
}
pub fn catch_keyboard_interrupt<R, F: FnOnce() -> R + UnwindSafe>(
try_fn: F,
) -> Result<R, KeyboardInterrupt> {
try_register_catcher()?;
let ret = catch_unwind(try_fn);
unregister_catcher();
ret.map_err(|p| {
if is_keyboard_interrupt(&*p) {
KeyboardInterrupt
} else {
std::panic::resume_unwind(p)
}
})
}
fn try_register_catcher() -> Result<(), KeyboardInterrupt> {
let old_state = INTERRUPT_STATE.fetch_add(2, Ordering::Relaxed);
if old_state & 1 != 0 {
unregister_catcher();
return Err(KeyboardInterrupt);
}
Ok(())
}
fn unregister_catcher() {
INTERRUPT_STATE
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |state| {
let num_catchers = state >> 1;
if num_catchers > 1 {
Some(state - 2)
} else {
Some(0)
}
})
.ok();
}