Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-io/src/csv/read/schema_inference.rs
8424 views
1
use polars_buffer::Buffer;
2
use polars_core::prelude::*;
3
#[cfg(feature = "polars-time")]
4
use polars_time::chunkedarray::string::infer as date_infer;
5
#[cfg(feature = "polars-time")]
6
use polars_time::prelude::string::Pattern;
7
use polars_utils::format_pl_smallstr;
8
9
use super::splitfields::SplitFields;
10
use super::{CsvParseOptions, NullValues};
11
use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};
12
13
/// Low-level CSV schema inference function.
14
///
15
/// Use `read_until_start_and_infer_schema` instead.
16
#[allow(clippy::too_many_arguments)]
17
pub(super) fn infer_file_schema_impl(
18
header_line: &Option<Buffer<u8>>,
19
content_lines: &[Buffer<u8>],
20
infer_all_as_str: bool,
21
parse_options: &CsvParseOptions,
22
schema_overwrite: Option<&Schema>,
23
) -> Schema {
24
let mut headers = header_line
25
.as_ref()
26
.map(|line| infer_headers(line, parse_options))
27
.unwrap_or_else(|| Vec::with_capacity(8));
28
29
let extend_header_with_unknown_column = header_line.is_none();
30
31
let mut column_types = vec![PlHashSet::<DataType>::with_capacity(4); headers.len()];
32
let mut nulls = vec![false; headers.len()];
33
34
for content_line in content_lines {
35
infer_types_from_line(
36
content_line,
37
infer_all_as_str,
38
&mut headers,
39
extend_header_with_unknown_column,
40
parse_options,
41
&mut column_types,
42
&mut nulls,
43
);
44
}
45
46
build_schema(&headers, &column_types, schema_overwrite)
47
}
48
49
fn infer_headers(mut header_line: &[u8], parse_options: &CsvParseOptions) -> Vec<PlSmallStr> {
50
let len = header_line.len();
51
52
if header_line.last().copied() == Some(b'\r') {
53
header_line = &header_line[..len - 1];
54
}
55
56
let byterecord = SplitFields::new(
57
header_line,
58
parse_options.separator,
59
parse_options.quote_char,
60
parse_options.eol_char,
61
);
62
63
let headers = byterecord
64
.map(|(slice, needs_escaping)| {
65
let slice_escaped = if needs_escaping && (slice.len() >= 2) {
66
&slice[1..(slice.len() - 1)]
67
} else {
68
slice
69
};
70
String::from_utf8_lossy(slice_escaped)
71
})
72
.collect::<Vec<_>>();
73
74
let mut deduplicated_headers = Vec::with_capacity(headers.len());
75
let mut header_names = PlHashMap::with_capacity(headers.len());
76
77
for name in &headers {
78
let count = header_names.entry(name.as_ref()).or_insert(0usize);
79
if *count != 0 {
80
deduplicated_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))
81
} else {
82
deduplicated_headers.push(PlSmallStr::from_str(name))
83
}
84
*count += 1;
85
}
86
87
deduplicated_headers
88
}
89
90
fn infer_types_from_line(
91
mut line: &[u8],
92
infer_all_as_str: bool,
93
headers: &mut Vec<PlSmallStr>,
94
extend_header_with_unknown_column: bool,
95
parse_options: &CsvParseOptions,
96
column_types: &mut Vec<PlHashSet<DataType>>,
97
nulls: &mut Vec<bool>,
98
) {
99
let line_len = line.len();
100
if line.last().copied() == Some(b'\r') {
101
line = &line[..line_len - 1];
102
}
103
104
let record = SplitFields::new(
105
line,
106
parse_options.separator,
107
parse_options.quote_char,
108
parse_options.eol_char,
109
);
110
111
for (i, (slice, needs_escaping)) in record.enumerate() {
112
if i >= headers.len() {
113
if extend_header_with_unknown_column {
114
headers.push(column_name(i));
115
column_types.push(Default::default());
116
nulls.push(false);
117
} else {
118
break;
119
}
120
}
121
122
if infer_all_as_str {
123
column_types[i].insert(DataType::String);
124
continue;
125
}
126
127
if slice.is_empty() {
128
nulls[i] = true;
129
} else {
130
let slice_escaped = if needs_escaping && (slice.len() >= 2) {
131
&slice[1..(slice.len() - 1)]
132
} else {
133
slice
134
};
135
let s = String::from_utf8_lossy(slice_escaped);
136
let dtype = match &parse_options.null_values {
137
None => Some(infer_field_schema(
138
&s,
139
parse_options.try_parse_dates,
140
parse_options.decimal_comma,
141
)),
142
Some(NullValues::AllColumns(names)) => {
143
if !names.iter().any(|nv| nv == s.as_ref()) {
144
Some(infer_field_schema(
145
&s,
146
parse_options.try_parse_dates,
147
parse_options.decimal_comma,
148
))
149
} else {
150
None
151
}
152
},
153
Some(NullValues::AllColumnsSingle(name)) => {
154
if s.as_ref() != name.as_str() {
155
Some(infer_field_schema(
156
&s,
157
parse_options.try_parse_dates,
158
parse_options.decimal_comma,
159
))
160
} else {
161
None
162
}
163
},
164
Some(NullValues::Named(names)) => {
165
let current_name = &headers[i];
166
let null_name = &names.iter().find(|name| name.0 == current_name);
167
168
if let Some(null_name) = null_name {
169
if null_name.1.as_str() != s.as_ref() {
170
Some(infer_field_schema(
171
&s,
172
parse_options.try_parse_dates,
173
parse_options.decimal_comma,
174
))
175
} else {
176
None
177
}
178
} else {
179
Some(infer_field_schema(
180
&s,
181
parse_options.try_parse_dates,
182
parse_options.decimal_comma,
183
))
184
}
185
},
186
};
187
if let Some(dtype) = dtype {
188
column_types[i].insert(dtype);
189
}
190
}
191
}
192
}
193
194
fn build_schema(
195
headers: &[PlSmallStr],
196
column_types: &[PlHashSet<DataType>],
197
schema_overwrite: Option<&Schema>,
198
) -> Schema {
199
assert!(headers.len() == column_types.len());
200
201
let get_schema_overwrite = |field_name| {
202
if let Some(schema_overwrite) = schema_overwrite {
203
// Apply schema_overwrite by column name only. Positional overrides are handled
204
// separately via dtype_overwrite.
205
if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {
206
return Some((name.clone(), dtype.clone()));
207
}
208
}
209
210
None
211
};
212
213
Schema::from_iter(
214
headers
215
.iter()
216
.zip(column_types)
217
.map(|(field_name, type_possibilities)| {
218
let (name, dtype) = get_schema_overwrite(field_name).unwrap_or_else(|| {
219
(
220
field_name.clone(),
221
finish_infer_field_schema(type_possibilities),
222
)
223
});
224
225
Field::new(name, dtype)
226
}),
227
)
228
}
229
230
pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {
231
// determine data type based on possible types
232
// if there are incompatible types, use DataType::String
233
match possibilities.len() {
234
1 => possibilities.iter().next().unwrap().clone(),
235
2 if possibilities.contains(&DataType::Int64)
236
&& possibilities.contains(&DataType::Float64) =>
237
{
238
// we have an integer and double, fall down to double
239
DataType::Float64
240
},
241
// default to String for conflicting datatypes (e.g bool and int)
242
_ => DataType::String,
243
}
244
}
245
246
/// Infer the data type of a record
247
pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {
248
// when quoting is enabled in the reader, these quotes aren't escaped, we default to
249
// String for them
250
let bytes = string.as_bytes();
251
if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {
252
if try_parse_dates {
253
#[cfg(feature = "polars-time")]
254
{
255
match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {
256
Some(pattern_with_offset) => match pattern_with_offset {
257
Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
258
DataType::Datetime(TimeUnit::Microseconds, None)
259
},
260
Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
261
Pattern::DatetimeYMDZ => {
262
DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
263
},
264
Pattern::Time => DataType::Time,
265
},
266
None => DataType::String,
267
}
268
}
269
#[cfg(not(feature = "polars-time"))]
270
{
271
panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
272
}
273
} else {
274
DataType::String
275
}
276
}
277
// match regex in a particular order
278
else if BOOLEAN_RE.is_match(string) {
279
DataType::Boolean
280
} else if !decimal_comma && FLOAT_RE.is_match(string)
281
|| decimal_comma && FLOAT_RE_DECIMAL.is_match(string)
282
{
283
DataType::Float64
284
} else if INTEGER_RE.is_match(string) {
285
DataType::Int64
286
} else if try_parse_dates {
287
#[cfg(feature = "polars-time")]
288
{
289
match date_infer::infer_pattern_single(string) {
290
Some(pattern_with_offset) => match pattern_with_offset {
291
Pattern::DatetimeYMD | Pattern::DatetimeDMY => {
292
DataType::Datetime(TimeUnit::Microseconds, None)
293
},
294
Pattern::DateYMD | Pattern::DateDMY => DataType::Date,
295
Pattern::DatetimeYMDZ => {
296
DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))
297
},
298
Pattern::Time => DataType::Time,
299
},
300
None => DataType::String,
301
}
302
}
303
#[cfg(not(feature = "polars-time"))]
304
{
305
panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")
306
}
307
} else {
308
DataType::String
309
}
310
}
311
312
fn column_name(i: usize) -> PlSmallStr {
313
format_pl_smallstr!("column_{}", i + 1)
314
}
315
316