Path: blob/main/crates/polars-io/src/csv/read/schema_inference.rs
8424 views
use polars_buffer::Buffer;1use polars_core::prelude::*;2#[cfg(feature = "polars-time")]3use polars_time::chunkedarray::string::infer as date_infer;4#[cfg(feature = "polars-time")]5use polars_time::prelude::string::Pattern;6use polars_utils::format_pl_smallstr;78use super::splitfields::SplitFields;9use super::{CsvParseOptions, NullValues};10use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};1112/// Low-level CSV schema inference function.13///14/// Use `read_until_start_and_infer_schema` instead.15#[allow(clippy::too_many_arguments)]16pub(super) fn infer_file_schema_impl(17header_line: &Option<Buffer<u8>>,18content_lines: &[Buffer<u8>],19infer_all_as_str: bool,20parse_options: &CsvParseOptions,21schema_overwrite: Option<&Schema>,22) -> Schema {23let mut headers = header_line24.as_ref()25.map(|line| infer_headers(line, parse_options))26.unwrap_or_else(|| Vec::with_capacity(8));2728let extend_header_with_unknown_column = header_line.is_none();2930let mut column_types = vec![PlHashSet::<DataType>::with_capacity(4); headers.len()];31let mut nulls = vec![false; headers.len()];3233for content_line in content_lines {34infer_types_from_line(35content_line,36infer_all_as_str,37&mut headers,38extend_header_with_unknown_column,39parse_options,40&mut column_types,41&mut nulls,42);43}4445build_schema(&headers, &column_types, schema_overwrite)46}4748fn infer_headers(mut header_line: &[u8], parse_options: &CsvParseOptions) -> Vec<PlSmallStr> {49let len = header_line.len();5051if header_line.last().copied() == Some(b'\r') {52header_line = &header_line[..len - 1];53}5455let byterecord = SplitFields::new(56header_line,57parse_options.separator,58parse_options.quote_char,59parse_options.eol_char,60);6162let headers = byterecord63.map(|(slice, needs_escaping)| {64let slice_escaped = if needs_escaping && (slice.len() >= 2) {65&slice[1..(slice.len() - 1)]66} else {67slice68};69String::from_utf8_lossy(slice_escaped)70})71.collect::<Vec<_>>();7273let mut deduplicated_headers = Vec::with_capacity(headers.len());74let mut header_names = PlHashMap::with_capacity(headers.len());7576for name in &headers {77let count = header_names.entry(name.as_ref()).or_insert(0usize);78if *count != 0 {79deduplicated_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))80} else {81deduplicated_headers.push(PlSmallStr::from_str(name))82}83*count += 1;84}8586deduplicated_headers87}8889fn infer_types_from_line(90mut line: &[u8],91infer_all_as_str: bool,92headers: &mut Vec<PlSmallStr>,93extend_header_with_unknown_column: bool,94parse_options: &CsvParseOptions,95column_types: &mut Vec<PlHashSet<DataType>>,96nulls: &mut Vec<bool>,97) {98let line_len = line.len();99if line.last().copied() == Some(b'\r') {100line = &line[..line_len - 1];101}102103let record = SplitFields::new(104line,105parse_options.separator,106parse_options.quote_char,107parse_options.eol_char,108);109110for (i, (slice, needs_escaping)) in record.enumerate() {111if i >= headers.len() {112if extend_header_with_unknown_column {113headers.push(column_name(i));114column_types.push(Default::default());115nulls.push(false);116} else {117break;118}119}120121if infer_all_as_str {122column_types[i].insert(DataType::String);123continue;124}125126if slice.is_empty() {127nulls[i] = true;128} else {129let slice_escaped = if needs_escaping && (slice.len() >= 2) {130&slice[1..(slice.len() - 1)]131} else {132slice133};134let s = String::from_utf8_lossy(slice_escaped);135let dtype = match &parse_options.null_values {136None => Some(infer_field_schema(137&s,138parse_options.try_parse_dates,139parse_options.decimal_comma,140)),141Some(NullValues::AllColumns(names)) => {142if !names.iter().any(|nv| nv == s.as_ref()) {143Some(infer_field_schema(144&s,145parse_options.try_parse_dates,146parse_options.decimal_comma,147))148} else {149None150}151},152Some(NullValues::AllColumnsSingle(name)) => {153if s.as_ref() != name.as_str() {154Some(infer_field_schema(155&s,156parse_options.try_parse_dates,157parse_options.decimal_comma,158))159} else {160None161}162},163Some(NullValues::Named(names)) => {164let current_name = &headers[i];165let null_name = &names.iter().find(|name| name.0 == current_name);166167if let Some(null_name) = null_name {168if null_name.1.as_str() != s.as_ref() {169Some(infer_field_schema(170&s,171parse_options.try_parse_dates,172parse_options.decimal_comma,173))174} else {175None176}177} else {178Some(infer_field_schema(179&s,180parse_options.try_parse_dates,181parse_options.decimal_comma,182))183}184},185};186if let Some(dtype) = dtype {187column_types[i].insert(dtype);188}189}190}191}192193fn build_schema(194headers: &[PlSmallStr],195column_types: &[PlHashSet<DataType>],196schema_overwrite: Option<&Schema>,197) -> Schema {198assert!(headers.len() == column_types.len());199200let get_schema_overwrite = |field_name| {201if let Some(schema_overwrite) = schema_overwrite {202// Apply schema_overwrite by column name only. Positional overrides are handled203// separately via dtype_overwrite.204if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {205return Some((name.clone(), dtype.clone()));206}207}208209None210};211212Schema::from_iter(213headers214.iter()215.zip(column_types)216.map(|(field_name, type_possibilities)| {217let (name, dtype) = get_schema_overwrite(field_name).unwrap_or_else(|| {218(219field_name.clone(),220finish_infer_field_schema(type_possibilities),221)222});223224Field::new(name, dtype)225}),226)227}228229pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {230// determine data type based on possible types231// if there are incompatible types, use DataType::String232match possibilities.len() {2331 => possibilities.iter().next().unwrap().clone(),2342 if possibilities.contains(&DataType::Int64)235&& possibilities.contains(&DataType::Float64) =>236{237// we have an integer and double, fall down to double238DataType::Float64239},240// default to String for conflicting datatypes (e.g bool and int)241_ => DataType::String,242}243}244245/// Infer the data type of a record246pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {247// when quoting is enabled in the reader, these quotes aren't escaped, we default to248// String for them249let bytes = string.as_bytes();250if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {251if try_parse_dates {252#[cfg(feature = "polars-time")]253{254match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {255Some(pattern_with_offset) => match pattern_with_offset {256Pattern::DatetimeYMD | Pattern::DatetimeDMY => {257DataType::Datetime(TimeUnit::Microseconds, None)258},259Pattern::DateYMD | Pattern::DateDMY => DataType::Date,260Pattern::DatetimeYMDZ => {261DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))262},263Pattern::Time => DataType::Time,264},265None => DataType::String,266}267}268#[cfg(not(feature = "polars-time"))]269{270panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")271}272} else {273DataType::String274}275}276// match regex in a particular order277else if BOOLEAN_RE.is_match(string) {278DataType::Boolean279} else if !decimal_comma && FLOAT_RE.is_match(string)280|| decimal_comma && FLOAT_RE_DECIMAL.is_match(string)281{282DataType::Float64283} else if INTEGER_RE.is_match(string) {284DataType::Int64285} else if try_parse_dates {286#[cfg(feature = "polars-time")]287{288match date_infer::infer_pattern_single(string) {289Some(pattern_with_offset) => match pattern_with_offset {290Pattern::DatetimeYMD | Pattern::DatetimeDMY => {291DataType::Datetime(TimeUnit::Microseconds, None)292},293Pattern::DateYMD | Pattern::DateDMY => DataType::Date,294Pattern::DatetimeYMDZ => {295DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))296},297Pattern::Time => DataType::Time,298},299None => DataType::String,300}301}302#[cfg(not(feature = "polars-time"))]303{304panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")305}306} else {307DataType::String308}309}310311fn column_name(i: usize) -> PlSmallStr {312format_pl_smallstr!("column_{}", i + 1)313}314315316