Path: blob/main/crates/polars-io/src/parquet/write/writer.rs
6940 views
use std::io::Write;1use std::sync::Mutex;23use arrow::datatypes::PhysicalType;4use polars_core::frame::chunk_df_for_writing;5use polars_core::prelude::*;6use polars_parquet::write::{7ChildWriteOptions, ColumnWriteOptions, CompressionOptions, Encoding, FieldWriteOptions,8FileWriter, KeyValue, ListLikeFieldWriteOptions, StatisticsOptions, StructFieldWriteOptions,9Version, WriteOptions, to_parquet_schema,10};1112use super::batched_writer::BatchedWriter;13use super::options::ParquetCompression;14use super::{KeyValueMetadata, MetadataKeyValue, ParquetFieldOverwrites, ParquetWriteOptions};15use crate::prelude::ChildFieldOverwrites;16use crate::shared::schema_to_arrow_checked;1718impl ParquetWriteOptions {19pub fn to_writer<F>(&self, f: F) -> ParquetWriter<F>20where21F: Write,22{23ParquetWriter::new(f)24.with_compression(self.compression)25.with_statistics(self.statistics)26.with_row_group_size(self.row_group_size)27.with_data_page_size(self.data_page_size)28.with_key_value_metadata(self.key_value_metadata.clone())29}30}3132/// Write a DataFrame to Parquet format.33#[must_use]34pub struct ParquetWriter<W> {35writer: W,36/// Data page compression37compression: CompressionOptions,38/// Compute and write column statistics.39statistics: StatisticsOptions,40/// if `None` will be 512^2 rows41row_group_size: Option<usize>,42/// if `None` will be 1024^2 bytes43data_page_size: Option<usize>,44/// Serialize columns in parallel45parallel: bool,46field_overwrites: Vec<ParquetFieldOverwrites>,47/// Custom file-level key value metadata48key_value_metadata: Option<KeyValueMetadata>,49/// Context info for the Parquet file being written.50context_info: Option<PlHashMap<String, String>>,51}5253impl<W> ParquetWriter<W>54where55W: Write,56{57/// Create a new writer58pub fn new(writer: W) -> Self59where60W: Write,61{62ParquetWriter {63writer,64compression: ParquetCompression::default().into(),65statistics: StatisticsOptions::default(),66row_group_size: None,67data_page_size: None,68parallel: true,69field_overwrites: Vec::new(),70key_value_metadata: None,71context_info: None,72}73}7475/// Set the compression used. Defaults to `Zstd`.76///77/// The default compression `Zstd` has very good performance, but may not yet been supported78/// by older readers. If you want more compatibility guarantees, consider using `Snappy`.79pub fn with_compression(mut self, compression: ParquetCompression) -> Self {80self.compression = compression.into();81self82}8384/// Compute and write statistic85pub fn with_statistics(mut self, statistics: StatisticsOptions) -> Self {86self.statistics = statistics;87self88}8990/// Set the row group size (in number of rows) during writing. This can reduce memory pressure and improve91/// writing performance.92pub fn with_row_group_size(mut self, size: Option<usize>) -> Self {93self.row_group_size = size;94self95}9697/// Sets the maximum bytes size of a data page. If `None` will be 1024^2 bytes.98pub fn with_data_page_size(mut self, limit: Option<usize>) -> Self {99self.data_page_size = limit;100self101}102103/// Serialize columns in parallel104pub fn set_parallel(mut self, parallel: bool) -> Self {105self.parallel = parallel;106self107}108109/// Set custom file-level key value metadata for the Parquet file110pub fn with_key_value_metadata(mut self, key_value_metadata: Option<KeyValueMetadata>) -> Self {111self.key_value_metadata = key_value_metadata;112self113}114115/// Set context information for the writer116pub fn with_context_info(mut self, context_info: Option<PlHashMap<String, String>>) -> Self {117self.context_info = context_info;118self119}120121pub fn batched(self, schema: &Schema) -> PolarsResult<BatchedWriter<W>> {122let schema = schema_to_arrow_checked(schema, CompatLevel::newest(), "parquet")?;123let column_options = get_column_write_options(&schema, &self.field_overwrites);124let parquet_schema = to_parquet_schema(&schema, &column_options)?;125let options = self.materialize_options();126let writer = Mutex::new(FileWriter::try_new(127self.writer,128schema,129options,130&column_options,131)?);132133Ok(BatchedWriter {134writer,135parquet_schema,136column_options,137options,138parallel: self.parallel,139key_value_metadata: self.key_value_metadata,140})141}142143fn materialize_options(&self) -> WriteOptions {144WriteOptions {145statistics: self.statistics,146compression: self.compression,147version: Version::V1,148data_page_size: self.data_page_size,149}150}151152/// Write the given DataFrame in the writer `W`.153/// Returns the total size of the file.154pub fn finish(self, df: &mut DataFrame) -> PolarsResult<u64> {155let chunked_df = chunk_df_for_writing(df, self.row_group_size.unwrap_or(512 * 512))?;156let mut batched = self.batched(chunked_df.schema())?;157batched.write_batch(&chunked_df)?;158batched.finish()159}160}161162fn convert_metadata(md: &Option<Vec<MetadataKeyValue>>) -> Vec<KeyValue> {163md.as_ref()164.map(|metadata| {165metadata166.iter()167.map(|kv| KeyValue {168key: kv.key.to_string(),169value: kv.value.as_ref().map(|v| v.to_string()),170})171.collect()172})173.unwrap_or_default()174}175176fn to_column_write_options_rec(177field: &ArrowField,178overwrites: Option<&ParquetFieldOverwrites>,179) -> ColumnWriteOptions {180let mut column_options = ColumnWriteOptions {181field_id: None,182metadata: Vec::new(),183required: None,184185// Dummy value.186children: ChildWriteOptions::Leaf(FieldWriteOptions {187encoding: Encoding::Plain,188}),189};190191if let Some(overwrites) = overwrites {192column_options.field_id = overwrites.field_id;193column_options.metadata = convert_metadata(&overwrites.metadata);194column_options.required = overwrites.required;195}196197use arrow::datatypes::PhysicalType::*;198match field.dtype().to_physical_type() {199Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8200| Dictionary(_) | LargeUtf8 | BinaryView | Utf8View => {201column_options.children = ChildWriteOptions::Leaf(FieldWriteOptions {202encoding: encoding_map(field.dtype()),203});204},205List | FixedSizeList | LargeList => {206let child_overwrites = overwrites.and_then(|o| match &o.children {207ChildFieldOverwrites::None => None,208ChildFieldOverwrites::ListLike(child_overwrites) => Some(child_overwrites.as_ref()),209_ => unreachable!(),210});211212let a = field.dtype().to_logical_type();213let child = if let ArrowDataType::List(inner) = a {214to_column_write_options_rec(inner, child_overwrites)215} else if let ArrowDataType::LargeList(inner) = a {216to_column_write_options_rec(inner, child_overwrites)217} else if let ArrowDataType::FixedSizeList(inner, _) = a {218to_column_write_options_rec(inner, child_overwrites)219} else {220unreachable!()221};222223column_options.children =224ChildWriteOptions::ListLike(Box::new(ListLikeFieldWriteOptions { child }));225},226Struct => {227if let ArrowDataType::Struct(fields) = field.dtype().to_logical_type() {228let children_overwrites = overwrites.and_then(|o| match &o.children {229ChildFieldOverwrites::None => None,230ChildFieldOverwrites::Struct(child_overwrites) => Some(PlHashMap::from_iter(231child_overwrites232.iter()233.map(|f| (f.name.as_ref().unwrap(), f)),234)),235_ => unreachable!(),236});237238let children = fields239.iter()240.map(|f| {241let overwrites = children_overwrites242.as_ref()243.and_then(|o| o.get(&f.name).copied());244to_column_write_options_rec(f, overwrites)245})246.collect();247248column_options.children =249ChildWriteOptions::Struct(Box::new(StructFieldWriteOptions { children }));250} else {251unreachable!()252}253},254255Map | Union => unreachable!(),256}257258column_options259}260261pub fn get_column_write_options(262schema: &ArrowSchema,263field_overwrites: &[ParquetFieldOverwrites],264) -> Vec<ColumnWriteOptions> {265let field_overwrites = PlHashMap::from(266field_overwrites267.iter()268.map(|f| (f.name.as_ref().unwrap(), f))269.collect(),270);271schema272.iter_values()273.map(|f| to_column_write_options_rec(f, field_overwrites.get(&f.name).copied()))274.collect()275}276277/// Declare encodings278fn encoding_map(dtype: &ArrowDataType) -> Encoding {279match dtype.to_physical_type() {280PhysicalType::Dictionary(_)281| PhysicalType::LargeBinary282| PhysicalType::LargeUtf8283| PhysicalType::Utf8View284| PhysicalType::BinaryView => Encoding::RleDictionary,285PhysicalType::Primitive(dt) => {286use arrow::types::PrimitiveType::*;287match dt {288Float32 | Float64 | Float16 => Encoding::Plain,289_ => Encoding::RleDictionary,290}291},292// remaining is plain293_ => Encoding::Plain,294}295}296297298