Path: blob/main/crates/polars-arrow/src/io/ipc/write/common.rs
6940 views
use std::borrow::{Borrow, Cow};12use arrow_format::ipc;3use arrow_format::ipc::planus::Builder;4use polars_error::{PolarsResult, polars_bail, polars_err};56use super::super::IpcField;7use super::{write, write_dictionary};8use crate::array::*;9use crate::datatypes::*;10use crate::io::ipc::endianness::is_native_little_endian;11use crate::io::ipc::read::Dictionaries;12use crate::legacy::prelude::LargeListArray;13use crate::match_integer_type;14use crate::record_batch::RecordBatchT;15use crate::types::Index;1617/// Compression codec18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]19pub enum Compression {20/// LZ4 (framed)21LZ4,22/// ZSTD23ZSTD,24}2526/// Options declaring the behaviour of writing to IPC27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]28pub struct WriteOptions {29/// Whether the buffers should be compressed and which codec to use.30/// Note: to use compression the crate must be compiled with feature `io_ipc_compression`.31pub compression: Option<Compression>,32}3334/// Find the dictionary that are new and need to be encoded.35pub fn dictionaries_to_encode(36field: &IpcField,37array: &dyn Array,38dictionary_tracker: &mut DictionaryTracker,39dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,40) -> PolarsResult<()> {41use PhysicalType::*;42match array.dtype().to_physical_type() {43Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null44| FixedSizeBinary | BinaryView | Utf8View => Ok(()),45Dictionary(key_type) => match_integer_type!(key_type, |$T| {46let dict_id = field.dictionary_id47.ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;4849if dictionary_tracker.insert(dict_id, array)? {50dicts_to_encode.push((dict_id, array.to_boxed()));51}5253let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();54let values = array.values();55// @Q? Should this not pick fields[0]?56dictionaries_to_encode(field,57values.as_ref(),58dictionary_tracker,59dicts_to_encode,60)?;6162Ok(())63}),64Struct => {65let array = array.as_any().downcast_ref::<StructArray>().unwrap();66let fields = field.fields.as_slice();67if array.fields().len() != fields.len() {68polars_bail!(InvalidOperation:69"The number of fields in a struct must equal the number of children in IpcField".to_string(),70);71}72fields73.iter()74.zip(array.values().iter())75.try_for_each(|(field, values)| {76dictionaries_to_encode(77field,78values.as_ref(),79dictionary_tracker,80dicts_to_encode,81)82})83},84List => {85let values = array86.as_any()87.downcast_ref::<ListArray<i32>>()88.unwrap()89.values();90let field = &field.fields[0]; // todo: error instead91dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)92},93LargeList => {94let values = array95.as_any()96.downcast_ref::<ListArray<i64>>()97.unwrap()98.values();99let field = &field.fields[0]; // todo: error instead100dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)101},102FixedSizeList => {103let values = array104.as_any()105.downcast_ref::<FixedSizeListArray>()106.unwrap()107.values();108let field = &field.fields[0]; // todo: error instead109dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)110},111Union => {112let values = array113.as_any()114.downcast_ref::<UnionArray>()115.unwrap()116.fields();117let fields = &field.fields[..]; // todo: error instead118if values.len() != fields.len() {119polars_bail!(InvalidOperation:120"The number of fields in a union must equal the number of children in IpcField"121);122}123fields124.iter()125.zip(values.iter())126.try_for_each(|(field, values)| {127dictionaries_to_encode(128field,129values.as_ref(),130dictionary_tracker,131dicts_to_encode,132)133})134},135Map => {136let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();137let field = &field.fields[0]; // todo: error instead138dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)139},140}141}142143/// Encode a dictionary array with a certain id.144///145/// # Panics146///147/// This will panic if the given array is not a [`DictionaryArray`].148pub fn encode_dictionary(149dict_id: i64,150array: &dyn Array,151options: &WriteOptions,152encoded_dictionaries: &mut Vec<EncodedData>,153) -> PolarsResult<()> {154let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {155panic!("Given array is not a DictionaryArray")156};157158match_integer_type!(key_type, |$T| {159let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();160encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>(161dict_id,162array,163options,164is_native_little_endian(),165));166});167168Ok(())169}170171pub fn encode_new_dictionaries(172field: &IpcField,173array: &dyn Array,174options: &WriteOptions,175dictionary_tracker: &mut DictionaryTracker,176encoded_dictionaries: &mut Vec<EncodedData>,177) -> PolarsResult<()> {178let mut dicts_to_encode = Vec::new();179dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;180for (dict_id, dict_array) in dicts_to_encode {181encode_dictionary(dict_id, dict_array.as_ref(), options, encoded_dictionaries)?;182}183Ok(())184}185186pub fn encode_chunk(187chunk: &RecordBatchT<Box<dyn Array>>,188fields: &[IpcField],189dictionary_tracker: &mut DictionaryTracker,190options: &WriteOptions,191) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {192let mut encoded_message = EncodedData::default();193let encoded_dictionaries = encode_chunk_amortized(194chunk,195fields,196dictionary_tracker,197options,198&mut encoded_message,199)?;200Ok((encoded_dictionaries, encoded_message))201}202203// Amortizes `EncodedData` allocation.204pub fn encode_chunk_amortized(205chunk: &RecordBatchT<Box<dyn Array>>,206fields: &[IpcField],207dictionary_tracker: &mut DictionaryTracker,208options: &WriteOptions,209encoded_message: &mut EncodedData,210) -> PolarsResult<Vec<EncodedData>> {211let mut encoded_dictionaries = vec![];212213for (field, array) in fields.iter().zip(chunk.as_ref()) {214encode_new_dictionaries(215field,216array.as_ref(),217options,218dictionary_tracker,219&mut encoded_dictionaries,220)?;221}222encode_record_batch(chunk, options, encoded_message);223224Ok(encoded_dictionaries)225}226227fn serialize_compression(228compression: Option<Compression>,229) -> Option<Box<arrow_format::ipc::BodyCompression>> {230if let Some(compression) = compression {231let codec = match compression {232Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,233Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd,234};235Some(Box::new(arrow_format::ipc::BodyCompression {236codec,237method: arrow_format::ipc::BodyCompressionMethod::Buffer,238}))239} else {240None241}242}243244fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {245match array.dtype() {246ArrowDataType::Utf8View => {247let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();248counts.push(array.data_buffers().len() as i64);249},250ArrowDataType::BinaryView => {251let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();252counts.push(array.data_buffers().len() as i64);253},254ArrowDataType::Struct(_) => {255let array = array.as_any().downcast_ref::<StructArray>().unwrap();256for array in array.values() {257set_variadic_buffer_counts(counts, array.as_ref())258}259},260ArrowDataType::LargeList(_) => {261// Subslicing can change the variadic buffer count, so we have to262// slice here as well to stay synchronized.263let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();264let offsets = array.offsets().buffer();265let first = *offsets.first().unwrap();266let last = *offsets.last().unwrap();267let subslice = array268.values()269.sliced(first.to_usize(), last.to_usize() - first.to_usize());270set_variadic_buffer_counts(counts, &*subslice)271},272ArrowDataType::FixedSizeList(_, _) => {273let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();274set_variadic_buffer_counts(counts, array.values().as_ref())275},276// Don't traverse dictionary values as those are set when the `Dictionary` IPC struct277// is read.278ArrowDataType::Dictionary(_, _, _) => (),279_ => (),280}281}282283fn gc_bin_view<'a, T: ViewType + ?Sized>(284arr: &'a Box<dyn Array>,285concrete_arr: &'a BinaryViewArrayGeneric<T>,286) -> Cow<'a, Box<dyn Array>> {287let bytes_len = concrete_arr.total_bytes_len();288let buffer_len = concrete_arr.total_buffer_len();289let extra_len = buffer_len.saturating_sub(bytes_len);290if extra_len < bytes_len.min(1024) {291// We can afford some tiny waste.292Cow::Borrowed(arr)293} else {294// Force GC it.295Cow::Owned(concrete_arr.clone().gc().boxed())296}297}298299pub fn encode_array(300array: &Box<dyn Array>,301options: &WriteOptions,302variadic_buffer_counts: &mut Vec<i64>,303buffers: &mut Vec<ipc::Buffer>,304arrow_data: &mut Vec<u8>,305nodes: &mut Vec<ipc::FieldNode>,306offset: &mut i64,307) {308// We don't want to write all buffers in sliced arrays.309let array = match array.dtype() {310ArrowDataType::BinaryView => {311let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();312gc_bin_view(array, concrete_arr)313},314ArrowDataType::Utf8View => {315let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();316gc_bin_view(array, concrete_arr)317},318_ => Cow::Borrowed(array),319};320let array = array.as_ref().as_ref();321322set_variadic_buffer_counts(variadic_buffer_counts, array);323324write(325array,326buffers,327arrow_data,328nodes,329offset,330is_native_little_endian(),331options.compression,332)333}334335/// Write [`RecordBatchT`] into two sets of bytes, one for the header (ipc::Schema::Message) and the336/// other for the batch's data337pub fn encode_record_batch(338chunk: &RecordBatchT<Box<dyn Array>>,339options: &WriteOptions,340encoded_message: &mut EncodedData,341) {342let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];343let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];344encoded_message.arrow_data.clear();345346let mut offset = 0;347let mut variadic_buffer_counts = vec![];348for array in chunk.arrays() {349encode_array(350array,351options,352&mut variadic_buffer_counts,353&mut buffers,354&mut encoded_message.arrow_data,355&mut nodes,356&mut offset,357);358}359360commit_encoded_arrays(361chunk.len(),362options,363variadic_buffer_counts,364buffers,365nodes,366encoded_message,367);368}369370pub fn commit_encoded_arrays(371array_len: usize,372options: &WriteOptions,373variadic_buffer_counts: Vec<i64>,374buffers: Vec<ipc::Buffer>,375nodes: Vec<ipc::FieldNode>,376encoded_message: &mut EncodedData,377) {378let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {379None380} else {381Some(variadic_buffer_counts)382};383384let compression = serialize_compression(options.compression);385386let message = arrow_format::ipc::Message {387version: arrow_format::ipc::MetadataVersion::V5,388header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(389arrow_format::ipc::RecordBatch {390length: array_len as i64,391nodes: Some(nodes),392buffers: Some(buffers),393compression,394variadic_buffer_counts,395},396))),397body_length: encoded_message.arrow_data.len() as i64,398custom_metadata: None,399};400401let mut builder = Builder::new();402let ipc_message = builder.finish(&message, None);403encoded_message.ipc_message = ipc_message.to_vec();404}405406/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the407/// other for the data408fn dictionary_batch_to_bytes<K: DictionaryKey>(409dict_id: i64,410array: &DictionaryArray<K>,411options: &WriteOptions,412is_little_endian: bool,413) -> EncodedData {414let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];415let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];416let mut arrow_data: Vec<u8> = vec![];417let mut variadic_buffer_counts = vec![];418set_variadic_buffer_counts(&mut variadic_buffer_counts, array.values().as_ref());419420let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {421None422} else {423Some(variadic_buffer_counts)424};425426let length = write_dictionary(427array,428&mut buffers,429&mut arrow_data,430&mut nodes,431&mut 0,432is_little_endian,433options.compression,434false,435);436437let compression = serialize_compression(options.compression);438439let message = arrow_format::ipc::Message {440version: arrow_format::ipc::MetadataVersion::V5,441header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(442arrow_format::ipc::DictionaryBatch {443id: dict_id,444data: Some(Box::new(arrow_format::ipc::RecordBatch {445length: length as i64,446nodes: Some(nodes),447buffers: Some(buffers),448compression,449variadic_buffer_counts,450})),451is_delta: false,452},453))),454body_length: arrow_data.len() as i64,455custom_metadata: None,456};457458let mut builder = Builder::new();459let ipc_message = builder.finish(&message, None);460461EncodedData {462ipc_message: ipc_message.to_vec(),463arrow_data,464}465}466467/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary468/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which469/// isn't allowed in the `FileWriter`.470pub struct DictionaryTracker {471pub dictionaries: Dictionaries,472pub cannot_replace: bool,473}474475impl DictionaryTracker {476/// Keep track of the dictionary with the given ID and values. Behavior:477///478/// * If this ID has been written already and has the same data, return `Ok(false)` to indicate479/// that the dictionary was not actually inserted (because it's already been seen).480/// * If this ID has been written already but with different data, and this tracker is481/// configured to return an error, return an error.482/// * If the tracker has not been configured to error on replacement or this dictionary483/// has never been seen before, return `Ok(true)` to indicate that the dictionary was just484/// inserted.485pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {486let values = match array.dtype() {487ArrowDataType::Dictionary(key_type, _, _) => {488match_integer_type!(key_type, |$T| {489let array = array490.as_any()491.downcast_ref::<DictionaryArray<$T>>()492.unwrap();493array.values()494})495},496_ => unreachable!(),497};498499// If a dictionary with this id was already emitted, check if it was the same.500if let Some(last) = self.dictionaries.get(&dict_id) {501if last.as_ref() == values.as_ref() {502// Same dictionary values => no need to emit it again503return Ok(false);504} else if self.cannot_replace {505polars_bail!(InvalidOperation:506"Dictionary replacement detected when writing IPC file format. \507Arrow IPC files only support a single dictionary for a given field \508across all batches."509);510}511};512513self.dictionaries.insert(dict_id, values.clone());514Ok(true)515}516}517518/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data519#[derive(Debug, Default)]520pub struct EncodedData {521/// An encoded ipc::Schema::Message522pub ipc_message: Vec<u8>,523/// Arrow buffers to be written, should be an empty vec for schema messages524pub arrow_data: Vec<u8>,525}526527/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes528#[inline]529pub(crate) fn pad_to_64(len: usize) -> usize {530((len + 63) & !63) - len531}532533/// An array [`RecordBatchT`] with optional accompanying IPC fields.534#[derive(Debug, Clone, PartialEq)]535pub struct Record<'a> {536columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,537fields: Option<Cow<'a, [IpcField]>>,538}539540impl Record<'_> {541/// Get the IPC fields for this record.542pub fn fields(&self) -> Option<&[IpcField]> {543self.fields.as_deref()544}545546/// Get the Arrow columns in this record.547pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {548self.columns.borrow()549}550}551552impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {553fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {554Self {555columns: Cow::Owned(columns),556fields: None,557}558}559}560561impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>562where563F: Into<Cow<'a, [IpcField]>>,564{565fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {566Self {567columns: Cow::Owned(columns),568fields: fields.map(|f| f.into()),569}570}571}572573impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>574where575F: Into<Cow<'a, [IpcField]>>,576{577fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {578Self {579columns: Cow::Borrowed(columns),580fields: fields.map(|f| f.into()),581}582}583}584585586