Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-arrow/src/io/ipc/read/flight.rs
6940 views
1
use std::io::SeekFrom;
2
use std::pin::Pin;
3
use std::sync::Arc;
4
5
use arrow_format::ipc::planus::ReadAsRoot;
6
use arrow_format::ipc::{Block, FooterRef, MessageHeaderRef};
7
use futures::{Stream, StreamExt};
8
use polars_error::{PolarsResult, polars_bail, polars_err};
9
use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
10
11
use crate::datatypes::ArrowSchema;
12
use crate::io::ipc::read::common::read_record_batch;
13
use crate::io::ipc::read::file::{
14
decode_footer_len, deserialize_schema_ref_from_footer, iter_dictionary_blocks_from_footer,
15
iter_recordbatch_blocks_from_footer,
16
};
17
use crate::io::ipc::read::schema::deserialize_stream_metadata;
18
use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, SendableIterator, StreamMetadata};
19
use crate::io::ipc::write::common::EncodedData;
20
use crate::mmap::{mmap_dictionary_from_batch, mmap_record};
21
use crate::record_batch::RecordBatch;
22
23
async fn read_ipc_message_from_block<'a, R: AsyncRead + AsyncSeek + Unpin>(
24
reader: &mut R,
25
block: &arrow_format::ipc::Block,
26
scratch: &'a mut Vec<u8>,
27
) -> PolarsResult<arrow_format::ipc::MessageRef<'a>> {
28
let offset: u64 = block
29
.offset
30
.try_into()
31
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
32
reader.seek(SeekFrom::Start(offset)).await?;
33
read_ipc_message(reader, scratch).await
34
}
35
36
/// Read an encapsulated IPC Message from the reader
37
async fn read_ipc_message<'a, R: AsyncRead + Unpin>(
38
reader: &mut R,
39
scratch: &'a mut Vec<u8>,
40
) -> PolarsResult<arrow_format::ipc::MessageRef<'a>> {
41
let mut message_size: [u8; 4] = [0; 4];
42
43
reader.read_exact(&mut message_size).await?;
44
if message_size == crate::io::ipc::CONTINUATION_MARKER {
45
reader.read_exact(&mut message_size).await?;
46
};
47
let message_length = i32::from_le_bytes(message_size);
48
49
let message_length: usize = message_length
50
.try_into()
51
.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
52
53
scratch.clear();
54
scratch.try_reserve(message_length)?;
55
reader
56
.take(message_length as u64)
57
.read_to_end(scratch)
58
.await?;
59
60
arrow_format::ipc::MessageRef::read_as_root(scratch)
61
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))
62
}
63
64
async fn read_footer_len<R: AsyncRead + AsyncSeek + Unpin>(
65
reader: &mut R,
66
) -> PolarsResult<(u64, usize)> {
67
// read footer length and magic number in footer
68
let end = reader.seek(SeekFrom::End(-10)).await? + 10;
69
70
let mut footer: [u8; 10] = [0; 10];
71
reader.read_exact(&mut footer).await?;
72
73
decode_footer_len(footer, end)
74
}
75
76
async fn read_footer<R: AsyncRead + AsyncSeek + Unpin>(
77
reader: &mut R,
78
footer_len: usize,
79
) -> PolarsResult<Vec<u8>> {
80
// read footer
81
reader.seek(SeekFrom::End(-10 - footer_len as i64)).await?;
82
83
let mut serialized_footer = vec![];
84
serialized_footer.try_reserve(footer_len)?;
85
86
reader
87
.take(footer_len as u64)
88
.read_to_end(&mut serialized_footer)
89
.await?;
90
Ok(serialized_footer)
91
}
92
93
fn schema_to_raw_message(schema: arrow_format::ipc::SchemaRef) -> EncodedData {
94
// Turn the IPC schema into an encapsulated message
95
let message = arrow_format::ipc::Message {
96
version: arrow_format::ipc::MetadataVersion::V5,
97
// Assumed the conversion is infallible.
98
header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new(
99
schema.try_into().unwrap(),
100
))),
101
body_length: 0,
102
custom_metadata: None, // todo: allow writing custom metadata
103
};
104
let mut builder = arrow_format::ipc::planus::Builder::new();
105
let header = builder.finish(&message, None).to_vec();
106
107
// Use `EncodedData` directly instead of `FlightData`. In FlightData we would only use
108
// `data_header` and `data_body`.
109
EncodedData {
110
ipc_message: header,
111
arrow_data: vec![],
112
}
113
}
114
115
async fn block_to_raw_message<'a, R>(
116
reader: &mut R,
117
block: &arrow_format::ipc::Block,
118
encoded_data: &mut EncodedData,
119
) -> PolarsResult<()>
120
where
121
R: AsyncRead + AsyncSeek + Unpin + Send + 'a,
122
{
123
debug_assert!(encoded_data.arrow_data.is_empty() && encoded_data.ipc_message.is_empty());
124
let message = read_ipc_message_from_block(reader, block, &mut encoded_data.ipc_message).await?;
125
126
let block_length: u64 = message
127
.body_length()
128
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))?
129
.try_into()
130
.map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?;
131
reader
132
.take(block_length)
133
.read_to_end(&mut encoded_data.arrow_data)
134
.await?;
135
136
Ok(())
137
}
138
139
pub async fn into_flight_stream<R: AsyncRead + AsyncSeek + Unpin + Send>(
140
reader: &mut R,
141
) -> PolarsResult<impl Stream<Item = PolarsResult<EncodedData>> + '_> {
142
Ok(async_stream::try_stream! {
143
let (_end, len) = read_footer_len(reader).await?;
144
let footer_data = read_footer(reader, len).await?;
145
let footer = arrow_format::ipc::FooterRef::read_as_root(&footer_data)
146
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?;
147
let data_blocks = iter_recordbatch_blocks_from_footer(footer)?;
148
let dict_blocks = iter_dictionary_blocks_from_footer(footer)?;
149
150
let schema_ref = deserialize_schema_ref_from_footer(footer)?;
151
let schema = schema_to_raw_message(schema_ref);
152
153
yield schema;
154
155
if let Some(dict_blocks_iter) = dict_blocks {
156
for d in dict_blocks_iter {
157
let mut ed: EncodedData = Default::default();
158
block_to_raw_message(reader, &d?, &mut ed).await?;
159
yield ed
160
}
161
};
162
163
for d in data_blocks {
164
let mut ed: EncodedData = Default::default();
165
block_to_raw_message(reader, &d?, &mut ed).await?;
166
yield ed
167
}
168
})
169
}
170
171
pub struct FlightStreamProducer<'a, R: AsyncRead + AsyncSeek + Unpin + Send> {
172
footer: Option<*const FooterRef<'static>>,
173
footer_data: Vec<u8>,
174
dict_blocks: Option<Box<dyn SendableIterator<Item = PolarsResult<Block>>>>,
175
data_blocks: Option<Box<dyn SendableIterator<Item = PolarsResult<Block>>>>,
176
reader: &'a mut R,
177
}
178
179
impl<R: AsyncRead + AsyncSeek + Unpin + Send> Drop for FlightStreamProducer<'_, R> {
180
fn drop(&mut self) {
181
if let Some(p) = self.footer {
182
unsafe {
183
let _ = Box::from_raw(p as *mut FooterRef<'static>);
184
}
185
}
186
}
187
}
188
189
unsafe impl<R: AsyncRead + AsyncSeek + Unpin + Send> Send for FlightStreamProducer<'_, R> {}
190
191
impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> {
192
pub async fn new(reader: &'a mut R) -> PolarsResult<Pin<Box<Self>>> {
193
let (_end, len) = read_footer_len(reader).await?;
194
let footer_data = read_footer(reader, len).await?;
195
196
Ok(Box::pin(Self {
197
footer: None,
198
footer_data,
199
dict_blocks: None,
200
data_blocks: None,
201
reader,
202
}))
203
}
204
205
pub fn init(self: &mut Pin<Box<Self>>) -> PolarsResult<()> {
206
let footer = arrow_format::ipc::FooterRef::read_as_root(&self.footer_data)
207
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?;
208
209
let footer = Box::new(footer);
210
211
#[allow(clippy::unnecessary_cast)]
212
let ptr = Box::leak(footer) as *const _ as *const FooterRef<'static>;
213
214
self.footer = Some(ptr);
215
let footer = &unsafe { **self.footer.as_ref().unwrap() };
216
217
self.data_blocks = Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?)
218
as Box<dyn SendableIterator<Item = _>>);
219
self.dict_blocks = iter_dictionary_blocks_from_footer(*footer)?
220
.map(|i| Box::new(i) as Box<dyn SendableIterator<Item = _>>);
221
222
Ok(())
223
}
224
225
pub fn get_schema(self: &Pin<Box<Self>>) -> PolarsResult<EncodedData> {
226
let footer = &unsafe { **self.footer.as_ref().expect("init must be called first") };
227
228
let schema_ref = deserialize_schema_ref_from_footer(*footer)?;
229
let schema = schema_to_raw_message(schema_ref);
230
231
Ok(schema)
232
}
233
234
pub async fn next_dict(
235
self: &mut Pin<Box<Self>>,
236
encoded_data: &mut EncodedData,
237
) -> PolarsResult<Option<()>> {
238
assert!(self.data_blocks.is_some(), "init must be called first");
239
encoded_data.ipc_message.clear();
240
encoded_data.arrow_data.clear();
241
242
if let Some(iter) = &mut self.dict_blocks {
243
let Some(value) = iter.next() else {
244
return Ok(None);
245
};
246
let block = value?;
247
248
block_to_raw_message(&mut self.reader, &block, encoded_data).await?;
249
Ok(Some(()))
250
} else {
251
Ok(None)
252
}
253
}
254
255
pub async fn next_data(
256
self: &mut Pin<Box<Self>>,
257
encoded_data: &mut EncodedData,
258
) -> PolarsResult<Option<()>> {
259
encoded_data.ipc_message.clear();
260
encoded_data.arrow_data.clear();
261
262
let iter = self
263
.data_blocks
264
.as_mut()
265
.expect("init must be called first");
266
let Some(value) = iter.next() else {
267
return Ok(None);
268
};
269
let block = value?;
270
271
block_to_raw_message(&mut self.reader, &block, encoded_data).await?;
272
Ok(Some(()))
273
}
274
}
275
276
pub struct FlightConsumer {
277
dictionaries: Dictionaries,
278
md: StreamMetadata,
279
scratch: Vec<u8>,
280
}
281
282
impl FlightConsumer {
283
pub fn new(first: EncodedData) -> PolarsResult<Self> {
284
let md = deserialize_stream_metadata(&first.ipc_message)?;
285
Ok(Self {
286
dictionaries: Default::default(),
287
md,
288
scratch: vec![],
289
})
290
}
291
292
pub fn schema(&self) -> &ArrowSchema {
293
&self.md.schema
294
}
295
296
pub fn consume(&mut self, msg: EncodedData) -> PolarsResult<Option<RecordBatch>> {
297
// Parse the header
298
let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message)
299
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?;
300
301
let header = message
302
.header()
303
.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))?
304
.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?;
305
306
// Either append to the dictionaries and return None or return Some(ArrowChunk)
307
match header {
308
MessageHeaderRef::Schema(_) => {
309
polars_bail!(ComputeError: "Unexpected schema message while parsing Stream");
310
},
311
// Add to dictionary state and continue iteration
312
MessageHeaderRef::DictionaryBatch(batch) => unsafe {
313
// Needed to memory map.
314
let arrow_data = Arc::new(msg.arrow_data);
315
mmap_dictionary_from_batch(
316
&self.md.schema,
317
&self.md.ipc_schema.fields,
318
&arrow_data,
319
batch,
320
&mut self.dictionaries,
321
0,
322
)
323
.map(|_| None)
324
},
325
// Return Batch
326
MessageHeaderRef::RecordBatch(batch) => {
327
if batch.compression()?.is_some() {
328
let data_size = msg.arrow_data.len() as u64;
329
let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice());
330
read_record_batch(
331
batch,
332
&self.md.schema,
333
&self.md.ipc_schema,
334
None,
335
None,
336
&self.dictionaries,
337
self.md.version,
338
&mut reader,
339
0,
340
data_size,
341
&mut self.scratch,
342
)
343
.map(Some)
344
} else {
345
// Needed to memory map.
346
let arrow_data = Arc::new(msg.arrow_data);
347
unsafe {
348
mmap_record(
349
&self.md.schema,
350
&self.md.ipc_schema.fields,
351
arrow_data,
352
batch,
353
0,
354
&self.dictionaries,
355
)
356
.map(Some)
357
}
358
}
359
},
360
_ => unimplemented!(),
361
}
362
}
363
}
364
365
pub struct FlightstreamConsumer<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> {
366
inner: FlightConsumer,
367
stream: S,
368
}
369
370
impl<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> FlightstreamConsumer<S> {
371
pub async fn new(mut stream: S) -> PolarsResult<Self> {
372
let Some(first) = stream.next().await else {
373
polars_bail!(ComputeError: "expected the schema")
374
};
375
let first = first?;
376
377
Ok(FlightstreamConsumer {
378
inner: FlightConsumer::new(first)?,
379
stream,
380
})
381
}
382
383
pub async fn next_batch(&mut self) -> PolarsResult<Option<RecordBatch>> {
384
while let Some(msg) = self.stream.next().await {
385
let msg = msg?;
386
let option_recordbatch = self.inner.consume(msg)?;
387
if option_recordbatch.is_some() {
388
return Ok(option_recordbatch);
389
}
390
}
391
Ok(None)
392
}
393
}
394
395
#[cfg(test)]
396
mod test {
397
use std::path::{Path, PathBuf};
398
399
use tokio::fs::File;
400
401
use super::*;
402
use crate::record_batch::RecordBatch;
403
404
fn get_file_path() -> PathBuf {
405
let polars_arrow = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");
406
Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc")
407
}
408
409
fn read_file(path: &Path) -> RecordBatch {
410
let mut file = std::fs::File::open(path).unwrap();
411
let md = crate::io::ipc::read::read_file_metadata(&mut file).unwrap();
412
let mut ipc_reader = crate::io::ipc::read::FileReader::new(&mut file, md, None, None);
413
ipc_reader.next().unwrap().unwrap()
414
}
415
416
#[tokio::test]
417
async fn test_file_flight_simple() {
418
let path = &get_file_path();
419
let mut file = tokio::fs::File::open(path).await.unwrap();
420
let stream = into_flight_stream(&mut file).await.unwrap();
421
422
let mut c = FlightstreamConsumer::new(Box::pin(stream)).await.unwrap();
423
let b = c.next_batch().await.unwrap().unwrap();
424
425
assert_eq!(b, read_file(path));
426
}
427
428
#[tokio::test]
429
async fn test_file_flight_amortized() {
430
let path = &get_file_path();
431
let mut file = File::open(path).await.unwrap();
432
let mut p = FlightStreamProducer::new(&mut file).await.unwrap();
433
p.init().unwrap();
434
435
let mut batches = vec![];
436
437
let schema = p.get_schema().unwrap();
438
batches.push(schema);
439
440
let mut ed = EncodedData::default();
441
if p.next_dict(&mut ed).await.unwrap().is_some() {
442
batches.push(ed);
443
}
444
445
let mut ed = EncodedData::default();
446
p.next_data(&mut ed).await.unwrap();
447
batches.push(ed);
448
449
let mut c =
450
FlightstreamConsumer::new(Box::pin(futures::stream::iter(batches.into_iter().map(Ok))))
451
.await
452
.unwrap();
453
let b = c.next_batch().await.unwrap().unwrap();
454
455
assert_eq!(b, read_file(path));
456
}
457
}
458
459