Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/cloud/cloud_writer/io_trait_wrap.rs
8431 views
1
use std::pin::Pin;
2
use std::task::{Poll, ready};
3
4
use bytes::Bytes;
5
use futures::FutureExt;
6
7
use crate::cloud::cloud_writer::CloudWriter;
8
use crate::pl_async;
9
use crate::utils::file::WriteableTrait;
10
11
/// Wrapper on [`CloudWriter`] that implements [`std::io::Write`] and [`tokio::io::AsyncWrite`].
12
pub struct CloudWriterIoTraitWrap {
13
state: WriterState,
14
}
15
16
enum WriterState {
17
Ready(Box<CloudWriter>),
18
Poll(
19
Pin<Box<dyn Future<Output = std::io::Result<WriterState>> + Send + 'static>>,
20
PollOperation,
21
),
22
Finished,
23
}
24
25
#[derive(Debug, Clone, PartialEq, Eq)]
26
enum PollOperation {
27
// (slice_addr, slice_len)
28
Write { slice_ptr: usize, written: usize },
29
Flush,
30
Shutdown,
31
}
32
33
struct FinishActivePoll<'a>(Pin<&'a mut WriterState>);
34
35
impl<'a> Future for FinishActivePoll<'a> {
36
type Output = std::io::Result<Option<PollOperation>>;
37
38
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
39
match &mut *self.0 {
40
WriterState::Poll(fut, _) => match fut.poll_unpin(cx) {
41
Poll::Ready(Ok(new_state)) => {
42
debug_assert!(!matches!(&new_state, WriterState::Poll(..)));
43
44
let WriterState::Poll(_, operation) =
45
std::mem::replace(&mut *self.0, new_state)
46
else {
47
unreachable!()
48
};
49
50
Poll::Ready(Ok(Some(operation)))
51
},
52
Poll::Ready(Err(e)) => {
53
*self.0 = WriterState::Finished;
54
Poll::Ready(Err(e))
55
},
56
Poll::Pending => Poll::Pending,
57
},
58
59
WriterState::Ready(_) | WriterState::Finished => Poll::Ready(Ok(None)),
60
}
61
}
62
}
63
64
impl CloudWriterIoTraitWrap {
65
fn finish_active_poll(&mut self) -> FinishActivePoll<'_> {
66
FinishActivePoll(Pin::new(&mut self.state))
67
}
68
69
fn take_writer_from_ready_state(&mut self) -> Option<Box<CloudWriter>> {
70
if !matches!(&self.state, WriterState::Ready(_)) {
71
return None;
72
}
73
74
let WriterState::Ready(writer) = std::mem::replace(&mut self.state, WriterState::Finished)
75
else {
76
unreachable!()
77
};
78
79
Some(writer)
80
}
81
82
fn get_writer_mut_from_ready_state(&mut self) -> Option<&mut CloudWriter> {
83
if let WriterState::Ready(writer) = &mut self.state {
84
Some(writer.as_mut())
85
} else {
86
None
87
}
88
}
89
90
pub async fn write_all_owned(&mut self, bytes: Bytes) -> std::io::Result<()> {
91
self.finish_active_poll().await?;
92
93
self.get_writer_mut_from_ready_state()
94
.unwrap()
95
.write_all_owned(bytes)
96
.await?;
97
98
Ok(())
99
}
100
101
pub async fn into_cloud_writer(mut self) -> std::io::Result<CloudWriter> {
102
self.finish_active_poll().await?;
103
104
match self.state {
105
WriterState::Ready(writer) => Ok(*writer),
106
WriterState::Poll(..) => unreachable!(),
107
WriterState::Finished => panic!(),
108
}
109
}
110
111
pub fn as_cloud_writer(&mut self) -> std::io::Result<&mut CloudWriter> {
112
if !matches!(self.state, WriterState::Ready(_)) {
113
match &mut self.state {
114
WriterState::Ready(_) => unreachable!(),
115
WriterState::Poll(..) => {
116
pl_async::get_runtime().block_in_place_on(self.finish_active_poll())?
117
},
118
WriterState::Finished => panic!(),
119
};
120
}
121
122
let WriterState::Ready(writer) = &mut self.state else {
123
panic!()
124
};
125
126
Ok(writer)
127
}
128
}
129
130
impl From<CloudWriter> for CloudWriterIoTraitWrap {
131
fn from(writer: CloudWriter) -> Self {
132
Self {
133
state: WriterState::Ready(Box::new(writer)),
134
}
135
}
136
}
137
138
impl std::io::Write for CloudWriterIoTraitWrap {
139
fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
140
let total_buf_len = buf.len();
141
let buf: &mut &[u8] = &mut buf;
142
143
if let Some(writer) = self.get_writer_mut_from_ready_state() {
144
let full = writer.fill_buffer_from_slice(buf);
145
146
if !full {
147
assert!(buf.is_empty());
148
return Ok(total_buf_len);
149
}
150
}
151
152
pl_async::get_runtime().block_in_place_on(async {
153
self.finish_active_poll().await?;
154
155
let writer = self.get_writer_mut_from_ready_state().unwrap();
156
157
loop {
158
writer.flush_full_chunk().await?;
159
160
if !writer.fill_buffer_from_slice(buf) {
161
break;
162
}
163
}
164
165
assert!(buf.is_empty());
166
167
Ok(total_buf_len)
168
})
169
}
170
171
fn flush(&mut self) -> std::io::Result<()> {
172
if self
173
.get_writer_mut_from_ready_state()
174
.is_some_and(|w| !w.has_buffered_bytes())
175
{
176
return Ok(());
177
}
178
179
pl_async::get_runtime().block_in_place_on(async {
180
self.finish_active_poll().await?;
181
182
self.get_writer_mut_from_ready_state()
183
.unwrap()
184
.flush()
185
.await?;
186
187
Ok(())
188
})
189
}
190
}
191
192
impl WriteableTrait for CloudWriterIoTraitWrap {
193
fn close(&mut self) -> std::io::Result<()> {
194
pl_async::get_runtime().block_in_place_on(async {
195
self.finish_active_poll().await?;
196
197
let mut writer = self.take_writer_from_ready_state().unwrap();
198
writer.finish().await?;
199
200
Ok(())
201
})
202
}
203
204
fn sync_all(&self) -> std::io::Result<()> {
205
Ok(())
206
}
207
208
fn sync_data(&self) -> std::io::Result<()> {
209
Ok(())
210
}
211
}
212
213
impl tokio::io::AsyncWrite for CloudWriterIoTraitWrap {
214
fn poll_write(
215
mut self: Pin<&mut Self>,
216
cx: &mut std::task::Context<'_>,
217
buf: &[u8],
218
) -> std::task::Poll<std::io::Result<usize>> {
219
loop {
220
let offset = match ready!(self.finish_active_poll().poll_unpin(cx))? {
221
Some(PollOperation::Write { slice_ptr, written })
222
if slice_ptr == buf.as_ptr() as usize =>
223
{
224
written
225
},
226
Some(_) => panic!(),
227
None => 0,
228
};
229
230
let writer = self.get_writer_mut_from_ready_state().unwrap();
231
232
let offset_buf: &mut &[u8] = &mut &buf[offset..];
233
234
let full = writer.fill_buffer_from_slice(offset_buf);
235
236
if !full {
237
assert!(offset_buf.is_empty());
238
return Poll::Ready(Ok(buf.len()));
239
};
240
241
let new_offset = buf.len() - offset_buf.len();
242
243
let mut writer = self.take_writer_from_ready_state().unwrap();
244
245
self.state = WriterState::Poll(
246
Box::pin(async move {
247
writer.flush_full_chunk().await?;
248
Ok(WriterState::Ready(writer))
249
}),
250
PollOperation::Write {
251
slice_ptr: buf.as_ptr() as usize,
252
written: new_offset,
253
},
254
);
255
}
256
}
257
258
fn poll_flush(
259
mut self: Pin<&mut Self>,
260
cx: &mut std::task::Context<'_>,
261
) -> std::task::Poll<std::io::Result<()>> {
262
loop {
263
match ready!(self.finish_active_poll().poll_unpin(cx))? {
264
Some(PollOperation::Flush) => return Poll::Ready(Ok(())),
265
Some(_) => panic!(),
266
None => {
267
let mut writer = self.take_writer_from_ready_state().unwrap();
268
269
self.state = WriterState::Poll(
270
Box::pin(async move {
271
writer.flush().await?;
272
Ok(WriterState::Ready(writer))
273
}),
274
PollOperation::Flush,
275
)
276
},
277
}
278
}
279
}
280
281
fn poll_shutdown(
282
mut self: Pin<&mut Self>,
283
cx: &mut std::task::Context<'_>,
284
) -> std::task::Poll<std::io::Result<()>> {
285
loop {
286
match ready!(self.finish_active_poll().poll_unpin(cx))? {
287
Some(PollOperation::Shutdown) => return Poll::Ready(Ok(())),
288
Some(_) => panic!(),
289
None => {
290
let mut writer = self.take_writer_from_ready_state().unwrap();
291
292
self.state = WriterState::Poll(
293
Box::pin(async move {
294
writer.finish().await?;
295
Ok(WriterState::Finished)
296
}),
297
PollOperation::Shutdown,
298
);
299
},
300
}
301
}
302
}
303
}
304
305