Path: blob/main/crates/polars-utils/src/async_utils/error_capture.rs
8395 views
use std::any::Any;1use std::panic::AssertUnwindSafe;23use futures::FutureExt;45/// Utility to capture errors and propagate them to an associated [`ErrorHandle`].6pub struct ErrorCapture<ErrorT> {7tx: tokio::sync::mpsc::Sender<ErrorMessage<ErrorT>>,8}910impl<ErrorT> Clone for ErrorCapture<ErrorT> {11fn clone(&self) -> Self {12Self {13tx: self.tx.clone(),14}15}16}1718impl<ErrorT> ErrorCapture<ErrorT> {19pub fn new() -> (Self, ErrorHandle<ErrorT>) {20let (tx, rx) = tokio::sync::mpsc::channel(1);21(Self { tx }, ErrorHandle { rx })22}2324/// Wraps a future such that its error result is sent to the associated [`ErrorHandle`].25pub async fn wrap_future<F, O>(self, fut: F)26where27F: Future<Output = Result<O, ErrorT>>,28{29let err: Result<(), tokio::sync::mpsc::error::TrySendError<ErrorMessage<ErrorT>>> =30match AssertUnwindSafe(fut).catch_unwind().await {31Ok(Ok(_)) => return,32Ok(Err(err)) => self.tx.try_send(ErrorMessage::Error(err)),33Err(panic) => self.tx.try_send(ErrorMessage::Panic(panic)),34};35drop(err);36}37}3839enum ErrorMessage<ErrorT> {40Error(ErrorT),41Panic(Box<dyn Any + Send + 'static>),42}4344/// Handle to await the completion of multiple tasks. Propagates error results45/// and resumes unwinds when joined.46pub struct ErrorHandle<ErrorT> {47rx: tokio::sync::mpsc::Receiver<ErrorMessage<ErrorT>>,48}4950impl<ErrorT> ErrorHandle<ErrorT> {51pub fn has_errored(&self) -> bool {52!self.rx.is_empty()53}5455/// Block until either an error is received, or all [`ErrorCapture`]s associated with this56/// handle are dropped (i.e. successful completion of all wrapped futures).57///58/// # Panics59/// If a panic is received, this will resume unwinding.60pub async fn join(self) -> Result<(), ErrorT> {61let ErrorHandle { mut rx } = self;6263match rx.recv().await {64None => Ok(()),65Some(ErrorMessage::Error(e)) => Err(e),66Some(ErrorMessage::Panic(panic)) => std::panic::resume_unwind(panic),67}68}69}707172