Path: blob/main/crates/polars-arrow/src/io/ipc/write/common.rs
8424 views
use std::borrow::{Borrow, Cow};12use arrow_format::ipc;3use arrow_format::ipc::KeyValue;4use arrow_format::ipc::planus::Builder;5use bytes::Bytes;6use polars_error::{PolarsResult, polars_bail, polars_err};7use polars_utils::compression::ZstdLevel;89use super::super::IpcField;10use super::write;11use crate::array::*;12use crate::datatypes::*;13use crate::io::ipc::endianness::is_native_little_endian;14use crate::io::ipc::read::Dictionaries;15use crate::legacy::prelude::LargeListArray;16use crate::match_integer_type;17use crate::record_batch::RecordBatchT;18use crate::types::Index;1920/// Compression codec21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]22pub enum Compression {23/// LZ4 (framed)24LZ4,25/// ZSTD26ZSTD(ZstdLevel),27}2829/// Options declaring the behaviour of writing to IPC30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]31pub struct WriteOptions {32/// Whether the buffers should be compressed and which codec to use.33/// Note: to use compression the crate must be compiled with feature `io_ipc_compression`.34pub compression: Option<Compression>,35}3637/// Find the dictionary that are new and need to be encoded.38pub fn dictionaries_to_encode(39field: &IpcField,40array: &dyn Array,41dictionary_tracker: &mut DictionaryTracker,42dicts_to_encode: &mut Vec<(i64, Box<dyn Array>)>,43) -> PolarsResult<()> {44use PhysicalType::*;45match array.dtype().to_physical_type() {46Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null47| FixedSizeBinary | BinaryView | Utf8View => Ok(()),48Dictionary(key_type) => match_integer_type!(key_type, |$T| {49let dict_id = field.dictionary_id50.ok_or_else(|| polars_err!(InvalidOperation: "Dictionaries must have an associated id"))?;5152if dictionary_tracker.insert(dict_id, array)? {53dicts_to_encode.push((dict_id, array.to_boxed()));54}5556let array = array.as_any().downcast_ref::<DictionaryArray<$T>>().unwrap();57let values = array.values();58// @Q? Should this not pick fields[0]?59dictionaries_to_encode(field,60values.as_ref(),61dictionary_tracker,62dicts_to_encode,63)?;6465Ok(())66}),67Struct => {68let array = array.as_any().downcast_ref::<StructArray>().unwrap();69let fields = field.fields.as_slice();70if array.fields().len() != fields.len() {71polars_bail!(InvalidOperation: "The number of fields in a struct must equal the number of children in IpcField");72}73fields74.iter()75.zip(array.values().iter())76.try_for_each(|(field, values)| {77dictionaries_to_encode(78field,79values.as_ref(),80dictionary_tracker,81dicts_to_encode,82)83})84},85List => {86let values = array87.as_any()88.downcast_ref::<ListArray<i32>>()89.unwrap()90.values();91let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;92dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)93},94LargeList => {95let values = array96.as_any()97.downcast_ref::<ListArray<i64>>()98.unwrap()99.values();100let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;101dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)102},103FixedSizeList => {104let values = array105.as_any()106.downcast_ref::<FixedSizeListArray>()107.unwrap()108.values();109let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;110dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)111},112Union => {113let values = array114.as_any()115.downcast_ref::<UnionArray>()116.unwrap()117.fields();118let fields = field.fields.as_slice();119if values.len() != fields.len() {120polars_bail!(InvalidOperation:121"The number of fields in a union must equal the number of children in IpcField"122);123}124fields125.iter()126.zip(values.iter())127.try_for_each(|(field, values)| {128dictionaries_to_encode(129field,130values.as_ref(),131dictionary_tracker,132dicts_to_encode,133)134})135},136Map => {137let values = array.as_any().downcast_ref::<MapArray>().unwrap().field();138let field = field.fields.first().ok_or_else(|| polars_err!(ComputeError: "Invalid IPC field structure: expected nested field but fields vector is empty"))?;139dictionaries_to_encode(field, values.as_ref(), dictionary_tracker, dicts_to_encode)140},141}142}143144/// Encode a dictionary array with a certain id.145///146/// # Panics147///148/// This will panic if the given array is not a [`DictionaryArray`].149pub fn encode_dictionary(150dict_id: i64,151array: &dyn Array,152options: &WriteOptions,153) -> PolarsResult<EncodedData> {154let PhysicalType::Dictionary(key_type) = array.dtype().to_physical_type() else {155panic!("Given array is not a DictionaryArray")156};157158match_integer_type!(key_type, |$T| {159let array: &DictionaryArray<$T> = array.as_any().downcast_ref().unwrap();160161encode_dictionary_values(dict_id, array.values().as_ref(), options)162})163}164165pub fn encode_new_dictionaries(166field: &IpcField,167array: &dyn Array,168options: &WriteOptions,169dictionary_tracker: &mut DictionaryTracker,170encoded_dictionaries: &mut Vec<EncodedData>,171) -> PolarsResult<()> {172let mut dicts_to_encode = Vec::new();173dictionaries_to_encode(field, array, dictionary_tracker, &mut dicts_to_encode)?;174for (dict_id, dict_array) in dicts_to_encode {175encoded_dictionaries.push(encode_dictionary(dict_id, dict_array.as_ref(), options)?);176}177Ok(())178}179180pub fn encode_chunk(181chunk: &RecordBatchT<Box<dyn Array>>,182fields: &[IpcField],183dictionary_tracker: &mut DictionaryTracker,184options: &WriteOptions,185) -> PolarsResult<(Vec<EncodedData>, EncodedData)> {186let mut encoded_message = EncodedData::default();187let encoded_dictionaries = encode_chunk_amortized(188chunk,189fields,190dictionary_tracker,191options,192&mut encoded_message,193)?;194Ok((encoded_dictionaries, encoded_message))195}196197// Amortizes `EncodedData` allocation.198pub fn encode_chunk_amortized(199chunk: &RecordBatchT<Box<dyn Array>>,200fields: &[IpcField],201dictionary_tracker: &mut DictionaryTracker,202options: &WriteOptions,203encoded_message: &mut EncodedData,204) -> PolarsResult<Vec<EncodedData>> {205let mut encoded_dictionaries = vec![];206207for (field, array) in fields.iter().zip(chunk.as_ref()) {208encode_new_dictionaries(209field,210array.as_ref(),211options,212dictionary_tracker,213&mut encoded_dictionaries,214)?;215}216encode_record_batch(chunk, options, encoded_message);217218Ok(encoded_dictionaries)219}220221fn serialize_compression(222compression: Option<Compression>,223) -> Option<Box<arrow_format::ipc::BodyCompression>> {224if let Some(compression) = compression {225let codec = match compression {226Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame,227Compression::ZSTD(_) => arrow_format::ipc::CompressionType::Zstd,228};229Some(Box::new(arrow_format::ipc::BodyCompression {230codec,231method: arrow_format::ipc::BodyCompressionMethod::Buffer,232}))233} else {234None235}236}237238fn set_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {239match array.dtype().to_storage() {240ArrowDataType::Utf8View => {241let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();242counts.push(array.data_buffers().len() as i64);243},244ArrowDataType::BinaryView => {245let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();246counts.push(array.data_buffers().len() as i64);247},248ArrowDataType::Struct(_) => {249let array = array.as_any().downcast_ref::<StructArray>().unwrap();250for array in array.values() {251set_variadic_buffer_counts(counts, array.as_ref())252}253},254ArrowDataType::LargeList(_) => {255// Subslicing can change the variadic buffer count, so we have to256// slice here as well to stay synchronized.257let array = array.as_any().downcast_ref::<LargeListArray>().unwrap();258let offsets = array.offsets().buffer();259let first = *offsets.first().unwrap();260let last = *offsets.last().unwrap();261let subslice = array262.values()263.sliced(first.to_usize(), last.to_usize() - first.to_usize());264set_variadic_buffer_counts(counts, &*subslice)265},266ArrowDataType::FixedSizeList(_, _) => {267let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();268set_variadic_buffer_counts(counts, array.values().as_ref())269},270// Don't traverse dictionary values as those are set when the `Dictionary` IPC struct271// is read.272ArrowDataType::Dictionary(_, _, _) => (),273_ => (),274}275}276277fn gc_bin_view<'a, T: ViewType + ?Sized>(278arr: &'a Box<dyn Array>,279concrete_arr: &'a BinaryViewArrayGeneric<T>,280) -> Cow<'a, Box<dyn Array>> {281let bytes_len = concrete_arr.total_bytes_len();282let buffer_len = concrete_arr.total_buffer_len();283let extra_len = buffer_len.saturating_sub(bytes_len);284if extra_len < bytes_len.min(1024) {285// We can afford some tiny waste.286Cow::Borrowed(arr)287} else {288// Force GC it.289Cow::Owned(concrete_arr.clone().gc().boxed())290}291}292293pub fn encode_array(294array: &Box<dyn Array>,295options: &WriteOptions,296variadic_buffer_counts: &mut Vec<i64>,297buffers: &mut Vec<ipc::Buffer>,298arrow_data: &mut Vec<u8>,299nodes: &mut Vec<ipc::FieldNode>,300offset: &mut i64,301) {302// We don't want to write all buffers in sliced arrays.303let array = match array.dtype() {304ArrowDataType::BinaryView => {305let concrete_arr = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();306gc_bin_view(array, concrete_arr)307},308ArrowDataType::Utf8View => {309let concrete_arr = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();310gc_bin_view(array, concrete_arr)311},312_ => Cow::Borrowed(array),313};314let array = array.as_ref().as_ref();315316set_variadic_buffer_counts(variadic_buffer_counts, array);317318write(319array,320buffers,321arrow_data,322nodes,323offset,324is_native_little_endian(),325options.compression,326)327}328329/// Write [`RecordBatchT`] into two sets of bytes, one for the header (ipc::Schema::Message) and the330/// other for the batch's data331pub fn encode_record_batch(332chunk: &RecordBatchT<Box<dyn Array>>,333options: &WriteOptions,334encoded_message: &mut EncodedData,335) {336let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];337let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];338encoded_message.arrow_data.clear();339340let mut offset = 0;341let mut variadic_buffer_counts = vec![];342for array in chunk.arrays() {343encode_array(344array,345options,346&mut variadic_buffer_counts,347&mut buffers,348&mut encoded_message.arrow_data,349&mut nodes,350&mut offset,351);352}353354commit_encoded_arrays(355chunk.len(),356options,357variadic_buffer_counts,358buffers,359nodes,360None,361encoded_message,362);363}364365pub fn commit_encoded_arrays(366array_len: usize,367options: &WriteOptions,368variadic_buffer_counts: Vec<i64>,369buffers: Vec<ipc::Buffer>,370nodes: Vec<ipc::FieldNode>,371custom_metadata: Option<Vec<KeyValue>>,372encoded_message: &mut EncodedData,373) {374let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {375None376} else {377Some(variadic_buffer_counts)378};379380let compression = serialize_compression(options.compression);381382let message = arrow_format::ipc::Message {383version: arrow_format::ipc::MetadataVersion::V5,384header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new(385arrow_format::ipc::RecordBatch {386length: array_len as i64,387nodes: Some(nodes),388buffers: Some(buffers),389compression,390variadic_buffer_counts,391},392))),393body_length: encoded_message.arrow_data.len() as i64,394custom_metadata,395};396397let mut builder = Builder::new();398let ipc_message = builder.finish(&message, None);399encoded_message.ipc_message = ipc_message.to_vec();400}401402pub fn encode_dictionary_values(403dict_id: i64,404values_array: &dyn Array,405options: &WriteOptions,406) -> PolarsResult<EncodedData> {407let mut nodes: Vec<arrow_format::ipc::FieldNode> = vec![];408let mut buffers: Vec<arrow_format::ipc::Buffer> = vec![];409let mut arrow_data: Vec<u8> = vec![];410let mut variadic_buffer_counts = vec![];411set_variadic_buffer_counts(&mut variadic_buffer_counts, values_array);412413let variadic_buffer_counts = if variadic_buffer_counts.is_empty() {414None415} else {416Some(variadic_buffer_counts)417};418419write(420values_array,421&mut buffers,422&mut arrow_data,423&mut nodes,424&mut 0,425is_native_little_endian(),426options.compression,427);428429let compression = serialize_compression(options.compression);430431let message = arrow_format::ipc::Message {432version: arrow_format::ipc::MetadataVersion::V5,433header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new(434arrow_format::ipc::DictionaryBatch {435id: dict_id,436data: Some(Box::new(arrow_format::ipc::RecordBatch {437length: values_array.len() as i64,438nodes: Some(nodes),439buffers: Some(buffers),440compression,441variadic_buffer_counts,442})),443is_delta: false,444},445))),446body_length: arrow_data.len() as i64,447custom_metadata: None,448};449450let mut builder = Builder::new();451let ipc_message = builder.finish(&message, None);452453Ok(EncodedData {454ipc_message: ipc_message.to_vec(),455arrow_data,456})457}458459/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary460/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which461/// isn't allowed in the `FileWriter`.462pub struct DictionaryTracker {463pub dictionaries: Dictionaries,464pub cannot_replace: bool,465}466467impl DictionaryTracker {468/// Keep track of the dictionary with the given ID and values. Behavior:469///470/// * If this ID has been written already and has the same data, return `Ok(false)` to indicate471/// that the dictionary was not actually inserted (because it's already been seen).472/// * If this ID has been written already but with different data, and this tracker is473/// configured to return an error, return an error.474/// * If the tracker has not been configured to error on replacement or this dictionary475/// has never been seen before, return `Ok(true)` to indicate that the dictionary was just476/// inserted.477pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> PolarsResult<bool> {478let values = match array.dtype().to_storage() {479ArrowDataType::Dictionary(key_type, _, _) => {480match_integer_type!(key_type, |$T| {481let array = array482.as_any()483.downcast_ref::<DictionaryArray<$T>>()484.unwrap();485array.values()486})487},488_ => unreachable!(),489};490491// If a dictionary with this id was already emitted, check if it was the same.492if let Some(last) = self.dictionaries.get(&dict_id) {493if last.as_ref() == values.as_ref() {494// Same dictionary values => no need to emit it again495return Ok(false);496} else if self.cannot_replace {497polars_bail!(InvalidOperation:498"Dictionary replacement detected when writing IPC file format. \499Arrow IPC files only support a single dictionary for a given field \500across all batches."501);502}503};504505self.dictionaries.insert(dict_id, values.clone());506Ok(true)507}508}509510/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data511#[derive(Debug, Default)]512pub struct EncodedData {513/// An encoded ipc::Schema::Message514pub ipc_message: Vec<u8>,515/// Arrow buffers to be written, should be an empty vec for schema messages516pub arrow_data: Vec<u8>,517}518519/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data520#[derive(Debug, Default)]521pub struct EncodedDataBytes {522/// An encoded ipc::Schema::Message523pub ipc_message: Bytes,524/// Arrow buffers to be written, should be an empty vec for schema messages525pub arrow_data: Bytes,526}527528/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes529#[inline]530pub(crate) fn pad_to_64(len: usize) -> usize {531((len + 63) & !63) - len532}533534/// An array [`RecordBatchT`] with optional accompanying IPC fields.535#[derive(Debug, Clone, PartialEq)]536pub struct Record<'a> {537columns: Cow<'a, RecordBatchT<Box<dyn Array>>>,538fields: Option<Cow<'a, [IpcField]>>,539}540541impl Record<'_> {542/// Get the IPC fields for this record.543pub fn fields(&self) -> Option<&[IpcField]> {544self.fields.as_deref()545}546547/// Get the Arrow columns in this record.548pub fn columns(&self) -> &RecordBatchT<Box<dyn Array>> {549self.columns.borrow()550}551}552553impl From<RecordBatchT<Box<dyn Array>>> for Record<'static> {554fn from(columns: RecordBatchT<Box<dyn Array>>) -> Self {555Self {556columns: Cow::Owned(columns),557fields: None,558}559}560}561562impl<'a, F> From<(RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>563where564F: Into<Cow<'a, [IpcField]>>,565{566fn from((columns, fields): (RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {567Self {568columns: Cow::Owned(columns),569fields: fields.map(|f| f.into()),570}571}572}573574impl<'a, F> From<(&'a RecordBatchT<Box<dyn Array>>, Option<F>)> for Record<'a>575where576F: Into<Cow<'a, [IpcField]>>,577{578fn from((columns, fields): (&'a RecordBatchT<Box<dyn Array>>, Option<F>)) -> Self {579Self {580columns: Cow::Borrowed(columns),581fields: fields.map(|f| f.into()),582}583}584}585586/// Create an IPC Block. Will panic when size limitations are not met.587pub fn arrow_ipc_block(588offset: usize,589meta_data_length: usize,590body_length: usize,591) -> arrow_format::ipc::Block {592arrow_format::ipc::Block {593offset: i64::try_from(offset).unwrap(),594meta_data_length: i32::try_from(meta_data_length).unwrap(),595body_length: i64::try_from(body_length).unwrap(),596}597}598599600