Path: blob/main/crates/polars-io/src/parquet/write/batched_writer.rs
8506 views
use std::io::Write;1use std::sync::Mutex;23use arrow::record_batch::RecordBatch;4use polars_buffer::Buffer;5use polars_core::POOL;6use polars_core::prelude::*;7use polars_parquet::read::{ParquetError, fallible_streaming_iterator};8use polars_parquet::write::{9CompressedPage, Compressor, DynIter, DynStreamingIterator, Encoding, FallibleStreamingIterator,10FileWriter, Page, ParquetType, RowGroupIterColumns, SchemaDescriptor, WriteOptions,11array_to_columns, schema_to_metadata_key,12};13use rayon::prelude::*;1415use super::{KeyValueMetadata, ParquetMetadataContext};1617pub struct BatchedWriter<W: Write> {18// A mutex so that streaming engine can get concurrent read access to19// compress pages.20//21// @TODO: Remove mutex when old streaming engine is removed22pub(super) writer: Mutex<FileWriter<W>>,23// @TODO: Remove when old streaming engine is removed24pub(super) parquet_schema: SchemaDescriptor,25pub(super) encodings: Buffer<Vec<Encoding>>,26pub(super) options: WriteOptions,27pub(super) parallel: bool,28pub(super) key_value_metadata: Option<KeyValueMetadata>,29}3031impl<W: Write> BatchedWriter<W> {32pub fn new(33writer: Mutex<FileWriter<W>>,34encodings: Buffer<Vec<Encoding>>,35options: WriteOptions,36parallel: bool,37key_value_metadata: Option<KeyValueMetadata>,38) -> Self {39Self {40writer,41parquet_schema: SchemaDescriptor::new(PlSmallStr::EMPTY, vec![]),42encodings,43options,44parallel,45key_value_metadata,46}47}4849pub fn encode_and_compress<'a>(50&'a self,51df: &'a DataFrame,52) -> impl Iterator<Item = PolarsResult<RowGroupIterColumns<'static, PolarsError>>> + 'a {53let rb_iter = df.iter_chunks(CompatLevel::newest(), false);54rb_iter.filter_map(move |batch| match batch.len() {550 => None,56_ => {57let row_group = create_eager_serializer(58batch,59self.parquet_schema.fields(),60self.encodings.as_ref(),61self.options,62);6364Some(row_group)65},66})67}6869/// Write a batch to the parquet writer.70///71/// # Panics72/// The caller must ensure the chunks in the given [`DataFrame`] are aligned.73pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {74let row_group_iter = prepare_rg_iter(75df,76&self.parquet_schema,77&self.encodings,78self.options,79self.parallel,80);81// Lock before looping so that order is maintained under contention.82let mut writer = self.writer.lock().unwrap();83for (num_rows, group) in row_group_iter {84writer.write(num_rows as u64, group?)?;85}86Ok(())87}8889pub fn parquet_schema(&mut self) -> &SchemaDescriptor {90let writer = self.writer.get_mut().unwrap();91writer.parquet_schema()92}9394/// Note: `num_rows` can be passed as `u64::MAX` to infer `num_rows` from the encoded data.95pub fn write_row_group(96&mut self,97num_rows: u64,98rg: &[Vec<CompressedPage>],99) -> PolarsResult<()> {100let writer = self.writer.get_mut().unwrap();101let rg = DynIter::new(rg.iter().map(|col_pages| {102Ok(DynStreamingIterator::new(103fallible_streaming_iterator::convert(col_pages.iter().map(PolarsResult::Ok)),104))105}));106writer.write(num_rows, rg)?;107Ok(())108}109110pub fn get_writer(&self) -> &Mutex<FileWriter<W>> {111&self.writer112}113114pub fn write_row_groups(115&self,116rgs: Vec<RowGroupIterColumns<'static, PolarsError>>,117) -> PolarsResult<()> {118// Lock before looping so that order is maintained.119let mut writer = self.writer.lock().unwrap();120for group in rgs {121writer.write(u64::MAX, group)?;122}123Ok(())124}125126/// Writes the footer of the parquet file. Returns the total size of the file.127pub fn finish(&self) -> PolarsResult<u64> {128let mut writer = self.writer.lock().unwrap();129130let key_value_metadata = self131.key_value_metadata132.as_ref()133.map(|meta| {134let arrow_schema = schema_to_metadata_key(writer.schema());135let ctx = ParquetMetadataContext {136arrow_schema: arrow_schema.value.as_ref().unwrap(),137};138let mut out = meta.collect(ctx)?;139if !out.iter().any(|kv| kv.key == arrow_schema.key) {140out.insert(0, arrow_schema);141}142PolarsResult::Ok(out)143})144.transpose()?;145146let size = writer.end(key_value_metadata)?;147Ok(size)148}149}150151// Note that the df should be rechunked152fn prepare_rg_iter<'a>(153df: &'a DataFrame,154parquet_schema: &'a SchemaDescriptor,155encodings: &'a [Vec<Encoding>],156options: WriteOptions,157parallel: bool,158) -> impl Iterator<159Item = (160usize,161PolarsResult<RowGroupIterColumns<'static, PolarsError>>,162),163> + 'a {164let rb_iter = df.iter_chunks(CompatLevel::newest(), false);165rb_iter.filter_map(move |batch| match batch.len() {1660 => None,167num_rows => {168let row_group =169create_serializer(batch, parquet_schema.fields(), encodings, options, parallel);170171Some((num_rows, row_group))172},173})174}175176fn pages_iter_to_compressor(177encoded_columns: Vec<DynIter<'static, PolarsResult<Page>>>,178options: WriteOptions,179) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {180encoded_columns181.into_iter()182.map(|encoded_pages| {183// iterator over pages184let pages = DynStreamingIterator::new(185Compressor::new_from_vec(186encoded_pages.map(|result| {187result.map_err(|e| {188ParquetError::FeatureNotSupported(format!("reraised in polars: {e}",))189})190}),191options.compression,192vec![],193)194.map_err(PolarsError::from),195);196197Ok(pages)198})199.collect::<Vec<_>>()200}201202fn array_to_pages_iter(203array: &ArrayRef,204type_: &ParquetType,205encoding: &[Encoding],206options: WriteOptions,207) -> Vec<PolarsResult<DynStreamingIterator<'static, CompressedPage, PolarsError>>> {208let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap();209pages_iter_to_compressor(encoded_columns, options)210}211212fn create_serializer(213batch: RecordBatch,214fields: &[ParquetType],215encodings: &[Vec<Encoding>],216options: WriteOptions,217parallel: bool,218) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {219let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec<Encoding>)| {220array_to_pages_iter(array, type_, encoding, options)221};222223let columns = if parallel {224POOL.install(|| {225batch226.columns()227.par_iter()228.zip(fields)229.zip(encodings)230.flat_map(func)231.collect::<Vec<_>>()232})233} else {234batch235.columns()236.iter()237.zip(fields)238.zip(encodings)239.flat_map(func)240.collect::<Vec<_>>()241};242243let row_group = DynIter::new(columns.into_iter());244245Ok(row_group)246}247248/// This serializer encodes and compresses all eagerly in memory.249/// Used for separating compute from IO.250fn create_eager_serializer(251batch: RecordBatch,252fields: &[ParquetType],253encodings: &[Vec<Encoding>],254options: WriteOptions,255) -> PolarsResult<RowGroupIterColumns<'static, PolarsError>> {256let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec<Encoding>)| {257array_to_pages_iter(array, type_, encoding, options)258};259260let columns = batch261.columns()262.iter()263.zip(fields)264.zip(encodings)265.flat_map(func)266.collect::<Vec<_>>();267268let row_group = DynIter::new(columns.into_iter());269270Ok(row_group)271}272273274