Path: blob/main/crates/polars-arrow/src/io/ipc/read/common.rs
8430 views
use std::collections::VecDeque;1use std::io::{Read, Seek};2use std::sync::Arc;34use polars_error::{PolarsResult, polars_bail, polars_err};5use polars_utils::aliases::PlHashMap;6use polars_utils::bool::UnsafeBool;7use polars_utils::pl_str::PlSmallStr;89use super::Dictionaries;10use super::deserialize::{read, skip};11use crate::array::*;12use crate::datatypes::{ArrowDataType, ArrowSchema, Field};13use crate::io::ipc::read::OutOfSpecKind;14use crate::io::ipc::{IpcField, IpcSchema};15use crate::record_batch::RecordBatchT;1617#[derive(Debug, Eq, PartialEq, Hash)]18enum ProjectionResult<A> {19Selected(A),20NotSelected(A),21}2223/// An iterator adapter that will return `Some(x)` or `None`24/// # Panics25/// The iterator panics iff the `projection` is not strictly increasing.26struct ProjectionIter<'a, A, I: Iterator<Item = A>> {27projection: &'a [usize],28iter: I,29current_count: usize,30current_projection: usize,31}3233impl<'a, A, I: Iterator<Item = A>> ProjectionIter<'a, A, I> {34/// # Panics35/// iff `projection` is empty36pub fn new(projection: &'a [usize], iter: I) -> Self {37Self {38projection: &projection[1..],39iter,40current_count: 0,41current_projection: projection[0],42}43}44}4546impl<A, I: Iterator<Item = A>> Iterator for ProjectionIter<'_, A, I> {47type Item = ProjectionResult<A>;4849fn next(&mut self) -> Option<Self::Item> {50if let Some(item) = self.iter.next() {51let result = if self.current_count == self.current_projection {52if !self.projection.is_empty() {53assert!(self.projection[0] > self.current_projection);54self.current_projection = self.projection[0];55self.projection = &self.projection[1..];56} else {57self.current_projection = 0 // a value that most likely already passed58};59Some(ProjectionResult::Selected(item))60} else {61Some(ProjectionResult::NotSelected(item))62};63self.current_count += 1;64result65} else {66None67}68}6970fn size_hint(&self) -> (usize, Option<usize>) {71self.iter.size_hint()72}73}7475/// Returns a [`RecordBatchT`] from a reader.76/// # Panic77/// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid)78#[allow(clippy::too_many_arguments)]79pub fn read_record_batch<R: Read + Seek>(80batch: arrow_format::ipc::RecordBatchRef,81fields: &ArrowSchema,82ipc_schema: &IpcSchema,83projection: Option<&[usize]>,84limit: Option<usize>,85dictionaries: &Dictionaries,86version: arrow_format::ipc::MetadataVersion,87reader: &mut R,88block_offset: u64,89scratch: &mut Vec<u8>,90checked: UnsafeBool,91) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {92assert_eq!(fields.len(), ipc_schema.fields.len());93let buffers = batch94.buffers()95.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferBuffers(err)))?96.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageBuffers))?;97let mut variadic_buffer_counts = batch98.variadic_buffer_counts()99.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))?100.map(|v| v.iter().map(|v| v as usize).collect::<VecDeque<usize>>())101.unwrap_or_else(VecDeque::new);102let mut buffers: VecDeque<arrow_format::ipc::BufferRef> = buffers.iter().collect();103104let field_nodes = batch105.nodes()106.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferNodes(err)))?107.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingMessageNodes))?;108let mut field_nodes = field_nodes.iter().collect::<VecDeque<_>>();109110let columns = if let Some(projection) = projection {111let projection = ProjectionIter::new(112projection,113fields.iter_values().zip(ipc_schema.fields.iter()),114);115116projection117.map(|maybe_field| match maybe_field {118ProjectionResult::Selected((field, ipc_field)) => Ok(Some(read(119&mut field_nodes,120&mut variadic_buffer_counts,121field,122ipc_field,123&mut buffers,124reader,125dictionaries,126block_offset,127ipc_schema.is_little_endian,128batch.compression().map_err(|err| {129polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err))130})?,131limit,132version,133scratch,134checked,135)?)),136ProjectionResult::NotSelected((field, _)) => {137skip(138&mut field_nodes,139&field.dtype,140&mut buffers,141&mut variadic_buffer_counts,142)?;143Ok(None)144},145})146.filter_map(|x| x.transpose())147.collect::<PolarsResult<Vec<_>>>()?148} else {149fields150.iter_values()151.zip(ipc_schema.fields.iter())152.map(|(field, ipc_field)| {153read(154&mut field_nodes,155&mut variadic_buffer_counts,156field,157ipc_field,158&mut buffers,159reader,160dictionaries,161block_offset,162ipc_schema.is_little_endian,163batch.compression().map_err(|err| {164polars_err!(oos = OutOfSpecKind::InvalidFlatbufferCompression(err))165})?,166limit,167version,168scratch,169checked,170)171})172.collect::<PolarsResult<Vec<_>>>()?173};174175let length = batch176.length()177.map_err(|_| polars_err!(oos = OutOfSpecKind::MissingData))178.unwrap()179.try_into()180.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;181let length = limit.map(|limit| limit.min(length)).unwrap_or(length);182183let mut schema: ArrowSchema = fields.iter_values().cloned().collect();184if let Some(projection) = projection {185schema = schema.try_project_indices(projection).unwrap();186}187RecordBatchT::try_new(length, Arc::new(schema), columns)188}189190fn find_first_dict_field_d<'a>(191id: i64,192dtype: &'a ArrowDataType,193ipc_field: &'a IpcField,194) -> Option<(&'a Field, &'a IpcField)> {195use ArrowDataType::*;196match dtype.to_storage() {197Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref(), ipc_field),198List(field) | LargeList(field) | FixedSizeList(field, ..) | Map(field, ..) => {199find_first_dict_field(id, field.as_ref(), &ipc_field.fields[0])200},201Struct(fields) => {202for (field, ipc_field) in fields.iter().zip(ipc_field.fields.iter()) {203if let Some(f) = find_first_dict_field(id, field, ipc_field) {204return Some(f);205}206}207None208},209Union(u) => {210for (field, ipc_field) in u.fields.iter().zip(ipc_field.fields.iter()) {211if let Some(f) = find_first_dict_field(id, field, ipc_field) {212return Some(f);213}214}215None216},217_ => None,218}219}220221fn find_first_dict_field<'a>(222id: i64,223field: &'a Field,224ipc_field: &'a IpcField,225) -> Option<(&'a Field, &'a IpcField)> {226if let Some(field_id) = ipc_field.dictionary_id {227if id == field_id {228return Some((field, ipc_field));229}230}231find_first_dict_field_d(id, &field.dtype, ipc_field)232}233234pub(crate) fn first_dict_field<'a>(235id: i64,236fields: &'a ArrowSchema,237ipc_fields: &'a [IpcField],238) -> PolarsResult<(&'a Field, &'a IpcField)> {239assert_eq!(fields.len(), ipc_fields.len());240for (field, ipc_field) in fields.iter_values().zip(ipc_fields.iter()) {241if let Some(field) = find_first_dict_field(id, field, ipc_field) {242return Ok(field);243}244}245Err(polars_err!(246oos = OutOfSpecKind::InvalidId { requested_id: id }247))248}249250/// Reads a dictionary from the reader,251/// updating `dictionaries` with the resulting dictionary252#[allow(clippy::too_many_arguments)]253pub fn read_dictionary<R: Read + Seek>(254batch: arrow_format::ipc::DictionaryBatchRef,255fields: &ArrowSchema,256ipc_schema: &IpcSchema,257dictionaries: &mut Dictionaries,258reader: &mut R,259block_offset: u64,260scratch: &mut Vec<u8>,261checked: UnsafeBool,262) -> PolarsResult<()> {263if batch264.is_delta()265.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferIsDelta(err)))?266{267polars_bail!(ComputeError: "delta dictionary batches not supported")268}269270let id = batch271.id()272.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferId(err)))?;273let (first_field, first_ipc_field) = first_dict_field(id, fields, &ipc_schema.fields)?;274275let batch = batch276.data()277.map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferData(err)))?278.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingData))?;279280let value_type =281if let ArrowDataType::Dictionary(_, value_type, _) = first_field.dtype.to_storage() {282value_type.as_ref()283} else {284polars_bail!(oos = OutOfSpecKind::InvalidIdDataType { requested_id: id })285};286287// Make a fake schema for the dictionary batch.288let fields = std::iter::once((289PlSmallStr::EMPTY,290Field::new(PlSmallStr::EMPTY, value_type.clone(), false),291))292.collect();293let ipc_schema = IpcSchema {294fields: vec![first_ipc_field.clone()],295is_little_endian: ipc_schema.is_little_endian,296};297let chunk = read_record_batch(298batch,299&fields,300&ipc_schema,301None,302None, // we must read the whole dictionary303dictionaries,304arrow_format::ipc::MetadataVersion::V5,305reader,306block_offset,307scratch,308checked,309)?;310311dictionaries.insert(id, chunk.into_arrays().pop().unwrap());312313Ok(())314}315316#[derive(Clone)]317pub struct ProjectionInfo {318pub columns: Vec<usize>,319pub map: PlHashMap<usize, usize>,320pub schema: ArrowSchema,321}322323pub fn prepare_projection(schema: &ArrowSchema, mut projection: Vec<usize>) -> ProjectionInfo {324let schema = projection325.iter()326.map(|x| {327let (k, v) = schema.get_at_index(*x).unwrap();328(k.clone(), v.clone())329})330.collect();331332// todo: find way to do this more efficiently333let mut indices = (0..projection.len()).collect::<Vec<_>>();334indices.sort_unstable_by_key(|&i| &projection[i]);335let map = indices.iter().copied().enumerate().fold(336PlHashMap::default(),337|mut acc, (index, new_index)| {338acc.insert(index, new_index);339acc340},341);342projection.sort_unstable();343344// check unique345if !projection.is_empty() {346let mut previous = projection[0];347348for &i in &projection[1..] {349assert!(350previous < i,351"The projection on IPC must not contain duplicates"352);353previous = i;354}355}356357ProjectionInfo {358columns: projection,359map,360schema,361}362}363364pub fn apply_projection(365chunk: RecordBatchT<Box<dyn Array>>,366map: &PlHashMap<usize, usize>,367) -> RecordBatchT<Box<dyn Array>> {368let length = chunk.len();369370// re-order according to projection371let (schema, arrays) = chunk.into_schema_and_arrays();372let mut new_schema = schema.as_ref().clone();373let mut new_arrays = arrays.clone();374375map.iter().for_each(|(old, new)| {376let (old_name, old_field) = schema.get_at_index(*old).unwrap();377let (new_name, new_field) = new_schema.get_at_index_mut(*new).unwrap();378379*new_name = old_name.clone();380*new_field = old_field.clone();381382new_arrays[*new] = arrays[*old].clone();383});384385RecordBatchT::new(length, Arc::new(new_schema), new_arrays)386}387388#[cfg(test)]389mod tests {390use super::*;391392#[test]393fn project_iter() {394let iter = 1..6;395let iter = ProjectionIter::new(&[0, 2, 4], iter);396let result: Vec<_> = iter.collect();397use ProjectionResult::*;398assert_eq!(399result,400vec![401Selected(1),402NotSelected(2),403Selected(3),404NotSelected(4),405Selected(5)406]407)408}409}410411412