Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/csv/write/write_impl.rs
6939 views
1
mod serializer;
2
3
use std::io::Write;
4
5
use arrow::array::NullArray;
6
use arrow::legacy::time_zone::Tz;
7
use polars_core::POOL;
8
use polars_core::prelude::*;
9
use polars_error::polars_ensure;
10
use rayon::prelude::*;
11
use serializer::{serializer_for, string_serializer};
12
13
use crate::csv::write::SerializeOptions;
14
15
pub(crate) fn write<W: Write>(
16
writer: &mut W,
17
df: &DataFrame,
18
chunk_size: usize,
19
options: &SerializeOptions,
20
n_threads: usize,
21
) -> PolarsResult<()> {
22
for s in df.get_columns() {
23
let nested = match s.dtype() {
24
DataType::List(_) => true,
25
#[cfg(feature = "dtype-struct")]
26
DataType::Struct(_) => true,
27
#[cfg(feature = "object")]
28
DataType::Object(_) => {
29
return Err(PolarsError::ComputeError(
30
"csv writer does not support object dtype".into(),
31
));
32
},
33
_ => false,
34
};
35
polars_ensure!(
36
!nested,
37
ComputeError: "CSV format does not support nested data",
38
);
39
}
40
41
// Check that the double quote is valid UTF-8.
42
polars_ensure!(
43
std::str::from_utf8(&[options.quote_char, options.quote_char]).is_ok(),
44
ComputeError: "quote char results in invalid utf-8",
45
);
46
47
let (datetime_formats, time_zones): (Vec<&str>, Vec<Option<Tz>>) = df
48
.get_columns()
49
.iter()
50
.map(|column| match column.dtype() {
51
DataType::Datetime(TimeUnit::Milliseconds, tz) => {
52
let (format, tz_parsed) = match tz {
53
#[cfg(feature = "timezones")]
54
Some(tz) => (
55
options
56
.datetime_format
57
.as_deref()
58
.unwrap_or("%FT%H:%M:%S.%3f%z"),
59
tz.parse::<Tz>().ok(),
60
),
61
_ => (
62
options
63
.datetime_format
64
.as_deref()
65
.unwrap_or("%FT%H:%M:%S.%3f"),
66
None,
67
),
68
};
69
(format, tz_parsed)
70
},
71
DataType::Datetime(TimeUnit::Microseconds, tz) => {
72
let (format, tz_parsed) = match tz {
73
#[cfg(feature = "timezones")]
74
Some(tz) => (
75
options
76
.datetime_format
77
.as_deref()
78
.unwrap_or("%FT%H:%M:%S.%6f%z"),
79
tz.parse::<Tz>().ok(),
80
),
81
_ => (
82
options
83
.datetime_format
84
.as_deref()
85
.unwrap_or("%FT%H:%M:%S.%6f"),
86
None,
87
),
88
};
89
(format, tz_parsed)
90
},
91
DataType::Datetime(TimeUnit::Nanoseconds, tz) => {
92
let (format, tz_parsed) = match tz {
93
#[cfg(feature = "timezones")]
94
Some(tz) => (
95
options
96
.datetime_format
97
.as_deref()
98
.unwrap_or("%FT%H:%M:%S.%9f%z"),
99
tz.parse::<Tz>().ok(),
100
),
101
_ => (
102
options
103
.datetime_format
104
.as_deref()
105
.unwrap_or("%FT%H:%M:%S.%9f"),
106
None,
107
),
108
};
109
(format, tz_parsed)
110
},
111
_ => ("", None),
112
})
113
.unzip();
114
115
let len = df.height();
116
let total_rows_per_pool_iter = n_threads * chunk_size;
117
118
let mut n_rows_finished = 0;
119
120
// To comply with the safety requirements for the buf_writer closure, we need to make sure
121
// the column dtype references have a lifetime that exceeds the scope of the serializer, i.e.
122
// the full dataframe. If not, we can run into use-after-free memory issues for types that
123
// allocate, such as Enum or Categorical dtype (see GH issue #23939).
124
let col_dtypes: Vec<_> = df.get_columns().iter().map(|c| c.dtype()).collect();
125
126
let mut buffers: Vec<_> = (0..n_threads).map(|_| (Vec::new(), Vec::new())).collect();
127
while n_rows_finished < len {
128
let buf_writer = |thread_no, write_buffer: &mut Vec<_>, serializers_vec: &mut Vec<_>| {
129
let thread_offset = thread_no * chunk_size;
130
let total_offset = n_rows_finished + thread_offset;
131
let mut df = df.slice(total_offset as i64, chunk_size);
132
// the `series.iter` needs rechunked series.
133
// we don't do this on the whole as this probably needs much less rechunking
134
// so will be faster.
135
// and allows writing `pl.concat([df] * 100, rechunk=False).write_csv()` as the rechunk
136
// would go OOM
137
df.as_single_chunk();
138
let cols = df.get_columns();
139
140
// SAFETY:
141
// the bck thinks the lifetime is bounded to write_buffer_pool, but at the time we return
142
// the vectors the buffer pool, the series have already been removed from the buffers
143
// in other words, the lifetime does not leave this scope
144
let cols = unsafe { std::mem::transmute::<&[Column], &[Column]>(cols) };
145
146
if df.is_empty() {
147
return Ok(());
148
}
149
150
if serializers_vec.is_empty() {
151
debug_assert_eq!(cols.len(), col_dtypes.len());
152
*serializers_vec = std::iter::zip(cols, &col_dtypes)
153
.enumerate()
154
.map(|(i, (col, &col_dtype))| {
155
serializer_for(
156
&*col.as_materialized_series().chunks()[0],
157
options,
158
col_dtype,
159
datetime_formats[i],
160
time_zones[i],
161
)
162
})
163
.collect::<Result<_, _>>()?;
164
} else {
165
debug_assert_eq!(serializers_vec.len(), cols.len());
166
for (col_iter, col) in std::iter::zip(serializers_vec.iter_mut(), cols) {
167
col_iter.update_array(&*col.as_materialized_series().chunks()[0]);
168
}
169
}
170
171
let serializers = serializers_vec.as_mut_slice();
172
173
let len = std::cmp::min(cols[0].len(), chunk_size);
174
175
for _ in 0..len {
176
serializers[0].serialize(write_buffer, options);
177
for serializer in &mut serializers[1..] {
178
write_buffer.push(options.separator);
179
serializer.serialize(write_buffer, options);
180
}
181
182
write_buffer.extend_from_slice(options.line_terminator.as_bytes());
183
}
184
185
Ok(())
186
};
187
188
if n_threads > 1 {
189
POOL.install(|| {
190
buffers
191
.par_iter_mut()
192
.enumerate()
193
.map(|(i, (w, s))| buf_writer(i, w, s))
194
.collect::<PolarsResult<()>>()
195
})?;
196
} else {
197
let (w, s) = &mut buffers[0];
198
buf_writer(0, w, s)?;
199
}
200
201
for (write_buffer, _) in &mut buffers {
202
writer.write_all(write_buffer)?;
203
write_buffer.clear();
204
}
205
206
n_rows_finished += total_rows_per_pool_iter;
207
}
208
Ok(())
209
}
210
211
/// Writes a CSV header to `writer`.
212
pub(crate) fn write_header<W: Write>(
213
writer: &mut W,
214
names: &[&str],
215
options: &SerializeOptions,
216
) -> PolarsResult<()> {
217
let mut header = Vec::new();
218
219
// A hack, but it works for this case.
220
let fake_arr = NullArray::new(ArrowDataType::Null, 0);
221
let mut names_serializer = string_serializer(
222
|iter: &mut std::slice::Iter<&str>| iter.next().copied(),
223
options,
224
|_| names.iter(),
225
&fake_arr,
226
);
227
for i in 0..names.len() {
228
names_serializer.serialize(&mut header, options);
229
if i != names.len() - 1 {
230
header.push(options.separator);
231
}
232
}
233
header.extend_from_slice(options.line_terminator.as_bytes());
234
writer.write_all(&header)?;
235
Ok(())
236
}
237
238
/// Writes a UTF-8 BOM to `writer`.
239
pub(crate) fn write_bom<W: Write>(writer: &mut W) -> PolarsResult<()> {
240
const BOM: [u8; 3] = [0xEF, 0xBB, 0xBF];
241
writer.write_all(&BOM)?;
242
Ok(())
243
}
244
245