Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/io/ipc/read/flight.rs
8446 views
1
use std::io::SeekFrom;
2
use std::pin::Pin;
3
use std::sync::Arc;
4
5
use arrow_format::ipc::planus::ReadAsRoot;
6
use arrow_format::ipc::{Block, FooterRef, MessageHeaderRef};
7
use futures::{Stream, StreamExt};
8
use polars_error::{PolarsResult, polars_bail, polars_err};
9
use polars_utils::bool::UnsafeBool;
10
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
11
12
use crate::datatypes::ArrowSchema;
13
use crate::io::ipc::read::common::read_record_batch;
14
use crate::io::ipc::read::file::{
15
decode_footer_len, deserialize_schema_ref_from_footer, iter_dictionary_blocks_from_footer,
16
iter_recordbatch_blocks_from_footer,
17
};
18
use crate::io::ipc::read::schema::deserialize_stream_metadata;
19
use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, SendableIterator, StreamMetadata};
20
use crate::io::ipc::write::common::EncodedData;
21
use crate::mmap::{mmap_dictionary_from_batch, mmap_record};
22
use crate::record_batch::RecordBatch;
23
24
async fn read_ipc_message_from_block<'a, R: AsyncRead + AsyncSeek + Unpin>(
25
reader: &mut R,
26
block: &arrow_format::ipc::Block,
27
scratch: &'a mut Vec<u8>,
28
) -> PolarsResult<arrow_format::ipc::MessageRef<'a>> {
29
let offset: u64 = block
30
.offset
31
.try_into()
32
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
33
reader.seek(SeekFrom::Start(offset)).await?;
34
read_ipc_message(reader, scratch).await
35
}
36
37
/// Read an encapsulated IPC Message from the reader
38
async fn read_ipc_message<'a, R: AsyncRead + Unpin>(
39
reader: &mut R,
40
scratch: &'a mut Vec<u8>,
41
) -> PolarsResult<arrow_format::ipc::MessageRef<'a>> {
42
let mut message_size: [u8; 4] = [0; 4];
43
44
reader.read_exact(&mut message_size).await?;
45
if message_size == crate::io::ipc::CONTINUATION_MARKER {
46
reader.read_exact(&mut message_size).await?;
47
};
48
let message_length = i32::from_le_bytes(message_size);
49
50
let message_length: usize = message_length
51
.try_into()
52
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
53
54
scratch.clear();
55
scratch.try_reserve(message_length)?;
56
reader
57
.take(message_length as u64)
58
.read_to_end(scratch)
59
.await?;
60
61
arrow_format::ipc::MessageRef::read_as_root(scratch)
62
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))
63
}
64
65
async fn read_footer_len<R: AsyncRead + AsyncSeek + Unpin>(
66
reader: &mut R,
67
) -> PolarsResult<(u64, usize)> {
68
// read footer length and magic number in footer
69
let end = reader.seek(SeekFrom::End(-10)).await? + 10;
70
71
let mut footer: [u8; 10] = [0; 10];
72
reader.read_exact(&mut footer).await?;
73
74
decode_footer_len(footer, end)
75
}
76
77
async fn read_footer<R: AsyncRead + AsyncSeek + Unpin>(
78
reader: &mut R,
79
footer_len: usize,
80
) -> PolarsResult<Vec<u8>> {
81
// read footer
82
reader.seek(SeekFrom::End(-10 - footer_len as i64)).await?;
83
84
let mut serialized_footer = vec![];
85
serialized_footer.try_reserve(footer_len)?;
86
87
reader
88
.take(footer_len as u64)
89
.read_to_end(&mut serialized_footer)
90
.await?;
91
Ok(serialized_footer)
92
}
93
94
fn schema_to_raw_message(schema: arrow_format::ipc::SchemaRef) -> EncodedData {
95
// Turn the IPC schema into an encapsulated message
96
let message = arrow_format::ipc::Message {
97
version: arrow_format::ipc::MetadataVersion::V5,
98
// Assumed the conversion is infallible.
99
header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new(
100
schema.try_into().unwrap(),
101
))),
102
body_length: 0,
103
custom_metadata: None, // todo: allow writing custom metadata
104
};
105
let mut builder = arrow_format::ipc::planus::Builder::new();
106
let header = builder.finish(&message, None).to_vec();
107
108
// Use `EncodedData` directly instead of `FlightData`. In FlightData we would only use
109
// `data_header` and `data_body`.
110
EncodedData {
111
ipc_message: header,
112
arrow_data: vec![],
113
}
114
}
115
116
async fn block_to_raw_message<'a, R>(
117
reader: &mut R,
118
block: &arrow_format::ipc::Block,
119
encoded_data: &mut EncodedData,
120
) -> PolarsResult<()>
121
where
122
R: AsyncRead + AsyncSeek + Unpin + Send + 'a,
123
{
124
debug_assert!(encoded_data.arrow_data.is_empty() && encoded_data.ipc_message.is_empty());
125
let message = read_ipc_message_from_block(reader, block, &mut encoded_data.ipc_message).await?;
126
127
let block_length: u64 = message
128
.body_length()
129
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))?
130
.try_into()
131
.map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?;
132
reader
133
.take(block_length)
134
.read_to_end(&mut encoded_data.arrow_data)
135
.await?;
136
137
Ok(())
138
}
139
140
pub async fn into_flight_stream<R: AsyncRead + AsyncSeek + Unpin + Send>(
141
reader: &mut R,
142
) -> PolarsResult<impl Stream<Item = PolarsResult<EncodedData>> + '_> {
143
Ok(async_stream::try_stream! {
144
let (_end, len) = read_footer_len(reader).await?;
145
let footer_data = read_footer(reader, len).await?;
146
let footer = arrow_format::ipc::FooterRef::read_as_root(&footer_data)
147
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?;
148
let data_blocks = iter_recordbatch_blocks_from_footer(footer)?;
149
let dict_blocks = iter_dictionary_blocks_from_footer(footer)?;
150
151
let schema_ref = deserialize_schema_ref_from_footer(footer)?;
152
let schema = schema_to_raw_message(schema_ref);
153
154
yield schema;
155
156
if let Some(dict_blocks_iter) = dict_blocks {
157
for d in dict_blocks_iter {
158
let mut ed: EncodedData = Default::default();
159
block_to_raw_message(reader, &d?, &mut ed).await?;
160
yield ed
161
}
162
};
163
164
for d in data_blocks {
165
let mut ed: EncodedData = Default::default();
166
block_to_raw_message(reader, &d?, &mut ed).await?;
167
yield ed
168
}
169
})
170
}
171
172
pub struct FlightStreamProducer<'a, R: AsyncRead + AsyncSeek + Unpin + Send> {
173
footer: Option<*const FooterRef<'static>>,
174
footer_data: Vec<u8>,
175
dict_blocks: Option<Box<dyn SendableIterator<Item = PolarsResult<Block>>>>,
176
data_blocks: Option<Box<dyn SendableIterator<Item = PolarsResult<Block>>>>,
177
reader: &'a mut R,
178
}
179
180
impl<R: AsyncRead + AsyncSeek + Unpin + Send> Drop for FlightStreamProducer<'_, R> {
181
fn drop(&mut self) {
182
if let Some(p) = self.footer {
183
unsafe {
184
let _ = Box::from_raw(p as *mut FooterRef<'static>);
185
}
186
}
187
}
188
}
189
190
unsafe impl<R: AsyncRead + AsyncSeek + Unpin + Send> Send for FlightStreamProducer<'_, R> {}
191
192
impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> {
193
pub async fn new(reader: &'a mut R) -> PolarsResult<Pin<Box<Self>>> {
194
let (_end, len) = read_footer_len(reader).await?;
195
let footer_data = read_footer(reader, len).await?;
196
197
Ok(Box::pin(Self {
198
footer: None,
199
footer_data,
200
dict_blocks: None,
201
data_blocks: None,
202
reader,
203
}))
204
}
205
206
pub fn init(self: &mut Pin<Box<Self>>) -> PolarsResult<()> {
207
let footer = arrow_format::ipc::FooterRef::read_as_root(&self.footer_data)
208
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?;
209
210
let footer = Box::new(footer);
211
212
#[allow(clippy::unnecessary_cast)]
213
let ptr = Box::leak(footer) as *const _ as *const FooterRef<'static>;
214
215
self.footer = Some(ptr);
216
let footer = &unsafe { **self.footer.as_ref().unwrap() };
217
218
self.data_blocks = Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?)
219
as Box<dyn SendableIterator<Item = _>>);
220
self.dict_blocks = iter_dictionary_blocks_from_footer(*footer)?
221
.map(|i| Box::new(i) as Box<dyn SendableIterator<Item = _>>);
222
223
Ok(())
224
}
225
226
pub fn get_schema(self: &Pin<Box<Self>>) -> PolarsResult<EncodedData> {
227
let footer = &unsafe { **self.footer.as_ref().expect("init must be called first") };
228
229
let schema_ref = deserialize_schema_ref_from_footer(*footer)?;
230
let schema = schema_to_raw_message(schema_ref);
231
232
Ok(schema)
233
}
234
235
pub async fn next_dict(
236
self: &mut Pin<Box<Self>>,
237
encoded_data: &mut EncodedData,
238
) -> PolarsResult<Option<()>> {
239
assert!(self.data_blocks.is_some(), "init must be called first");
240
encoded_data.ipc_message.clear();
241
encoded_data.arrow_data.clear();
242
243
if let Some(iter) = &mut self.dict_blocks {
244
let Some(value) = iter.next() else {
245
return Ok(None);
246
};
247
let block = value?;
248
249
block_to_raw_message(&mut self.reader, &block, encoded_data).await?;
250
Ok(Some(()))
251
} else {
252
Ok(None)
253
}
254
}
255
256
pub async fn next_data(
257
self: &mut Pin<Box<Self>>,
258
encoded_data: &mut EncodedData,
259
) -> PolarsResult<Option<()>> {
260
encoded_data.ipc_message.clear();
261
encoded_data.arrow_data.clear();
262
263
let iter = self
264
.data_blocks
265
.as_mut()
266
.expect("init must be called first");
267
let Some(value) = iter.next() else {
268
return Ok(None);
269
};
270
let block = value?;
271
272
block_to_raw_message(&mut self.reader, &block, encoded_data).await?;
273
Ok(Some(()))
274
}
275
}
276
277
pub struct FlightConsumer {
278
dictionaries: Dictionaries,
279
md: StreamMetadata,
280
scratch: Vec<u8>,
281
checked: UnsafeBool,
282
}
283
284
impl FlightConsumer {
285
pub fn new(first: EncodedData) -> PolarsResult<Self> {
286
let md = deserialize_stream_metadata(&first.ipc_message)?;
287
Ok(Self {
288
dictionaries: Default::default(),
289
md,
290
scratch: vec![],
291
checked: Default::default(),
292
})
293
}
294
295
/// # Safety
296
/// Don't do expensive checks.
297
/// This means the data source has to be trusted to be correct.
298
pub unsafe fn unchecked(mut self) -> Self {
299
unsafe {
300
self.checked = UnsafeBool::new_false();
301
}
302
self
303
}
304
305
pub fn schema(&self) -> &ArrowSchema {
306
&self.md.schema
307
}
308
309
pub fn consume(&mut self, msg: EncodedData) -> PolarsResult<Option<RecordBatch>> {
310
// Parse the header
311
let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message)
312
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?;
313
314
let header = message
315
.header()
316
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))?
317
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?;
318
319
// Either append to the dictionaries and return None or return Some(ArrowChunk)
320
match header {
321
MessageHeaderRef::Schema(_) => {
322
polars_bail!(ComputeError: "Unexpected schema message while parsing Stream");
323
},
324
// Add to dictionary state and continue iteration
325
MessageHeaderRef::DictionaryBatch(batch) => unsafe {
326
// Needed to memory map.
327
let arrow_data = Arc::new(msg.arrow_data);
328
mmap_dictionary_from_batch(
329
&self.md.schema,
330
&self.md.ipc_schema.fields,
331
&arrow_data,
332
batch,
333
&mut self.dictionaries,
334
0,
335
)
336
.map(|_| None)
337
},
338
// Return Batch
339
MessageHeaderRef::RecordBatch(batch) => {
340
if batch.compression()?.is_some() {
341
let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice());
342
read_record_batch(
343
batch,
344
&self.md.schema,
345
&self.md.ipc_schema,
346
None,
347
None,
348
&self.dictionaries,
349
self.md.version,
350
&mut reader,
351
0,
352
&mut self.scratch,
353
self.checked,
354
)
355
.map(Some)
356
} else {
357
// Needed to memory map.
358
let arrow_data = Arc::new(msg.arrow_data);
359
unsafe {
360
mmap_record(
361
&self.md.schema,
362
&self.md.ipc_schema.fields,
363
arrow_data,
364
batch,
365
0,
366
&self.dictionaries,
367
)
368
.map(Some)
369
}
370
}
371
},
372
_ => unimplemented!(),
373
}
374
}
375
}
376
377
pub struct FlightstreamConsumer<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> {
378
inner: FlightConsumer,
379
stream: S,
380
}
381
382
impl<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> FlightstreamConsumer<S> {
383
pub async fn new(mut stream: S) -> PolarsResult<Self> {
384
let Some(first) = stream.next().await else {
385
polars_bail!(ComputeError: "expected the schema")
386
};
387
let first = first?;
388
389
Ok(FlightstreamConsumer {
390
inner: FlightConsumer::new(first)?,
391
stream,
392
})
393
}
394
395
pub async fn next_batch(&mut self) -> PolarsResult<Option<RecordBatch>> {
396
while let Some(msg) = self.stream.next().await {
397
let msg = msg?;
398
let option_recordbatch = self.inner.consume(msg)?;
399
if option_recordbatch.is_some() {
400
return Ok(option_recordbatch);
401
}
402
}
403
Ok(None)
404
}
405
}
406
407
#[cfg(test)]
408
mod test {
409
use std::path::{Path, PathBuf};
410
411
use tokio::fs::File;
412
413
use super::*;
414
use crate::record_batch::RecordBatch;
415
416
fn get_file_path() -> PathBuf {
417
let polars_arrow = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
418
Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc")
419
}
420
421
fn read_file(path: &Path) -> RecordBatch {
422
let mut file = std::fs::File::open(path).unwrap();
423
let md = crate::io::ipc::read::read_file_metadata(&mut file).unwrap();
424
let mut ipc_reader = crate::io::ipc::read::FileReader::new(&mut file, md, None, None);
425
ipc_reader.next().unwrap().unwrap()
426
}
427
428
#[tokio::test]
429
async fn test_file_flight_simple() {
430
let path = &get_file_path();
431
let mut file = tokio::fs::File::open(path).await.unwrap();
432
let stream = into_flight_stream(&mut file).await.unwrap();
433
434
let mut c = FlightstreamConsumer::new(Box::pin(stream)).await.unwrap();
435
let b = c.next_batch().await.unwrap().unwrap();
436
437
assert_eq!(b, read_file(path));
438
}
439
440
#[tokio::test]
441
async fn test_file_flight_amortized() {
442
let path = &get_file_path();
443
let mut file = File::open(path).await.unwrap();
444
let mut p = FlightStreamProducer::new(&mut file).await.unwrap();
445
p.init().unwrap();
446
447
let mut batches = vec![];
448
449
let schema = p.get_schema().unwrap();
450
batches.push(schema);
451
452
let mut ed = EncodedData::default();
453
if p.next_dict(&mut ed).await.unwrap().is_some() {
454
batches.push(ed);
455
}
456
457
let mut ed = EncodedData::default();
458
p.next_data(&mut ed).await.unwrap();
459
batches.push(ed);
460
461
let mut c =
462
FlightstreamConsumer::new(Box::pin(futures::stream::iter(batches.into_iter().map(Ok))))
463
.await
464
.unwrap();
465
let b = c.next_batch().await.unwrap().unwrap();
466
467
assert_eq!(b, read_file(path));
468
}
469
}
470
471