Path: blob/main/crates/polars-arrow/src/io/ipc/read/common.rs
6940 views
use std::collections::VecDeque;1use std::io::{Read, Seek};2use std::sync::Arc;34use polars_error::{PolarsResult, polars_bail, polars_err};5use polars_utils::aliases::PlHashMap;6use polars_utils::pl_str::PlSmallStr;78use super::Dictionaries;9use super::deserialize::{read, skip};10use crate::array::*;11use crate::datatypes::{ArrowDataType, ArrowSchema, Field};12use crate::io::ipc::read::OutOfSpecKind;13use crate::io::ipc::{IpcField, IpcSchema};14use crate::record_batch::RecordBatchT;1516#[derive(Debug, Eq, PartialEq, Hash)]17enum ProjectionResult<A> {18Selected(A),19NotSelected(A),20}2122/// An iterator adapter that will return `Some(x)` or `None`23/// # Panics24/// The iterator panics iff the `projection` is not strictly increasing.25struct ProjectionIter<'a, A, I: Iterator<Item = A>> {26projection: &'a [usize],27iter: I,28current_count: usize,29current_projection: usize,30}3132impl<'a, A, I: Iterator<Item = A>> ProjectionIter<'a, A, I> {33/// # Panics34/// iff `projection` is empty35pub fn new(projection: &'a [usize], iter: I) -> Self {36Self {37projection: &projection[1..],38iter,39current_count: 0,40current_projection: projection[0],41}42}43}4445impl<A, I: Iterator<Item = A>> Iterator for ProjectionIter<'_, A, I> {46type Item = ProjectionResult<A>;4748fn next(&mut self) -> Option<Self::Item> {49if let Some(item) = self.iter.next() {50let result = if self.current_count == self.current_projection {51if !self.projection.is_empty() {52assert!(self.projection[0] > self.current_projection);53self.current_projection = self.projection[0];54self.projection = &self.projection[1..];55} else {56self.current_projection = 0 // a value that most likely already passed57};58Some(ProjectionResult::Selected(item))59} else {60Some(ProjectionResult::NotSelected(item))61};62self.current_count += 1;63result64} else {65None66}67}6869fn size_hint(&self) -> (usize, Option<usize>) {70self.iter.size_hint()71}72}7374/// Returns a [`RecordBatchT`] from a reader.75/// # Panic76/// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid)77#[allow(clippy::too_many_arguments)]78pub fn read_record_batch<R: Read + Seek>(79batch: arrow_format::ipc::RecordBatchRef,80fields: &ArrowSchema,81ipc_schema: &IpcSchema,82projection: Option<&[usize]>,83limit: Option<usize>,84dictionaries: &Dictionaries,85version: arrow_format::ipc::MetadataVersion,86reader: &mut R,87block_offset: u64,88file_size: u64,89scratch: &mut Vec<u8>,90) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {91assert_eq!(fields.len(), ipc_schema.fields.len());92let buffers = batch93.buffers()94.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBuffers(err)))?95.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageBuffers))?;96let mut variadic_buffer_counts = batch97.variadic_buffer_counts()98.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))?99.map(|v| v.iter().map(|v| v as usize).collect::<VecDeque<usize>>())100.unwrap_or_else(VecDeque::new);101let mut buffers: VecDeque<arrow_format::ipc::BufferRef> = buffers.iter().collect();102103// check that the sum of the sizes of all buffers is <= than the size of the file104let buffers_size = buffers105.iter()106.map(|buffer| {107let buffer_size: u64 = buffer108.length()109.try_into()110.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;111Ok(buffer_size)112})113.sum::<PolarsResult<u64>>()?;114if buffers_size > file_size {115return Err(polars_err!(116oos = OutOfSpecKind::InvalidBuffersLength {117buffers_size,118file_size,119}120));121}122123let field_nodes = batch124.nodes()125.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferNodes(err)))?126.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageNodes))?;127let mut field_nodes = field_nodes.iter().collect::<VecDeque<_>>();128129let columns = if let Some(projection) = projection {130let projection = ProjectionIter::new(131projection,132fields.iter_values().zip(ipc_schema.fields.iter()),133);134135projection136.map(|maybe_field| match maybe_field {137ProjectionResult::Selected((field, ipc_field)) => Ok(Some(read(138&mut field_nodes,139&mut variadic_buffer_counts,140field,141ipc_field,142&mut buffers,143reader,144dictionaries,145block_offset,146ipc_schema.is_little_endian,147batch.compression().map_err(|err| {148polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err))149})?,150limit,151version,152scratch,153)?)),154ProjectionResult::NotSelected((field, _)) => {155skip(156&mut field_nodes,157&field.dtype,158&mut buffers,159&mut variadic_buffer_counts,160)?;161Ok(None)162},163})164.filter_map(|x| x.transpose())165.collect::<PolarsResult<Vec<_>>>()?166} else {167fields168.iter_values()169.zip(ipc_schema.fields.iter())170.map(|(field, ipc_field)| {171read(172&mut field_nodes,173&mut variadic_buffer_counts,174field,175ipc_field,176&mut buffers,177reader,178dictionaries,179block_offset,180ipc_schema.is_little_endian,181batch.compression().map_err(|err| {182polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err))183})?,184limit,185version,186scratch,187)188})189.collect::<PolarsResult<Vec<_>>>()?190};191192let length = batch193.length()194.map_err(|_| polars_err!(oos = OutOfSpecKind::MissingData))195.unwrap()196.try_into()197.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;198let length = limit.map(|limit| limit.min(length)).unwrap_or(length);199200let mut schema: ArrowSchema = fields.iter_values().cloned().collect();201if let Some(projection) = projection {202schema = schema.try_project_indices(projection).unwrap();203}204RecordBatchT::try_new(length, Arc::new(schema), columns)205}206207fn find_first_dict_field_d<'a>(208id: i64,209dtype: &'a ArrowDataType,210ipc_field: &'a IpcField,211) -> Option<(&'a Field, &'a IpcField)> {212use ArrowDataType::*;213match dtype {214Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref(), ipc_field),215List(field) | LargeList(field) | FixedSizeList(field, ..) | Map(field, ..) => {216find_first_dict_field(id, field.as_ref(), &ipc_field.fields[0])217},218Struct(fields) => {219for (field, ipc_field) in fields.iter().zip(ipc_field.fields.iter()) {220if let Some(f) = find_first_dict_field(id, field, ipc_field) {221return Some(f);222}223}224None225},226Union(u) => {227for (field, ipc_field) in u.fields.iter().zip(ipc_field.fields.iter()) {228if let Some(f) = find_first_dict_field(id, field, ipc_field) {229return Some(f);230}231}232None233},234_ => None,235}236}237238fn find_first_dict_field<'a>(239id: i64,240field: &'a Field,241ipc_field: &'a IpcField,242) -> Option<(&'a Field, &'a IpcField)> {243if let Some(field_id) = ipc_field.dictionary_id {244if id == field_id {245return Some((field, ipc_field));246}247}248find_first_dict_field_d(id, &field.dtype, ipc_field)249}250251pub(crate) fn first_dict_field<'a>(252id: i64,253fields: &'a ArrowSchema,254ipc_fields: &'a [IpcField],255) -> PolarsResult<(&'a Field, &'a IpcField)> {256assert_eq!(fields.len(), ipc_fields.len());257for (field, ipc_field) in fields.iter_values().zip(ipc_fields.iter()) {258if let Some(field) = find_first_dict_field(id, field, ipc_field) {259return Ok(field);260}261}262Err(polars_err!(263oos = OutOfSpecKind::InvalidId { requested_id: id }264))265}266267/// Reads a dictionary from the reader,268/// updating `dictionaries` with the resulting dictionary269#[allow(clippy::too_many_arguments)]270pub fn read_dictionary<R: Read + Seek>(271batch: arrow_format::ipc::DictionaryBatchRef,272fields: &ArrowSchema,273ipc_schema: &IpcSchema,274dictionaries: &mut Dictionaries,275reader: &mut R,276block_offset: u64,277file_size: u64,278scratch: &mut Vec<u8>,279) -> PolarsResult<()> {280if batch281.is_delta()282.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferIsDelta(err)))?283{284polars_bail!(ComputeError: "delta dictionary batches not supported")285}286287let id = batch288.id()289.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferId(err)))?;290let (first_field, first_ipc_field) = first_dict_field(id, fields, &ipc_schema.fields)?;291292let batch = batch293.data()294.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferData(err)))?295.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingData))?;296297let value_type =298if let ArrowDataType::Dictionary(_, value_type, _) = first_field.dtype.to_logical_type() {299value_type.as_ref()300} else {301polars_bail!(oos = OutOfSpecKind::InvalidIdDataType { requested_id: id })302};303304// Make a fake schema for the dictionary batch.305let fields = std::iter::once((306PlSmallStr::EMPTY,307Field::new(PlSmallStr::EMPTY, value_type.clone(), false),308))309.collect();310let ipc_schema = IpcSchema {311fields: vec![first_ipc_field.clone()],312is_little_endian: ipc_schema.is_little_endian,313};314let chunk = read_record_batch(315batch,316&fields,317&ipc_schema,318None,319None, // we must read the whole dictionary320dictionaries,321arrow_format::ipc::MetadataVersion::V5,322reader,323block_offset,324file_size,325scratch,326)?;327328dictionaries.insert(id, chunk.into_arrays().pop().unwrap());329330Ok(())331}332333#[derive(Clone)]334pub struct ProjectionInfo {335pub columns: Vec<usize>,336pub map: PlHashMap<usize, usize>,337pub schema: ArrowSchema,338}339340pub fn prepare_projection(schema: &ArrowSchema, mut projection: Vec<usize>) -> ProjectionInfo {341let schema = projection342.iter()343.map(|x| {344let (k, v) = schema.get_at_index(*x).unwrap();345(k.clone(), v.clone())346})347.collect();348349// todo: find way to do this more efficiently350let mut indices = (0..projection.len()).collect::<Vec<_>>();351indices.sort_unstable_by_key(|&i| &projection[i]);352let map = indices.iter().copied().enumerate().fold(353PlHashMap::default(),354|mut acc, (index, new_index)| {355acc.insert(index, new_index);356acc357},358);359projection.sort_unstable();360361// check unique362if !projection.is_empty() {363let mut previous = projection[0];364365for &i in &projection[1..] {366assert!(367previous < i,368"The projection on IPC must not contain duplicates"369);370previous = i;371}372}373374ProjectionInfo {375columns: projection,376map,377schema,378}379}380381pub fn apply_projection(382chunk: RecordBatchT<Box<dyn Array>>,383map: &PlHashMap<usize, usize>,384) -> RecordBatchT<Box<dyn Array>> {385let length = chunk.len();386387// re-order according to projection388let (schema, arrays) = chunk.into_schema_and_arrays();389let mut new_schema = schema.as_ref().clone();390let mut new_arrays = arrays.clone();391392map.iter().for_each(|(old, new)| {393let (old_name, old_field) = schema.get_at_index(*old).unwrap();394let (new_name, new_field) = new_schema.get_at_index_mut(*new).unwrap();395396*new_name = old_name.clone();397*new_field = old_field.clone();398399new_arrays[*new] = arrays[*old].clone();400});401402RecordBatchT::new(length, Arc::new(new_schema), new_arrays)403}404405#[cfg(test)]406mod tests {407use super::*;408409#[test]410fn project_iter() {411let iter = 1..6;412let iter = ProjectionIter::new(&[0, 2, 4], iter);413let result: Vec<_> = iter.collect();414use ProjectionResult::*;415assert_eq!(416result,417vec![418Selected(1),419NotSelected(2),420Selected(3),421NotSelected(4),422Selected(5)423]424)425}426}427428429