Path: blob/main/crates/polars-io/src/csv/write/write_impl.rs
6939 views
mod serializer;12use std::io::Write;34use arrow::array::NullArray;5use arrow::legacy::time_zone::Tz;6use polars_core::POOL;7use polars_core::prelude::*;8use polars_error::polars_ensure;9use rayon::prelude::*;10use serializer::{serializer_for, string_serializer};1112use crate::csv::write::SerializeOptions;1314pub(crate) fn write<W: Write>(15writer: &mut W,16df: &DataFrame,17chunk_size: usize,18options: &SerializeOptions,19n_threads: usize,20) -> PolarsResult<()> {21for s in df.get_columns() {22let nested = match s.dtype() {23DataType::List(_) => true,24#[cfg(feature = "dtype-struct")]25DataType::Struct(_) => true,26#[cfg(feature = "object")]27DataType::Object(_) => {28return Err(PolarsError::ComputeError(29"csv writer does not support object dtype".into(),30));31},32_ => false,33};34polars_ensure!(35!nested,36ComputeError: "CSV format does not support nested data",37);38}3940// Check that the double quote is valid UTF-8.41polars_ensure!(42std::str::from_utf8(&[options.quote_char, options.quote_char]).is_ok(),43ComputeError: "quote char results in invalid utf-8",44);4546let (datetime_formats, time_zones): (Vec<&str>, Vec<Option<Tz>>) = df47.get_columns()48.iter()49.map(|column| match column.dtype() {50DataType::Datetime(TimeUnit::Milliseconds, tz) => {51let (format, tz_parsed) = match tz {52#[cfg(feature = "timezones")]53Some(tz) => (54options55.datetime_format56.as_deref()57.unwrap_or("%FT%H:%M:%S.%3f%z"),58tz.parse::<Tz>().ok(),59),60_ => (61options62.datetime_format63.as_deref()64.unwrap_or("%FT%H:%M:%S.%3f"),65None,66),67};68(format, tz_parsed)69},70DataType::Datetime(TimeUnit::Microseconds, tz) => {71let (format, tz_parsed) = match tz {72#[cfg(feature = "timezones")]73Some(tz) => (74options75.datetime_format76.as_deref()77.unwrap_or("%FT%H:%M:%S.%6f%z"),78tz.parse::<Tz>().ok(),79),80_ => (81options82.datetime_format83.as_deref()84.unwrap_or("%FT%H:%M:%S.%6f"),85None,86),87};88(format, tz_parsed)89},90DataType::Datetime(TimeUnit::Nanoseconds, tz) => {91let (format, tz_parsed) = match tz {92#[cfg(feature = "timezones")]93Some(tz) => (94options95.datetime_format96.as_deref()97.unwrap_or("%FT%H:%M:%S.%9f%z"),98tz.parse::<Tz>().ok(),99),100_ => (101options102.datetime_format103.as_deref()104.unwrap_or("%FT%H:%M:%S.%9f"),105None,106),107};108(format, tz_parsed)109},110_ => ("", None),111})112.unzip();113114let len = df.height();115let total_rows_per_pool_iter = n_threads * chunk_size;116117let mut n_rows_finished = 0;118119// To comply with the safety requirements for the buf_writer closure, we need to make sure120// the column dtype references have a lifetime that exceeds the scope of the serializer, i.e.121// the full dataframe. If not, we can run into use-after-free memory issues for types that122// allocate, such as Enum or Categorical dtype (see GH issue #23939).123let col_dtypes: Vec<_> = df.get_columns().iter().map(|c| c.dtype()).collect();124125let mut buffers: Vec<_> = (0..n_threads).map(|_| (Vec::new(), Vec::new())).collect();126while n_rows_finished < len {127let buf_writer = |thread_no, write_buffer: &mut Vec<_>, serializers_vec: &mut Vec<_>| {128let thread_offset = thread_no * chunk_size;129let total_offset = n_rows_finished + thread_offset;130let mut df = df.slice(total_offset as i64, chunk_size);131// the `series.iter` needs rechunked series.132// we don't do this on the whole as this probably needs much less rechunking133// so will be faster.134// and allows writing `pl.concat([df] * 100, rechunk=False).write_csv()` as the rechunk135// would go OOM136df.as_single_chunk();137let cols = df.get_columns();138139// SAFETY:140// the bck thinks the lifetime is bounded to write_buffer_pool, but at the time we return141// the vectors the buffer pool, the series have already been removed from the buffers142// in other words, the lifetime does not leave this scope143let cols = unsafe { std::mem::transmute::<&[Column], &[Column]>(cols) };144145if df.is_empty() {146return Ok(());147}148149if serializers_vec.is_empty() {150debug_assert_eq!(cols.len(), col_dtypes.len());151*serializers_vec = std::iter::zip(cols, &col_dtypes)152.enumerate()153.map(|(i, (col, &col_dtype))| {154serializer_for(155&*col.as_materialized_series().chunks()[0],156options,157col_dtype,158datetime_formats[i],159time_zones[i],160)161})162.collect::<Result<_, _>>()?;163} else {164debug_assert_eq!(serializers_vec.len(), cols.len());165for (col_iter, col) in std::iter::zip(serializers_vec.iter_mut(), cols) {166col_iter.update_array(&*col.as_materialized_series().chunks()[0]);167}168}169170let serializers = serializers_vec.as_mut_slice();171172let len = std::cmp::min(cols[0].len(), chunk_size);173174for _ in 0..len {175serializers[0].serialize(write_buffer, options);176for serializer in &mut serializers[1..] {177write_buffer.push(options.separator);178serializer.serialize(write_buffer, options);179}180181write_buffer.extend_from_slice(options.line_terminator.as_bytes());182}183184Ok(())185};186187if n_threads > 1 {188POOL.install(|| {189buffers190.par_iter_mut()191.enumerate()192.map(|(i, (w, s))| buf_writer(i, w, s))193.collect::<PolarsResult<()>>()194})?;195} else {196let (w, s) = &mut buffers[0];197buf_writer(0, w, s)?;198}199200for (write_buffer, _) in &mut buffers {201writer.write_all(write_buffer)?;202write_buffer.clear();203}204205n_rows_finished += total_rows_per_pool_iter;206}207Ok(())208}209210/// Writes a CSV header to `writer`.211pub(crate) fn write_header<W: Write>(212writer: &mut W,213names: &[&str],214options: &SerializeOptions,215) -> PolarsResult<()> {216let mut header = Vec::new();217218// A hack, but it works for this case.219let fake_arr = NullArray::new(ArrowDataType::Null, 0);220let mut names_serializer = string_serializer(221|iter: &mut std::slice::Iter<&str>| iter.next().copied(),222options,223|_| names.iter(),224&fake_arr,225);226for i in 0..names.len() {227names_serializer.serialize(&mut header, options);228if i != names.len() - 1 {229header.push(options.separator);230}231}232header.extend_from_slice(options.line_terminator.as_bytes());233writer.write_all(&header)?;234Ok(())235}236237/// Writes a UTF-8 BOM to `writer`.238pub(crate) fn write_bom<W: Write>(writer: &mut W) -> PolarsResult<()> {239const BOM: [u8; 3] = [0xEF, 0xBB, 0xBF];240writer.write_all(&BOM)?;241Ok(())242}243244245