Path: blob/main/crates/polars-arrow/src/io/ipc/read/flight.rs
8446 views
use std::io::SeekFrom;1use std::pin::Pin;2use std::sync::Arc;34use arrow_format::ipc::planus::ReadAsRoot;5use arrow_format::ipc::{Block, FooterRef, MessageHeaderRef};6use futures::{Stream, StreamExt};7use polars_error::{PolarsResult, polars_bail, polars_err};8use polars_utils::bool::UnsafeBool;9use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};1011use crate::datatypes::ArrowSchema;12use crate::io::ipc::read::common::read_record_batch;13use crate::io::ipc::read::file::{14decode_footer_len, deserialize_schema_ref_from_footer, iter_dictionary_blocks_from_footer,15iter_recordbatch_blocks_from_footer,16};17use crate::io::ipc::read::schema::deserialize_stream_metadata;18use crate::io::ipc::read::{Dictionaries, OutOfSpecKind, SendableIterator, StreamMetadata};19use crate::io::ipc::write::common::EncodedData;20use crate::mmap::{mmap_dictionary_from_batch, mmap_record};21use crate::record_batch::RecordBatch;2223async fn read_ipc_message_from_block<'a, R: AsyncRead + AsyncSeek + Unpin>(24reader: &mut R,25block: &arrow_format::ipc::Block,26scratch: &'a mut Vec<u8>,27) -> PolarsResult<arrow_format::ipc::MessageRef<'a>> {28let offset: u64 = block29.offset30.try_into()31.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;32reader.seek(SeekFrom::Start(offset)).await?;33read_ipc_message(reader, scratch).await34}3536/// Read an encapsulated IPC Message from the reader37async fn read_ipc_message<'a, R: AsyncRead + Unpin>(38reader: &mut R,39scratch: &'a mut Vec<u8>,40) -> PolarsResult<arrow_format::ipc::MessageRef<'a>> {41let mut message_size: [u8; 4] = [0; 4];4243reader.read_exact(&mut message_size).await?;44if message_size == crate::io::ipc::CONTINUATION_MARKER {45reader.read_exact(&mut message_size).await?;46};47let message_length = i32::from_le_bytes(message_size);4849let message_length: usize = message_length50.try_into()51.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;5253scratch.clear();54scratch.try_reserve(message_length)?;55reader56.take(message_length as u64)57.read_to_end(scratch)58.await?;5960arrow_format::ipc::MessageRef::read_as_root(scratch)61.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))62}6364async fn read_footer_len<R: AsyncRead + AsyncSeek + Unpin>(65reader: &mut R,66) -> PolarsResult<(u64, usize)> {67// read footer length and magic number in footer68let end = reader.seek(SeekFrom::End(-10)).await? + 10;6970let mut footer: [u8; 10] = [0; 10];71reader.read_exact(&mut footer).await?;7273decode_footer_len(footer, end)74}7576async fn read_footer<R: AsyncRead + AsyncSeek + Unpin>(77reader: &mut R,78footer_len: usize,79) -> PolarsResult<Vec<u8>> {80// read footer81reader.seek(SeekFrom::End(-10 - footer_len as i64)).await?;8283let mut serialized_footer = vec![];84serialized_footer.try_reserve(footer_len)?;8586reader87.take(footer_len as u64)88.read_to_end(&mut serialized_footer)89.await?;90Ok(serialized_footer)91}9293fn schema_to_raw_message(schema: arrow_format::ipc::SchemaRef) -> EncodedData {94// Turn the IPC schema into an encapsulated message95let message = arrow_format::ipc::Message {96version: arrow_format::ipc::MetadataVersion::V5,97// Assumed the conversion is infallible.98header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new(99schema.try_into().unwrap(),100))),101body_length: 0,102custom_metadata: None, // todo: allow writing custom metadata103};104let mut builder = arrow_format::ipc::planus::Builder::new();105let header = builder.finish(&message, None).to_vec();106107// Use `EncodedData` directly instead of `FlightData`. In FlightData we would only use108// `data_header` and `data_body`.109EncodedData {110ipc_message: header,111arrow_data: vec![],112}113}114115async fn block_to_raw_message<'a, R>(116reader: &mut R,117block: &arrow_format::ipc::Block,118encoded_data: &mut EncodedData,119) -> PolarsResult<()>120where121R: AsyncRead + AsyncSeek + Unpin + Send + 'a,122{123debug_assert!(encoded_data.arrow_data.is_empty() && encoded_data.ipc_message.is_empty());124let message = read_ipc_message_from_block(reader, block, &mut encoded_data.ipc_message).await?;125126let block_length: u64 = message127.body_length()128.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBodyLength(err)))?129.try_into()130.map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?;131reader132.take(block_length)133.read_to_end(&mut encoded_data.arrow_data)134.await?;135136Ok(())137}138139pub async fn into_flight_stream<R: AsyncRead + AsyncSeek + Unpin + Send>(140reader: &mut R,141) -> PolarsResult<impl Stream<Item = PolarsResult<EncodedData>> + '_> {142Ok(async_stream::try_stream! {143let (_end, len) = read_footer_len(reader).await?;144let footer_data = read_footer(reader, len).await?;145let footer = arrow_format::ipc::FooterRef::read_as_root(&footer_data)146.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?;147let data_blocks = iter_recordbatch_blocks_from_footer(footer)?;148let dict_blocks = iter_dictionary_blocks_from_footer(footer)?;149150let schema_ref = deserialize_schema_ref_from_footer(footer)?;151let schema = schema_to_raw_message(schema_ref);152153yield schema;154155if let Some(dict_blocks_iter) = dict_blocks {156for d in dict_blocks_iter {157let mut ed: EncodedData = Default::default();158block_to_raw_message(reader, &d?, &mut ed).await?;159yield ed160}161};162163for d in data_blocks {164let mut ed: EncodedData = Default::default();165block_to_raw_message(reader, &d?, &mut ed).await?;166yield ed167}168})169}170171pub struct FlightStreamProducer<'a, R: AsyncRead + AsyncSeek + Unpin + Send> {172footer: Option<*const FooterRef<'static>>,173footer_data: Vec<u8>,174dict_blocks: Option<Box<dyn SendableIterator<Item = PolarsResult<Block>>>>,175data_blocks: Option<Box<dyn SendableIterator<Item = PolarsResult<Block>>>>,176reader: &'a mut R,177}178179impl<R: AsyncRead + AsyncSeek + Unpin + Send> Drop for FlightStreamProducer<'_, R> {180fn drop(&mut self) {181if let Some(p) = self.footer {182unsafe {183let _ = Box::from_raw(p as *mut FooterRef<'static>);184}185}186}187}188189unsafe impl<R: AsyncRead + AsyncSeek + Unpin + Send> Send for FlightStreamProducer<'_, R> {}190191impl<'a, R: AsyncRead + AsyncSeek + Unpin + Send> FlightStreamProducer<'a, R> {192pub async fn new(reader: &'a mut R) -> PolarsResult<Pin<Box<Self>>> {193let (_end, len) = read_footer_len(reader).await?;194let footer_data = read_footer(reader, len).await?;195196Ok(Box::pin(Self {197footer: None,198footer_data,199dict_blocks: None,200data_blocks: None,201reader,202}))203}204205pub fn init(self: &mut Pin<Box<Self>>) -> PolarsResult<()> {206let footer = arrow_format::ipc::FooterRef::read_as_root(&self.footer_data)207.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?;208209let footer = Box::new(footer);210211#[allow(clippy::unnecessary_cast)]212let ptr = Box::leak(footer) as *const _ as *const FooterRef<'static>;213214self.footer = Some(ptr);215let footer = &unsafe { **self.footer.as_ref().unwrap() };216217self.data_blocks = Some(Box::new(iter_recordbatch_blocks_from_footer(*footer)?)218as Box<dyn SendableIterator<Item = _>>);219self.dict_blocks = iter_dictionary_blocks_from_footer(*footer)?220.map(|i| Box::new(i) as Box<dyn SendableIterator<Item = _>>);221222Ok(())223}224225pub fn get_schema(self: &Pin<Box<Self>>) -> PolarsResult<EncodedData> {226let footer = &unsafe { **self.footer.as_ref().expect("init must be called first") };227228let schema_ref = deserialize_schema_ref_from_footer(*footer)?;229let schema = schema_to_raw_message(schema_ref);230231Ok(schema)232}233234pub async fn next_dict(235self: &mut Pin<Box<Self>>,236encoded_data: &mut EncodedData,237) -> PolarsResult<Option<()>> {238assert!(self.data_blocks.is_some(), "init must be called first");239encoded_data.ipc_message.clear();240encoded_data.arrow_data.clear();241242if let Some(iter) = &mut self.dict_blocks {243let Some(value) = iter.next() else {244return Ok(None);245};246let block = value?;247248block_to_raw_message(&mut self.reader, &block, encoded_data).await?;249Ok(Some(()))250} else {251Ok(None)252}253}254255pub async fn next_data(256self: &mut Pin<Box<Self>>,257encoded_data: &mut EncodedData,258) -> PolarsResult<Option<()>> {259encoded_data.ipc_message.clear();260encoded_data.arrow_data.clear();261262let iter = self263.data_blocks264.as_mut()265.expect("init must be called first");266let Some(value) = iter.next() else {267return Ok(None);268};269let block = value?;270271block_to_raw_message(&mut self.reader, &block, encoded_data).await?;272Ok(Some(()))273}274}275276pub struct FlightConsumer {277dictionaries: Dictionaries,278md: StreamMetadata,279scratch: Vec<u8>,280checked: UnsafeBool,281}282283impl FlightConsumer {284pub fn new(first: EncodedData) -> PolarsResult<Self> {285let md = deserialize_stream_metadata(&first.ipc_message)?;286Ok(Self {287dictionaries: Default::default(),288md,289scratch: vec![],290checked: Default::default(),291})292}293294/// # Safety295/// Don't do expensive checks.296/// This means the data source has to be trusted to be correct.297pub unsafe fn unchecked(mut self) -> Self {298unsafe {299self.checked = UnsafeBool::new_false();300}301self302}303304pub fn schema(&self) -> &ArrowSchema {305&self.md.schema306}307308pub fn consume(&mut self, msg: EncodedData) -> PolarsResult<Option<RecordBatch>> {309// Parse the header310let message = arrow_format::ipc::MessageRef::read_as_root(&msg.ipc_message)311.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?;312313let header = message314.header()315.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferHeader(err)))?316.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageHeader))?;317318// Either append to the dictionaries and return None or return Some(ArrowChunk)319match header {320MessageHeaderRef::Schema(_) => {321polars_bail!(ComputeError: "Unexpected schema message while parsing Stream");322},323// Add to dictionary state and continue iteration324MessageHeaderRef::DictionaryBatch(batch) => unsafe {325// Needed to memory map.326let arrow_data = Arc::new(msg.arrow_data);327mmap_dictionary_from_batch(328&self.md.schema,329&self.md.ipc_schema.fields,330&arrow_data,331batch,332&mut self.dictionaries,3330,334)335.map(|_| None)336},337// Return Batch338MessageHeaderRef::RecordBatch(batch) => {339if batch.compression()?.is_some() {340let mut reader = std::io::Cursor::new(msg.arrow_data.as_slice());341read_record_batch(342batch,343&self.md.schema,344&self.md.ipc_schema,345None,346None,347&self.dictionaries,348self.md.version,349&mut reader,3500,351&mut self.scratch,352self.checked,353)354.map(Some)355} else {356// Needed to memory map.357let arrow_data = Arc::new(msg.arrow_data);358unsafe {359mmap_record(360&self.md.schema,361&self.md.ipc_schema.fields,362arrow_data,363batch,3640,365&self.dictionaries,366)367.map(Some)368}369}370},371_ => unimplemented!(),372}373}374}375376pub struct FlightstreamConsumer<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> {377inner: FlightConsumer,378stream: S,379}380381impl<S: Stream<Item = PolarsResult<EncodedData>> + Unpin> FlightstreamConsumer<S> {382pub async fn new(mut stream: S) -> PolarsResult<Self> {383let Some(first) = stream.next().await else {384polars_bail!(ComputeError: "expected the schema")385};386let first = first?;387388Ok(FlightstreamConsumer {389inner: FlightConsumer::new(first)?,390stream,391})392}393394pub async fn next_batch(&mut self) -> PolarsResult<Option<RecordBatch>> {395while let Some(msg) = self.stream.next().await {396let msg = msg?;397let option_recordbatch = self.inner.consume(msg)?;398if option_recordbatch.is_some() {399return Ok(option_recordbatch);400}401}402Ok(None)403}404}405406#[cfg(test)]407mod test {408use std::path::{Path, PathBuf};409410use tokio::fs::File;411412use super::*;413use crate::record_batch::RecordBatch;414415fn get_file_path() -> PathBuf {416let polars_arrow = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");417Path::new(&polars_arrow).join("../../py-polars/tests/unit/io/files/foods1.ipc")418}419420fn read_file(path: &Path) -> RecordBatch {421let mut file = std::fs::File::open(path).unwrap();422let md = crate::io::ipc::read::read_file_metadata(&mut file).unwrap();423let mut ipc_reader = crate::io::ipc::read::FileReader::new(&mut file, md, None, None);424ipc_reader.next().unwrap().unwrap()425}426427#[tokio::test]428async fn test_file_flight_simple() {429let path = &get_file_path();430let mut file = tokio::fs::File::open(path).await.unwrap();431let stream = into_flight_stream(&mut file).await.unwrap();432433let mut c = FlightstreamConsumer::new(Box::pin(stream)).await.unwrap();434let b = c.next_batch().await.unwrap().unwrap();435436assert_eq!(b, read_file(path));437}438439#[tokio::test]440async fn test_file_flight_amortized() {441let path = &get_file_path();442let mut file = File::open(path).await.unwrap();443let mut p = FlightStreamProducer::new(&mut file).await.unwrap();444p.init().unwrap();445446let mut batches = vec![];447448let schema = p.get_schema().unwrap();449batches.push(schema);450451let mut ed = EncodedData::default();452if p.next_dict(&mut ed).await.unwrap().is_some() {453batches.push(ed);454}455456let mut ed = EncodedData::default();457p.next_data(&mut ed).await.unwrap();458batches.push(ed);459460let mut c =461FlightstreamConsumer::new(Box::pin(futures::stream::iter(batches.into_iter().map(Ok))))462.await463.unwrap();464let b = c.next_batch().await.unwrap().unwrap();465466assert_eq!(b, read_file(path));467}468}469470471