Path: blob/main/crates/polars-io/src/cloud/cloud_writer/io_trait_wrap.rs
8431 views
use std::pin::Pin;1use std::task::{Poll, ready};23use bytes::Bytes;4use futures::FutureExt;56use crate::cloud::cloud_writer::CloudWriter;7use crate::pl_async;8use crate::utils::file::WriteableTrait;910/// Wrapper on [`CloudWriter`] that implements [`std::io::Write`] and [`tokio::io::AsyncWrite`].11pub struct CloudWriterIoTraitWrap {12state: WriterState,13}1415enum WriterState {16Ready(Box<CloudWriter>),17Poll(18Pin<Box<dyn Future<Output = std::io::Result<WriterState>> + Send + 'static>>,19PollOperation,20),21Finished,22}2324#[derive(Debug, Clone, PartialEq, Eq)]25enum PollOperation {26// (slice_addr, slice_len)27Write { slice_ptr: usize, written: usize },28Flush,29Shutdown,30}3132struct FinishActivePoll<'a>(Pin<&'a mut WriterState>);3334impl<'a> Future for FinishActivePoll<'a> {35type Output = std::io::Result<Option<PollOperation>>;3637fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {38match &mut *self.0 {39WriterState::Poll(fut, _) => match fut.poll_unpin(cx) {40Poll::Ready(Ok(new_state)) => {41debug_assert!(!matches!(&new_state, WriterState::Poll(..)));4243let WriterState::Poll(_, operation) =44std::mem::replace(&mut *self.0, new_state)45else {46unreachable!()47};4849Poll::Ready(Ok(Some(operation)))50},51Poll::Ready(Err(e)) => {52*self.0 = WriterState::Finished;53Poll::Ready(Err(e))54},55Poll::Pending => Poll::Pending,56},5758WriterState::Ready(_) | WriterState::Finished => Poll::Ready(Ok(None)),59}60}61}6263impl CloudWriterIoTraitWrap {64fn finish_active_poll(&mut self) -> FinishActivePoll<'_> {65FinishActivePoll(Pin::new(&mut self.state))66}6768fn take_writer_from_ready_state(&mut self) -> Option<Box<CloudWriter>> {69if !matches!(&self.state, WriterState::Ready(_)) {70return None;71}7273let WriterState::Ready(writer) = std::mem::replace(&mut self.state, WriterState::Finished)74else {75unreachable!()76};7778Some(writer)79}8081fn get_writer_mut_from_ready_state(&mut self) -> Option<&mut CloudWriter> {82if let WriterState::Ready(writer) = &mut self.state {83Some(writer.as_mut())84} else {85None86}87}8889pub async fn write_all_owned(&mut self, bytes: Bytes) -> std::io::Result<()> {90self.finish_active_poll().await?;9192self.get_writer_mut_from_ready_state()93.unwrap()94.write_all_owned(bytes)95.await?;9697Ok(())98}99100pub async fn into_cloud_writer(mut self) -> std::io::Result<CloudWriter> {101self.finish_active_poll().await?;102103match self.state {104WriterState::Ready(writer) => Ok(*writer),105WriterState::Poll(..) => unreachable!(),106WriterState::Finished => panic!(),107}108}109110pub fn as_cloud_writer(&mut self) -> std::io::Result<&mut CloudWriter> {111if !matches!(self.state, WriterState::Ready(_)) {112match &mut self.state {113WriterState::Ready(_) => unreachable!(),114WriterState::Poll(..) => {115pl_async::get_runtime().block_in_place_on(self.finish_active_poll())?116},117WriterState::Finished => panic!(),118};119}120121let WriterState::Ready(writer) = &mut self.state else {122panic!()123};124125Ok(writer)126}127}128129impl From<CloudWriter> for CloudWriterIoTraitWrap {130fn from(writer: CloudWriter) -> Self {131Self {132state: WriterState::Ready(Box::new(writer)),133}134}135}136137impl std::io::Write for CloudWriterIoTraitWrap {138fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {139let total_buf_len = buf.len();140let buf: &mut &[u8] = &mut buf;141142if let Some(writer) = self.get_writer_mut_from_ready_state() {143let full = writer.fill_buffer_from_slice(buf);144145if !full {146assert!(buf.is_empty());147return Ok(total_buf_len);148}149}150151pl_async::get_runtime().block_in_place_on(async {152self.finish_active_poll().await?;153154let writer = self.get_writer_mut_from_ready_state().unwrap();155156loop {157writer.flush_full_chunk().await?;158159if !writer.fill_buffer_from_slice(buf) {160break;161}162}163164assert!(buf.is_empty());165166Ok(total_buf_len)167})168}169170fn flush(&mut self) -> std::io::Result<()> {171if self172.get_writer_mut_from_ready_state()173.is_some_and(|w| !w.has_buffered_bytes())174{175return Ok(());176}177178pl_async::get_runtime().block_in_place_on(async {179self.finish_active_poll().await?;180181self.get_writer_mut_from_ready_state()182.unwrap()183.flush()184.await?;185186Ok(())187})188}189}190191impl WriteableTrait for CloudWriterIoTraitWrap {192fn close(&mut self) -> std::io::Result<()> {193pl_async::get_runtime().block_in_place_on(async {194self.finish_active_poll().await?;195196let mut writer = self.take_writer_from_ready_state().unwrap();197writer.finish().await?;198199Ok(())200})201}202203fn sync_all(&self) -> std::io::Result<()> {204Ok(())205}206207fn sync_data(&self) -> std::io::Result<()> {208Ok(())209}210}211212impl tokio::io::AsyncWrite for CloudWriterIoTraitWrap {213fn poll_write(214mut self: Pin<&mut Self>,215cx: &mut std::task::Context<'_>,216buf: &[u8],217) -> std::task::Poll<std::io::Result<usize>> {218loop {219let offset = match ready!(self.finish_active_poll().poll_unpin(cx))? {220Some(PollOperation::Write { slice_ptr, written })221if slice_ptr == buf.as_ptr() as usize =>222{223written224},225Some(_) => panic!(),226None => 0,227};228229let writer = self.get_writer_mut_from_ready_state().unwrap();230231let offset_buf: &mut &[u8] = &mut &buf[offset..];232233let full = writer.fill_buffer_from_slice(offset_buf);234235if !full {236assert!(offset_buf.is_empty());237return Poll::Ready(Ok(buf.len()));238};239240let new_offset = buf.len() - offset_buf.len();241242let mut writer = self.take_writer_from_ready_state().unwrap();243244self.state = WriterState::Poll(245Box::pin(async move {246writer.flush_full_chunk().await?;247Ok(WriterState::Ready(writer))248}),249PollOperation::Write {250slice_ptr: buf.as_ptr() as usize,251written: new_offset,252},253);254}255}256257fn poll_flush(258mut self: Pin<&mut Self>,259cx: &mut std::task::Context<'_>,260) -> std::task::Poll<std::io::Result<()>> {261loop {262match ready!(self.finish_active_poll().poll_unpin(cx))? {263Some(PollOperation::Flush) => return Poll::Ready(Ok(())),264Some(_) => panic!(),265None => {266let mut writer = self.take_writer_from_ready_state().unwrap();267268self.state = WriterState::Poll(269Box::pin(async move {270writer.flush().await?;271Ok(WriterState::Ready(writer))272}),273PollOperation::Flush,274)275},276}277}278}279280fn poll_shutdown(281mut self: Pin<&mut Self>,282cx: &mut std::task::Context<'_>,283) -> std::task::Poll<std::io::Result<()>> {284loop {285match ready!(self.finish_active_poll().poll_unpin(cx))? {286Some(PollOperation::Shutdown) => return Poll::Ready(Ok(())),287Some(_) => panic!(),288None => {289let mut writer = self.take_writer_from_ready_state().unwrap();290291self.state = WriterState::Poll(292Box::pin(async move {293writer.finish().await?;294Ok(WriterState::Finished)295}),296PollOperation::Shutdown,297);298},299}300}301}302}303304305