Path: blob/main/crates/polars-arrow/src/io/ipc/read/flight.rs
6940 views
use std::io::SeekFrom;1use std::pin::Pin;2use std::sync::Arc;34use arrow_format::ipc::planus::ReadAsRoot;5use arrow_format::ipc::{Block, FooterRef, MessageHeaderRef};6use futures::{Stream, StreamExt};7use polars_error::{PolarsResult, polars_bail, polars_err};8use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};910use crate::datatypes::ArrowSchema;11use crate::io::ipc::read::common::read_record_batch;12use crate::io::ipc::read::file::{13decode_footer_len, deserialize_schema_ref_from_footer, iter_dictionary_blocks_from_footer,14iter_recordbatch_blocks_from_footer,15};16use crate::io::ipc::read::schema::deserialize_stream_metadata;17use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, SendableIterator, StreamMetadata};18use crate::io::ipc::write::common::EncodedData;19use crate::mmap::{mmap_dictionary_from_batch, mmap_record};20use crate::record_batch::RecordBatch;2122async fn read_ipc_message_from_block<'a, R: AsyncRead + AsyncSeek + Unpin>(23reader: &mut R,24block: &arrow_format::ipc::Block,25scratch: &'a mut Vec<u8>,26) -> PolarsResult<arrow_format::ipc::MessageRef<'a>> {27let offset: u64 = block28.offset29.try_into()30.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;31reader.seek(SeekFrom::Start(offset)).await?;32read_ipc_message(reader, scratch).await33}3435/// Read an encapsulated IPC Message from the reader36async fn read_ipc_message<'a, R: AsyncRead + Unpin>(37reader: &mut R,38scratch: &'a mut Vec<u8>,39) -> PolarsResult<arrow_format::ipc::MessageRef<'a>> {40let mut message_size: [u8; 4] = [0; 4];4142reader.read_exact(&mut message_size).await?;43if message_size == crate::io::ipc::CONTINUATION_MARKER {44reader.read_exact(&mut message_size).await?;45};46let message_length = i32::from_le_bytes(message_size);4748let message_length: usize = message_length49.try_into()50.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;5152scratch.clear();53scratch.try_reserve(message_length)?;54reader55.take(message_length as u64)56.read_to_end(scratch)57.await?;5859arrow_format::ipc::MessageRef::read_as_root(scratch)60.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))61}6263async fn read_footer_len<R: AsyncRead + AsyncSeek + Unpin>(64reader: &mut R,65) -> PolarsResult<(u64, usize)> {66// read footer length and magic number in footer67let end = reader.seek(SeekFrom::End(-10)).await? + 10;6869let mut footer: [u8; 10] = [0; 10];70reader.read_exact(&mut footer).await?;7172decode_footer_len(footer, end)73}7475async fn read_footer<R: AsyncRead + AsyncSeek + Unpin>(76reader: &mut R,77footer_len: usize,78) -> PolarsResult<Vec<u8>> {79// read footer80reader.seek(SeekFrom::End(-10 - footer_len as i64)).await?;8182let mut serialized_footer = vec![];83serialized_footer.try_reserve(footer_len)?;8485reader86.take(footer_len as u64)87.read_to_end(&mut serialized_footer)88.await?;89Ok(serialized_footer)90}9192fn schema_to_raw_message(schema: arrow_format::ipc::SchemaRef) -> EncodedData {93// Turn the IPC schema into an encapsulated message94let message = arrow_format::ipc::Message {95version: arrow_format::ipc::MetadataVersion::V5,96// Assumed the conversion is infallible.97header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new(98schema.try_into().unwrap(),99))),100body_length: 0,101custom_metadata: None, // todo: allow writing custom metadata102};103let mut builder = arrow_format::ipc::planus::Builder::new();104let header = builder.finish(&message, None).to_vec();105106// Use `EncodedData` directly instead of `FlightData`. In FlightData we would only use107// `data_header` and `data_body`.108EncodedData {109ipc_message: header,110arrow_data: vec![],111}112}113114async fn block_to_raw_message<'a, R>(115reader: &mut R,116block: &arrow_format::ipc::Block,117encoded_data: &mut EncodedData,118) -> PolarsResult<()>119where120R: AsyncRead + AsyncSeek + Unpin + Send + 'a,121{122debug_assert!(encoded_data.arrow_data.is_empty() && encoded_data.ipc_message.is_empty());123let message = read_ipc_message_from_block(reader, block, &mut encoded_data.ipc_message).await?;124125let block_length: u64 = message126.body_length()127.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))?128.try_into()129.map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?;130reader131.take(block_length)132.read_to_end(&mut encoded_data.arrow_data)133.await?;134135Ok(())136}137138pub async fn into_flight_stream<R: AsyncRead + AsyncSeek + Unpin + Send>(139reader: &mut R,140) -> PolarsResult<impl Stream<Item = PolarsResult<EncodedData>> + '_> {141Ok(async_stream::try_stream! {142let (_end, len) = read_footer_len(reader).await?;143let footer_data = read_footer(reader, len).await?;144let footer = arrow_format::ipc::FooterRef::read_as_root(&footer_data)145.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?;146let data_blocks = iter_recordbatch_blocks_from_footer(footer)?;147let dict_blocks = iter_dictionary_blocks_from_footer(footer)?;148149let schema_ref = deserialize_schema_ref_from_footer(footer)?;150let schema = schema_to_raw_message(schema_ref);151152yield schema;153154if let Some(dict_blocks_iter) = dict_blocks {155for d in dict_blocks_iter {156let mut ed: EncodedData = Default::default();157block_to_raw_message(reader, &d?, &mut ed).await?;158yield ed159}160};161162for d in data_blocks {163let mut ed: EncodedData = Default::default();164block_to_raw_message(reader, &d?, &mut ed).await?;165yield ed166}167})168}169170pub struct FlightStreamProducer<'a, R: AsyncRead + AsyncSeek + Unpin + Send> {171footer: Option<*const FooterRef<'static>>,172footer_data: Vec<u8>,173dict_blocks: Option<Box<dyn SendableIterator<Item = PolarsResult<Block>>>>,174data_blocks: Option<Box<dyn SendableIterator<Item = PolarsResult<Block>>>>,175reader: &'a mut R,176}177178impl<R: AsyncRead + AsyncSeek + Unpin + Send> Drop for FlightStreamProducer<'_, R> {179fn drop(&mut self) {180if let Some(p) = self.footer {181unsafe {182let _ = Box::from_raw(p as *mut FooterRef<'static>);183}184}185}186}187188unsafe impl<R: AsyncRead + AsyncSeek + Unpin + Send> Send for FlightStreamProducer<'_, R> {}189190impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> {191pub async fn new(reader: &'a mut R) -> PolarsResult<Pin<Box<Self>>> {192let (_end, len) = read_footer_len(reader).await?;193let footer_data = read_footer(reader, len).await?;194195Ok(Box::pin(Self {196footer: None,197footer_data,198dict_blocks: None,199data_blocks: None,200reader,201}))202}203204pub fn init(self: &mut Pin<Box<Self>>) -> PolarsResult<()> {205let footer = arrow_format::ipc::FooterRef::read_as_root(&self.footer_data)206.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?;207208let footer = Box::new(footer);209210#[allow(clippy::unnecessary_cast)]211let ptr = Box::leak(footer) as *const _ as *const FooterRef<'static>;212213self.footer = Some(ptr);214let footer = &unsafe { **self.footer.as_ref().unwrap() };215216self.data_blocks = Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?)217as Box<dyn SendableIterator<Item = _>>);218self.dict_blocks = iter_dictionary_blocks_from_footer(*footer)?219.map(|i| Box::new(i) as Box<dyn SendableIterator<Item = _>>);220221Ok(())222}223224pub fn get_schema(self: &Pin<Box<Self>>) -> PolarsResult<EncodedData> {225let footer = &unsafe { **self.footer.as_ref().expect("init must be called first") };226227let schema_ref = deserialize_schema_ref_from_footer(*footer)?;228let schema = schema_to_raw_message(schema_ref);229230Ok(schema)231}232233pub async fn next_dict(234self: &mut Pin<Box<Self>>,235encoded_data: &mut EncodedData,236) -> PolarsResult<Option<()>> {237assert!(self.data_blocks.is_some(), "init must be called first");238encoded_data.ipc_message.clear();239encoded_data.arrow_data.clear();240241if let Some(iter) = &mut self.dict_blocks {242let Some(value) = iter.next() else {243return Ok(None);244};245let block = value?;246247block_to_raw_message(&mut self.reader, &block, encoded_data).await?;248Ok(Some(()))249} else {250Ok(None)251}252}253254pub async fn next_data(255self: &mut Pin<Box<Self>>,256encoded_data: &mut EncodedData,257) -> PolarsResult<Option<()>> {258encoded_data.ipc_message.clear();259encoded_data.arrow_data.clear();260261let iter = self262.data_blocks263.as_mut()264.expect("init must be called first");265let Some(value) = iter.next() else {266return Ok(None);267};268let block = value?;269270block_to_raw_message(&mut self.reader, &block, encoded_data).await?;271Ok(Some(()))272}273}274275pub struct FlightConsumer {276dictionaries: Dictionaries,277md: StreamMetadata,278scratch: Vec<u8>,279}280281impl FlightConsumer {282pub fn new(first: EncodedData) -> PolarsResult<Self> {283let md = deserialize_stream_metadata(&first.ipc_message)?;284Ok(Self {285dictionaries: Default::default(),286md,287scratch: vec![],288})289}290291pub fn schema(&self) -> &ArrowSchema {292&self.md.schema293}294295pub fn consume(&mut self, msg: EncodedData) -> PolarsResult<Option<RecordBatch>> {296// Parse the header297let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message)298.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?;299300let header = message301.header()302.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))?303.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?;304305// Either append to the dictionaries and return None or return Some(ArrowChunk)306match header {307MessageHeaderRef::Schema(_) => {308polars_bail!(ComputeError: "Unexpected schema message while parsing Stream");309},310// Add to dictionary state and continue iteration311MessageHeaderRef::DictionaryBatch(batch) => unsafe {312// Needed to memory map.313let arrow_data = Arc::new(msg.arrow_data);314mmap_dictionary_from_batch(315&self.md.schema,316&self.md.ipc_schema.fields,317&arrow_data,318batch,319&mut self.dictionaries,3200,321)322.map(|_| None)323},324// Return Batch325MessageHeaderRef::RecordBatch(batch) => {326if batch.compression()?.is_some() {327let data_size = msg.arrow_data.len() as u64;328let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice());329read_record_batch(330batch,331&self.md.schema,332&self.md.ipc_schema,333None,334None,335&self.dictionaries,336self.md.version,337&mut reader,3380,339data_size,340&mut self.scratch,341)342.map(Some)343} else {344// Needed to memory map.345let arrow_data = Arc::new(msg.arrow_data);346unsafe {347mmap_record(348&self.md.schema,349&self.md.ipc_schema.fields,350arrow_data,351batch,3520,353&self.dictionaries,354)355.map(Some)356}357}358},359_ => unimplemented!(),360}361}362}363364pub struct FlightstreamConsumer<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> {365inner: FlightConsumer,366stream: S,367}368369impl<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> FlightstreamConsumer<S> {370pub async fn new(mut stream: S) -> PolarsResult<Self> {371let Some(first) = stream.next().await else {372polars_bail!(ComputeError: "expected the schema")373};374let first = first?;375376Ok(FlightstreamConsumer {377inner: FlightConsumer::new(first)?,378stream,379})380}381382pub async fn next_batch(&mut self) -> PolarsResult<Option<RecordBatch>> {383while let Some(msg) = self.stream.next().await {384let msg = msg?;385let option_recordbatch = self.inner.consume(msg)?;386if option_recordbatch.is_some() {387return Ok(option_recordbatch);388}389}390Ok(None)391}392}393394#[cfg(test)]395mod test {396use std::path::{Path, PathBuf};397398use tokio::fs::File;399400use super::*;401use crate::record_batch::RecordBatch;402403fn get_file_path() -> PathBuf {404let polars_arrow = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");405Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc")406}407408fn read_file(path: &Path) -> RecordBatch {409let mut file = std::fs::File::open(path).unwrap();410let md = crate::io::ipc::read::read_file_metadata(&mut file).unwrap();411let mut ipc_reader = crate::io::ipc::read::FileReader::new(&mut file, md, None, None);412ipc_reader.next().unwrap().unwrap()413}414415#[tokio::test]416async fn test_file_flight_simple() {417let path = &get_file_path();418let mut file = tokio::fs::File::open(path).await.unwrap();419let stream = into_flight_stream(&mut file).await.unwrap();420421let mut c = FlightstreamConsumer::new(Box::pin(stream)).await.unwrap();422let b = c.next_batch().await.unwrap().unwrap();423424assert_eq!(b, read_file(path));425}426427#[tokio::test]428async fn test_file_flight_amortized() {429let path = &get_file_path();430let mut file = File::open(path).await.unwrap();431let mut p = FlightStreamProducer::new(&mut file).await.unwrap();432p.init().unwrap();433434let mut batches = vec![];435436let schema = p.get_schema().unwrap();437batches.push(schema);438439let mut ed = EncodedData::default();440if p.next_dict(&mut ed).await.unwrap().is_some() {441batches.push(ed);442}443444let mut ed = EncodedData::default();445p.next_data(&mut ed).await.unwrap();446batches.push(ed);447448let mut c =449FlightstreamConsumer::new(Box::pin(futures::stream::iter(batches.into_iter().map(Ok))))450.await451.unwrap();452let b = c.next_batch().await.unwrap().unwrap();453454assert_eq!(b, read_file(path));455}456}457458459