Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/csv/read/schema_inference.rs
6939 views
1
use std::borrow::Cow;
2
3
use polars_core::prelude::*;
4
#[cfg(feature = "polars-time")]
5
use polars_time::chunkedarray::string::infer as date_infer;
6
#[cfg(feature = "polars-time")]
7
use polars_time::prelude::string::Pattern;
8
use polars_utils::format_pl_smallstr;
9
10
use super::parser::{SplitLines, is_comment_line, skip_bom, skip_line_ending};
11
use super::splitfields::SplitFields;
12
use super::{CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues};
13
use crate::csv::read::parser::skip_lines_naive;
14
use crate::mmap::ReaderBytes;
15
use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
16
17
#[derive(Clone, Debug, Default)]
18
pub struct SchemaInferenceResult {
19
inferred_schema: SchemaRef,
20
rows_read: usize,
21
bytes_read: usize,
22
bytes_total: usize,
23
n_threads: Option<usize>,
24
}
25
26
impl SchemaInferenceResult {
27
pub fn try_from_reader_bytes_and_options(
28
reader_bytes: &ReaderBytes,
29
options: &CsvReadOptions,
30
) -> PolarsResult<Self> {
31
let parse_options = options.get_parse_options();
32
33
let infer_schema_length = options.infer_schema_length;
34
let has_header = options.has_header;
35
let schema_overwrite_arc = options.schema_overwrite.clone();
36
let schema_overwrite = schema_overwrite_arc.as_ref().map(|x| x.as_ref());
37
let skip_rows = options.skip_rows;
38
let skip_lines = options.skip_lines;
39
let skip_rows_after_header = options.skip_rows_after_header;
40
let raise_if_empty = options.raise_if_empty;
41
let n_threads = options.n_threads;
42
43
let bytes_total = reader_bytes.len();
44
45
let (inferred_schema, rows_read, bytes_read) = infer_file_schema(
46
reader_bytes,
47
&parse_options,
48
infer_schema_length,
49
has_header,
50
schema_overwrite,
51
skip_rows,
52
skip_lines,
53
skip_rows_after_header,
54
raise_if_empty,
55
)?;
56
57
let this = Self {
58
inferred_schema: Arc::new(inferred_schema),
59
rows_read,
60
bytes_read,
61
bytes_total,
62
n_threads,
63
};
64
65
Ok(this)
66
}
67
68
pub fn with_inferred_schema(mut self, inferred_schema: SchemaRef) -> Self {
69
self.inferred_schema = inferred_schema;
70
self
71
}
72
73
pub fn get_inferred_schema(&self) -> SchemaRef {
74
self.inferred_schema.clone()
75
}
76
77
pub fn get_estimated_n_rows(&self) -> usize {
78
(self.rows_read as f64 / self.bytes_read as f64 * self.bytes_total as f64) as usize
79
}
80
}
81
82
impl CsvReadOptions {
83
/// Note: This does not update the schema from the inference result.
84
pub fn update_with_inference_result(&mut self, si_result: &SchemaInferenceResult) {
85
self.n_threads = si_result.n_threads;
86
}
87
}
88
89
pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
90
// determine data type based on possible types
91
// if there are incompatible types, use DataType::String
92
match possibilities.len() {
93
1 => possibilities.iter().next().unwrap().clone(),
94
2 if possibilities.contains(&DataType::Int64)
95
&& possibilities.contains(&DataType::Float64) =>
96
{
97
// we have an integer and double, fall down to double
98
DataType::Float64
99
},
100
// default to String for conflicting datatypes (e.g bool and int)
101
_ => DataType::String,
102
}
103
}
104
105
/// Infer the data type of a record
106
pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
107
// when quoting is enabled in the reader, these quotes aren't escaped, we default to
108
// String for them
109
let bytes = string.as_bytes();
110
if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
111
if try_parse_dates {
112
#[cfg(feature = "polars-time")]
113
{
114
match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
115
Some(pattern_with_offset) => match pattern_with_offset {
116
Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
117
DataType::Datetime(TimeUnit::Microseconds, None)
118
},
119
Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
120
Pattern::DatetimeYMDZ => {
121
DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
122
},
123
Pattern::Time => DataType::Time,
124
},
125
None => DataType::String,
126
}
127
}
128
#[cfg(not(feature = "polars-time"))]
129
{
130
panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
131
}
132
} else {
133
DataType::String
134
}
135
}
136
// match regex in a particular order
137
else if BOOLEAN_RE.is_match(string) {
138
DataType::Boolean
139
} else if !decimal_comma && FLOAT_RE.is_match(string)
140
|| decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
141
{
142
DataType::Float64
143
} else if INTEGER_RE.is_match(string) {
144
DataType::Int64
145
} else if try_parse_dates {
146
#[cfg(feature = "polars-time")]
147
{
148
match date_infer::infer_pattern_single(string) {
149
Some(pattern_with_offset) => match pattern_with_offset {
150
Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
151
DataType::Datetime(TimeUnit::Microseconds, None)
152
},
153
Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
154
Pattern::DatetimeYMDZ => {
155
DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
156
},
157
Pattern::Time => DataType::Time,
158
},
159
None => DataType::String,
160
}
161
}
162
#[cfg(not(feature = "polars-time"))]
163
{
164
panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
165
}
166
} else {
167
DataType::String
168
}
169
}
170
171
#[inline]
172
fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult<Cow<'_, str>> {
173
Ok(match encoding {
174
CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes)
175
.map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))?
176
.into(),
177
CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes),
178
})
179
}
180
181
fn column_name(i: usize) -> PlSmallStr {
182
format_pl_smallstr!("column_{}", i + 1)
183
}
184
185
#[allow(clippy::too_many_arguments)]
186
fn infer_file_schema_inner(
187
reader_bytes: &ReaderBytes,
188
parse_options: &CsvParseOptions,
189
max_read_rows: Option<usize>,
190
has_header: bool,
191
schema_overwrite: Option<&Schema>,
192
// we take &mut because we maybe need to skip more rows dependent
193
// on the schema inference
194
mut skip_rows: usize,
195
skip_rows_after_header: usize,
196
recursion_count: u8,
197
raise_if_empty: bool,
198
) -> PolarsResult<(Schema, usize, usize)> {
199
// keep track so that we can determine the amount of bytes read
200
let start_ptr = reader_bytes.as_ptr() as usize;
201
202
// We use lossy utf8 here because we don't want the schema inference to fail on utf8.
203
// It may later.
204
let encoding = CsvEncoding::LossyUtf8;
205
206
let bytes = skip_line_ending(skip_bom(reader_bytes), parse_options.eol_char);
207
if raise_if_empty {
208
polars_ensure!(!bytes.is_empty(), NoData: "empty CSV");
209
};
210
let mut lines = SplitLines::new(
211
bytes,
212
parse_options.quote_char,
213
parse_options.eol_char,
214
parse_options.comment_prefix.as_ref(),
215
)
216
.skip(skip_rows);
217
218
// get or create header names
219
// when has_header is false, creates default column names with column_ prefix
220
221
// skip lines that are comments
222
let mut first_line = None;
223
224
for (i, line) in (&mut lines).enumerate() {
225
if !is_comment_line(line, parse_options.comment_prefix.as_ref()) {
226
first_line = Some(line);
227
skip_rows += i;
228
break;
229
}
230
}
231
232
if first_line.is_none() {
233
first_line = lines.next();
234
}
235
236
// now that we've found the first non-comment line we parse the headers, or we create a header
237
let mut headers: Vec<PlSmallStr> = if let Some(mut header_line) = first_line {
238
let len = header_line.len();
239
if len > 1 {
240
// remove carriage return
241
let trailing_byte = header_line[len - 1];
242
if trailing_byte == b'\r' {
243
header_line = &header_line[..len - 1];
244
}
245
}
246
247
let byterecord = SplitFields::new(
248
header_line,
249
parse_options.separator,
250
parse_options.quote_char,
251
parse_options.eol_char,
252
);
253
if has_header {
254
let headers = byterecord
255
.map(|(slice, needs_escaping)| {
256
let slice_escaped = if needs_escaping && (slice.len() >= 2) {
257
&slice[1..(slice.len() - 1)]
258
} else {
259
slice
260
};
261
let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
262
Ok(s)
263
})
264
.collect::<PolarsResult<Vec<_>>>()?;
265
266
let mut final_headers = Vec::with_capacity(headers.len());
267
268
let mut header_names = PlHashMap::with_capacity(headers.len());
269
270
for name in &headers {
271
let count = header_names.entry(name.as_ref()).or_insert(0usize);
272
if *count != 0 {
273
final_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
274
} else {
275
final_headers.push(PlSmallStr::from_str(name))
276
}
277
*count += 1;
278
}
279
final_headers
280
} else {
281
byterecord
282
.enumerate()
283
.map(|(i, _s)| column_name(i))
284
.collect::<Vec<PlSmallStr>>()
285
}
286
} else if has_header && !bytes.is_empty() && recursion_count == 0 {
287
// there was no new line char. So we copy the whole buf and add one
288
// this is likely to be cheap as there are no rows.
289
let mut buf = Vec::with_capacity(bytes.len() + 2);
290
buf.extend_from_slice(bytes);
291
buf.push(parse_options.eol_char);
292
293
return infer_file_schema_inner(
294
&ReaderBytes::Owned(buf.into()),
295
parse_options,
296
max_read_rows,
297
has_header,
298
schema_overwrite,
299
skip_rows,
300
skip_rows_after_header,
301
recursion_count + 1,
302
raise_if_empty,
303
);
304
} else if !raise_if_empty {
305
return Ok((Schema::default(), 0, 0));
306
} else {
307
polars_bail!(NoData: "empty CSV");
308
};
309
if !has_header {
310
// re-init lines so that the header is included in type inference.
311
lines = SplitLines::new(
312
bytes,
313
parse_options.quote_char,
314
parse_options.eol_char,
315
parse_options.comment_prefix.as_ref(),
316
)
317
.skip(skip_rows);
318
}
319
320
// keep track of inferred field types
321
let mut column_types: Vec<PlHashSet<DataType>> =
322
vec![PlHashSet::with_capacity(4); headers.len()];
323
// keep track of columns with nulls
324
let mut nulls: Vec<bool> = vec![false; headers.len()];
325
326
let mut rows_count = 0;
327
let mut fields = Vec::with_capacity(headers.len());
328
329
// needed to prevent ownership going into the iterator loop
330
let records_ref = &mut lines;
331
332
let mut end_ptr = start_ptr;
333
for mut line in records_ref
334
.take(match max_read_rows {
335
Some(max_read_rows) => {
336
if max_read_rows <= (usize::MAX - skip_rows_after_header) {
337
// read skip_rows_after_header more rows for inferring
338
// the correct schema as the first skip_rows_after_header
339
// rows will be skipped
340
max_read_rows + skip_rows_after_header
341
} else {
342
max_read_rows
343
}
344
},
345
None => usize::MAX,
346
})
347
.skip(skip_rows_after_header)
348
{
349
rows_count += 1;
350
// keep track so that we can determine the amount of bytes read
351
end_ptr = line.as_ptr() as usize + line.len();
352
353
if line.is_empty() {
354
continue;
355
}
356
357
// line is a comment -> skip
358
if is_comment_line(line, parse_options.comment_prefix.as_ref()) {
359
continue;
360
}
361
362
let len = line.len();
363
if len > 1 {
364
// remove carriage return
365
let trailing_byte = line[len - 1];
366
if trailing_byte == b'\r' {
367
line = &line[..len - 1];
368
}
369
}
370
371
let record = SplitFields::new(
372
line,
373
parse_options.separator,
374
parse_options.quote_char,
375
parse_options.eol_char,
376
);
377
378
for (i, (slice, needs_escaping)) in record.enumerate() {
379
// When `has_header = False` and ``
380
// Increase the schema if the first line didn't have all columns.
381
if i >= headers.len() {
382
if !has_header {
383
headers.push(column_name(i));
384
column_types.push(Default::default());
385
nulls.push(false);
386
} else {
387
break;
388
}
389
}
390
391
if slice.is_empty() {
392
unsafe { *nulls.get_unchecked_mut(i) = true };
393
} else {
394
let slice_escaped = if needs_escaping && (slice.len() >= 2) {
395
&slice[1..(slice.len() - 1)]
396
} else {
397
slice
398
};
399
let s = parse_bytes_with_encoding(slice_escaped, encoding)?;
400
let dtype = match &parse_options.null_values {
401
None => Some(infer_field_schema(
402
&s,
403
parse_options.try_parse_dates,
404
parse_options.decimal_comma,
405
)),
406
Some(NullValues::AllColumns(names)) => {
407
if !names.iter().any(|nv| nv == s.as_ref()) {
408
Some(infer_field_schema(
409
&s,
410
parse_options.try_parse_dates,
411
parse_options.decimal_comma,
412
))
413
} else {
414
None
415
}
416
},
417
Some(NullValues::AllColumnsSingle(name)) => {
418
if s.as_ref() != name.as_str() {
419
Some(infer_field_schema(
420
&s,
421
parse_options.try_parse_dates,
422
parse_options.decimal_comma,
423
))
424
} else {
425
None
426
}
427
},
428
Some(NullValues::Named(names)) => {
429
// SAFETY:
430
// we iterate over headers length.
431
let current_name = unsafe { headers.get_unchecked(i) };
432
let null_name = &names.iter().find(|name| name.0 == current_name);
433
434
if let Some(null_name) = null_name {
435
if null_name.1.as_str() != s.as_ref() {
436
Some(infer_field_schema(
437
&s,
438
parse_options.try_parse_dates,
439
parse_options.decimal_comma,
440
))
441
} else {
442
None
443
}
444
} else {
445
Some(infer_field_schema(
446
&s,
447
parse_options.try_parse_dates,
448
parse_options.decimal_comma,
449
))
450
}
451
},
452
};
453
if let Some(dtype) = dtype {
454
unsafe { column_types.get_unchecked_mut(i).insert(dtype) };
455
}
456
}
457
}
458
}
459
460
// build schema from inference results
461
for i in 0..headers.len() {
462
let field_name = &headers[i];
463
464
if let Some(schema_overwrite) = schema_overwrite {
465
if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
466
fields.push(Field::new(name.clone(), dtype.clone()));
467
continue;
468
}
469
470
// column might have been renamed
471
// execute only if schema is complete
472
if schema_overwrite.len() == headers.len() {
473
if let Some((name, dtype)) = schema_overwrite.get_at_index(i) {
474
fields.push(Field::new(name.clone(), dtype.clone()));
475
continue;
476
}
477
}
478
}
479
480
let possibilities = &column_types[i];
481
let dtype = finish_infer_field_schema(possibilities);
482
fields.push(Field::new(field_name.clone(), dtype));
483
}
484
// if there is a single line after the header without an eol
485
// we copy the bytes add an eol and rerun this function
486
// so that the inference is consistent with and without eol char
487
if rows_count == 0
488
&& !reader_bytes.is_empty()
489
&& reader_bytes[reader_bytes.len() - 1] != parse_options.eol_char
490
&& recursion_count == 0
491
{
492
let mut rb = Vec::with_capacity(reader_bytes.len() + 1);
493
rb.extend_from_slice(reader_bytes);
494
rb.push(parse_options.eol_char);
495
return infer_file_schema_inner(
496
&ReaderBytes::Owned(rb.into()),
497
parse_options,
498
max_read_rows,
499
has_header,
500
schema_overwrite,
501
skip_rows,
502
skip_rows_after_header,
503
recursion_count + 1,
504
raise_if_empty,
505
);
506
}
507
508
Ok((Schema::from_iter(fields), rows_count, end_ptr - start_ptr))
509
}
510
511
/// Infer the schema of a CSV file by reading through the first n rows of the file,
512
/// with `max_read_rows` controlling the maximum number of rows to read.
513
///
514
/// If `max_read_rows` is not set, the whole file is read to infer its schema.
515
///
516
/// Returns
517
/// - inferred schema
518
/// - number of rows used for inference.
519
/// - bytes read
520
#[allow(clippy::too_many_arguments)]
521
pub fn infer_file_schema(
522
reader_bytes: &ReaderBytes,
523
parse_options: &CsvParseOptions,
524
max_read_rows: Option<usize>,
525
has_header: bool,
526
schema_overwrite: Option<&Schema>,
527
skip_rows: usize,
528
skip_lines: usize,
529
skip_rows_after_header: usize,
530
raise_if_empty: bool,
531
) -> PolarsResult<(Schema, usize, usize)> {
532
if skip_lines > 0 {
533
polars_ensure!(skip_rows == 0, InvalidOperation: "only one of 'skip_rows'/'skip_lines' may be set");
534
let bytes = skip_lines_naive(reader_bytes, parse_options.eol_char, skip_lines);
535
let reader_bytes = ReaderBytes::Borrowed(bytes);
536
infer_file_schema_inner(
537
&reader_bytes,
538
parse_options,
539
max_read_rows,
540
has_header,
541
schema_overwrite,
542
skip_rows,
543
skip_rows_after_header,
544
0,
545
raise_if_empty,
546
)
547
} else {
548
infer_file_schema_inner(
549
reader_bytes,
550
parse_options,
551
max_read_rows,
552
has_header,
553
schema_overwrite,
554
skip_rows,
555
skip_rows_after_header,
556
0,
557
raise_if_empty,
558
)
559
}
560
}
561
562