Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/csv/write/write_impl.rs
8406 views
1
mod serializer;
2
3
use arrow::array::NullArray;
4
use arrow::legacy::time_zone::Tz;
5
use polars_core::POOL;
6
use polars_core::prelude::*;
7
use polars_error::polars_ensure;
8
use polars_utils::reuse_vec::reuse_vec;
9
use rayon::prelude::*;
10
use serializer::{serializer_for, string_serializer};
11
12
use crate::csv::write::SerializeOptions;
13
14
type ColumnSerializer<'a> =
15
dyn crate::csv::write::write_impl::serializer::Serializer<'a> + Send + 'a;
16
17
/// Writes CSV from DataFrames.
18
pub struct CsvSerializer {
19
serializers: Vec<Box<ColumnSerializer<'static>>>,
20
options: Arc<SerializeOptions>,
21
datetime_formats: Arc<[PlSmallStr]>,
22
time_zones: Arc<[Option<Tz>]>,
23
}
24
25
impl Clone for CsvSerializer {
26
fn clone(&self) -> Self {
27
Self {
28
serializers: vec![],
29
options: self.options.clone(),
30
datetime_formats: self.datetime_formats.clone(),
31
time_zones: self.time_zones.clone(),
32
}
33
}
34
}
35
36
impl CsvSerializer {
37
pub fn new(schema: SchemaRef, options: Arc<SerializeOptions>) -> PolarsResult<Self> {
38
for dtype in schema.iter_values() {
39
let nested = match dtype {
40
DataType::List(_) => true,
41
#[cfg(feature = "dtype-struct")]
42
DataType::Struct(_) => true,
43
#[cfg(feature = "object")]
44
DataType::Object(_) => {
45
return Err(PolarsError::ComputeError(
46
"csv writer does not support object dtype".into(),
47
));
48
},
49
_ => false,
50
};
51
polars_ensure!(
52
!nested,
53
ComputeError: "CSV format does not support nested data",
54
);
55
}
56
57
// Check that the double quote is valid UTF-8.
58
polars_ensure!(
59
std::str::from_utf8(&[options.quote_char, options.quote_char]).is_ok(),
60
ComputeError: "quote char results in invalid utf-8",
61
);
62
63
let (datetime_formats, time_zones): (Vec<PlSmallStr>, Vec<Option<Tz>>) = schema
64
.iter_values()
65
.map(|dtype| {
66
let (datetime_format_str, time_zone) = match dtype {
67
DataType::Datetime(TimeUnit::Milliseconds, tz) => {
68
let (format, tz_parsed) = match tz {
69
#[cfg(feature = "timezones")]
70
Some(tz) => (
71
options
72
.datetime_format
73
.as_deref()
74
.unwrap_or("%FT%H:%M:%S.%3f%z"),
75
tz.parse::<Tz>().ok(),
76
),
77
_ => (
78
options
79
.datetime_format
80
.as_deref()
81
.unwrap_or("%FT%H:%M:%S.%3f"),
82
None,
83
),
84
};
85
(format, tz_parsed)
86
},
87
DataType::Datetime(TimeUnit::Microseconds, tz) => {
88
let (format, tz_parsed) = match tz {
89
#[cfg(feature = "timezones")]
90
Some(tz) => (
91
options
92
.datetime_format
93
.as_deref()
94
.unwrap_or("%FT%H:%M:%S.%6f%z"),
95
tz.parse::<Tz>().ok(),
96
),
97
_ => (
98
options
99
.datetime_format
100
.as_deref()
101
.unwrap_or("%FT%H:%M:%S.%6f"),
102
None,
103
),
104
};
105
(format, tz_parsed)
106
},
107
DataType::Datetime(TimeUnit::Nanoseconds, tz) => {
108
let (format, tz_parsed) = match tz {
109
#[cfg(feature = "timezones")]
110
Some(tz) => (
111
options
112
.datetime_format
113
.as_deref()
114
.unwrap_or("%FT%H:%M:%S.%9f%z"),
115
tz.parse::<Tz>().ok(),
116
),
117
_ => (
118
options
119
.datetime_format
120
.as_deref()
121
.unwrap_or("%FT%H:%M:%S.%9f"),
122
None,
123
),
124
};
125
(format, tz_parsed)
126
},
127
_ => ("", None),
128
};
129
130
(datetime_format_str.into(), time_zone)
131
})
132
.collect();
133
134
Ok(Self {
135
serializers: vec![],
136
options,
137
datetime_formats: Arc::from_iter(datetime_formats),
138
time_zones: Arc::from_iter(time_zones),
139
})
140
}
141
142
/// # Panics
143
/// Panics if a column has >1 chunk.
144
pub fn serialize_to_csv<'a>(
145
&'a mut self,
146
df: &'a DataFrame,
147
buffer: &mut Vec<u8>,
148
) -> PolarsResult<()> {
149
if df.height() == 0 || df.width() == 0 {
150
return Ok(());
151
}
152
153
let options = Arc::clone(&self.options);
154
let options = options.as_ref();
155
156
let mut serializers_vec = reuse_vec(std::mem::take(&mut self.serializers));
157
let serializers = self.build_serializers(df.columns(), &mut serializers_vec)?;
158
159
for _ in 0..df.height() {
160
serializers[0].serialize(buffer, options);
161
for serializer in &mut serializers[1..] {
162
buffer.push(options.separator);
163
serializer.serialize(buffer, options);
164
}
165
166
buffer.extend_from_slice(options.line_terminator.as_bytes());
167
}
168
169
self.serializers = reuse_vec(serializers_vec);
170
171
Ok(())
172
}
173
174
/// # Panics
175
/// Panics if a column has >1 chunk.
176
fn build_serializers<'a, 'b>(
177
&'a mut self,
178
columns: &'a [Column],
179
serializers: &'b mut Vec<Box<ColumnSerializer<'a>>>,
180
) -> PolarsResult<&'b mut [Box<ColumnSerializer<'a>>]> {
181
serializers.clear();
182
serializers.reserve(columns.len());
183
184
for (i, c) in columns.iter().enumerate() {
185
assert_eq!(c.n_chunks(), 1);
186
187
serializers.push(serializer_for(
188
c.as_materialized_series().chunks()[0].as_ref(),
189
Arc::as_ref(&self.options),
190
c.dtype(),
191
self.datetime_formats[i].as_str(),
192
self.time_zones[i],
193
)?)
194
}
195
196
Ok(serializers)
197
}
198
}
199
200
pub(crate) fn write(
201
mut writer: impl std::io::Write,
202
df: &DataFrame,
203
chunk_size: usize,
204
options: Arc<SerializeOptions>,
205
n_threads: usize,
206
) -> PolarsResult<()> {
207
let len = df.height();
208
let total_rows_per_pool_iter = n_threads * chunk_size;
209
210
let mut n_rows_finished = 0;
211
212
let csv_serializer = CsvSerializer::new(Arc::clone(df.schema()), options)?;
213
214
let mut buffers: Vec<(Vec<u8>, CsvSerializer)> = (0..n_threads)
215
.map(|_| (Vec::new(), csv_serializer.clone()))
216
.collect();
217
while n_rows_finished < len {
218
let buf_writer =
219
|thread_no, write_buffer: &mut Vec<_>, csv_serializer: &mut CsvSerializer| {
220
let thread_offset = thread_no * chunk_size;
221
let total_offset = n_rows_finished + thread_offset;
222
let mut df = df.slice(total_offset as i64, chunk_size);
223
// the `series.iter` needs rechunked series.
224
// we don't do this on the whole as this probably needs much less rechunking
225
// so will be faster.
226
// and allows writing `pl.concat([df] * 100, rechunk=False).write_csv()` as the rechunk
227
// would go OOM
228
df.rechunk_mut();
229
230
csv_serializer.serialize_to_csv(&df, write_buffer)?;
231
232
Ok(())
233
};
234
235
if n_threads > 1 {
236
POOL.install(|| {
237
buffers
238
.par_iter_mut()
239
.enumerate()
240
.map(|(i, (w, s))| buf_writer(i, w, s))
241
.collect::<PolarsResult<()>>()
242
})?;
243
} else {
244
let (w, s) = &mut buffers[0];
245
buf_writer(0, w, s)?;
246
}
247
248
for (write_buffer, _) in &mut buffers {
249
writer.write_all(write_buffer)?;
250
write_buffer.clear();
251
}
252
253
n_rows_finished += total_rows_per_pool_iter;
254
}
255
Ok(())
256
}
257
258
/// Writes a CSV header to `writer`.
259
pub fn csv_header(names: &[&str], options: &SerializeOptions) -> PolarsResult<Vec<u8>> {
260
let mut header = Vec::new();
261
262
// A hack, but it works for this case.
263
let fake_arr = NullArray::new(ArrowDataType::Null, 0);
264
let mut names_serializer = string_serializer(
265
|iter: &mut std::slice::Iter<&str>| iter.next().copied(),
266
options,
267
|_| names.iter(),
268
&fake_arr,
269
);
270
for i in 0..names.len() {
271
names_serializer.serialize(&mut header, options);
272
if i != names.len() - 1 {
273
header.push(options.separator);
274
}
275
}
276
header.extend_from_slice(options.line_terminator.as_bytes());
277
Ok(header)
278
}
279
280
pub const UTF8_BOM: [u8; 3] = [0xEF, 0xBB, 0xBF];
281
282