Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/cloud/cloud_writer/internal_writer.rs
8430 views
1
use std::num::NonZeroUsize;
2
3
use futures::StreamExt as _;
4
use futures::stream::FuturesUnordered;
5
use object_store::PutPayload;
6
use polars_error::{PolarsError, PolarsResult};
7
use polars_utils::async_utils::error_capture::{ErrorCapture, ErrorHandle};
8
use polars_utils::async_utils::tokio_handle_ext;
9
10
use crate::cloud::PolarsObjectStore;
11
use crate::cloud::cloud_writer::multipart_upload::PlMultipartUpload;
12
use crate::metrics::OptIOMetrics;
13
14
/// Cloud writer that provides the `put()` function, does not perform any buffering.
15
pub(super) struct InternalCloudWriter {
16
pub(super) store: PolarsObjectStore,
17
pub(super) path: object_store::path::Path,
18
pub(super) max_concurrency: NonZeroUsize,
19
pub(super) io_metrics: OptIOMetrics,
20
pub(super) state: InternalCloudWriterState,
21
}
22
23
pub(super) enum InternalCloudWriterState {
24
NotStarted,
25
Started(StartedState),
26
Finished,
27
}
28
29
type WriterState = InternalCloudWriterState;
30
31
pub(super) struct StartedState {
32
multipart: PlMultipartUpload,
33
tasks: FuturesUnordered<tokio_handle_ext::AbortOnDropHandle<()>>,
34
error_handle: ErrorHandle<PolarsError>,
35
error_capture: ErrorCapture<PolarsError>,
36
}
37
38
impl InternalCloudWriter {
39
pub(super) async fn start(&mut self) -> PolarsResult<()> {
40
if let WriterState::NotStarted = &self.state {
41
let path_ref = &self.path;
42
let multipart = PlMultipartUpload::new(
43
self.store
44
.exec_with_rebuild_retry_on_err(|s| async move {
45
s.put_multipart_opts(path_ref, object_store::PutMultipartOptions::default())
46
.await
47
})
48
.await?,
49
self.store.error_context(),
50
);
51
52
let (error_capture, error_handle) = ErrorCapture::new();
53
54
self.state = WriterState::Started(StartedState {
55
multipart,
56
tasks: FuturesUnordered::new(),
57
error_handle,
58
error_capture,
59
});
60
}
61
62
Ok(())
63
}
64
65
async fn get_or_init_started_state(&mut self) -> PolarsResult<&mut StartedState> {
66
loop {
67
match &self.state {
68
WriterState::Started(_) => {
69
let WriterState::Started(state) = &mut self.state else {
70
unreachable!()
71
};
72
return Ok(state);
73
},
74
WriterState::NotStarted => self.start().await?,
75
WriterState::Finished => panic!(),
76
}
77
}
78
}
79
80
/// Takes `self.state`, replacing with it `Finished`. Returns `None` if `self.state` is not
81
/// `Started`.
82
fn take_started_state(&mut self) -> Option<StartedState> {
83
if !matches!(&self.state, WriterState::Started(_)) {
84
return None;
85
}
86
87
let WriterState::Started(state) = std::mem::replace(&mut self.state, WriterState::Finished)
88
else {
89
unreachable!()
90
};
91
92
Some(state)
93
}
94
95
pub(super) async fn put(&mut self, payload: PutPayload) -> PolarsResult<()> {
96
let io_metrics = self.io_metrics.clone();
97
let max_concurrency = self.max_concurrency.get();
98
99
let state = self.get_or_init_started_state().await?;
100
101
if state.error_handle.has_errored() {
102
let state = self.take_started_state().unwrap();
103
return Err(state.error_handle.join().await.unwrap_err());
104
}
105
106
while state.tasks.len() >= max_concurrency {
107
state.tasks.next().await;
108
}
109
110
let num_bytes = payload.content_length() as u64;
111
let upload_fut = state.multipart.put(payload);
112
113
let fut = async move { io_metrics.record_bytes_tx(num_bytes, upload_fut).await };
114
115
let handle = tokio_handle_ext::AbortOnDropHandle(tokio::spawn(
116
state.error_capture.clone().wrap_future(fut),
117
));
118
119
state.tasks.push(handle);
120
121
Ok(())
122
}
123
124
pub(super) async fn finish(&mut self) -> PolarsResult<()> {
125
let Some(StartedState {
126
mut multipart,
127
tasks,
128
error_handle,
129
error_capture,
130
}) = self.take_started_state()
131
else {
132
return Ok(());
133
};
134
135
drop(error_capture);
136
error_handle.join().await?;
137
138
for handle in tasks {
139
handle.await.unwrap();
140
}
141
142
multipart.finish().await?;
143
144
Ok(())
145
}
146
}
147
148