Path: blob/main/crates/polars-io/src/cloud/cloud_writer/internal_writer.rs
8430 views
use std::num::NonZeroUsize;12use futures::StreamExt as _;3use futures::stream::FuturesUnordered;4use object_store::PutPayload;5use polars_error::{PolarsError, PolarsResult};6use polars_utils::async_utils::error_capture::{ErrorCapture, ErrorHandle};7use polars_utils::async_utils::tokio_handle_ext;89use crate::cloud::PolarsObjectStore;10use crate::cloud::cloud_writer::multipart_upload::PlMultipartUpload;11use crate::metrics::OptIOMetrics;1213/// Cloud writer that provides the `put()` function, does not perform any buffering.14pub(super) struct InternalCloudWriter {15pub(super) store: PolarsObjectStore,16pub(super) path: object_store::path::Path,17pub(super) max_concurrency: NonZeroUsize,18pub(super) io_metrics: OptIOMetrics,19pub(super) state: InternalCloudWriterState,20}2122pub(super) enum InternalCloudWriterState {23NotStarted,24Started(StartedState),25Finished,26}2728type WriterState = InternalCloudWriterState;2930pub(super) struct StartedState {31multipart: PlMultipartUpload,32tasks: FuturesUnordered<tokio_handle_ext::AbortOnDropHandle<()>>,33error_handle: ErrorHandle<PolarsError>,34error_capture: ErrorCapture<PolarsError>,35}3637impl InternalCloudWriter {38pub(super) async fn start(&mut self) -> PolarsResult<()> {39if let WriterState::NotStarted = &self.state {40let path_ref = &self.path;41let multipart = PlMultipartUpload::new(42self.store43.exec_with_rebuild_retry_on_err(|s| async move {44s.put_multipart_opts(path_ref, object_store::PutMultipartOptions::default())45.await46})47.await?,48self.store.error_context(),49);5051let (error_capture, error_handle) = ErrorCapture::new();5253self.state = WriterState::Started(StartedState {54multipart,55tasks: FuturesUnordered::new(),56error_handle,57error_capture,58});59}6061Ok(())62}6364async fn get_or_init_started_state(&mut self) -> PolarsResult<&mut StartedState> {65loop {66match &self.state {67WriterState::Started(_) => {68let WriterState::Started(state) = &mut self.state else {69unreachable!()70};71return Ok(state);72},73WriterState::NotStarted => self.start().await?,74WriterState::Finished => panic!(),75}76}77}7879/// Takes `self.state`, replacing with it `Finished`. Returns `None` if `self.state` is not80/// `Started`.81fn take_started_state(&mut self) -> Option<StartedState> {82if !matches!(&self.state, WriterState::Started(_)) {83return None;84}8586let WriterState::Started(state) = std::mem::replace(&mut self.state, WriterState::Finished)87else {88unreachable!()89};9091Some(state)92}9394pub(super) async fn put(&mut self, payload: PutPayload) -> PolarsResult<()> {95let io_metrics = self.io_metrics.clone();96let max_concurrency = self.max_concurrency.get();9798let state = self.get_or_init_started_state().await?;99100if state.error_handle.has_errored() {101let state = self.take_started_state().unwrap();102return Err(state.error_handle.join().await.unwrap_err());103}104105while state.tasks.len() >= max_concurrency {106state.tasks.next().await;107}108109let num_bytes = payload.content_length() as u64;110let upload_fut = state.multipart.put(payload);111112let fut = async move { io_metrics.record_bytes_tx(num_bytes, upload_fut).await };113114let handle = tokio_handle_ext::AbortOnDropHandle(tokio::spawn(115state.error_capture.clone().wrap_future(fut),116));117118state.tasks.push(handle);119120Ok(())121}122123pub(super) async fn finish(&mut self) -> PolarsResult<()> {124let Some(StartedState {125mut multipart,126tasks,127error_handle,128error_capture,129}) = self.take_started_state()130else {131return Ok(());132};133134drop(error_capture);135error_handle.join().await?;136137for handle in tasks {138handle.await.unwrap();139}140141multipart.finish().await?;142143Ok(())144}145}146147148