Path: blob/main/crates/polars-io/src/csv/read/schema_inference.rs
6939 views
use std::borrow::Cow;12use polars_core::prelude::*;3#[cfg(feature = "polars-time")]4use polars_time::chunkedarray::string::infer as date_infer;5#[cfg(feature = "polars-time")]6use polars_time::prelude::string::Pattern;7use polars_utils::format_pl_smallstr;89use super::parser::{SplitLines, is_comment_line, skip_bom, skip_line_ending};10use super::splitfields::SplitFields;11use super::{CsvEncoding, CsvParseOptions, CsvReadOptions, NullValues};12use crate::csv::read::parser::skip_lines_naive;13use crate::mmap::ReaderBytes;14use crate::utils::{BOOLEAN_RE, FLOAT_RE, FLOAT_RE_DECIMAL, INTEGER_RE};1516#[derive(Clone, Debug, Default)]17pub struct SchemaInferenceResult {18inferred_schema: SchemaRef,19rows_read: usize,20bytes_read: usize,21bytes_total: usize,22n_threads: Option<usize>,23}2425impl SchemaInferenceResult {26pub fn try_from_reader_bytes_and_options(27reader_bytes: &ReaderBytes,28options: &CsvReadOptions,29) -> PolarsResult<Self> {30let parse_options = options.get_parse_options();3132let infer_schema_length = options.infer_schema_length;33let has_header = options.has_header;34let schema_overwrite_arc = options.schema_overwrite.clone();35let schema_overwrite = schema_overwrite_arc.as_ref().map(|x| x.as_ref());36let skip_rows = options.skip_rows;37let skip_lines = options.skip_lines;38let skip_rows_after_header = options.skip_rows_after_header;39let raise_if_empty = options.raise_if_empty;40let n_threads = options.n_threads;4142let bytes_total = reader_bytes.len();4344let (inferred_schema, rows_read, bytes_read) = infer_file_schema(45reader_bytes,46&parse_options,47infer_schema_length,48has_header,49schema_overwrite,50skip_rows,51skip_lines,52skip_rows_after_header,53raise_if_empty,54)?;5556let this = Self {57inferred_schema: Arc::new(inferred_schema),58rows_read,59bytes_read,60bytes_total,61n_threads,62};6364Ok(this)65}6667pub fn with_inferred_schema(mut self, inferred_schema: SchemaRef) -> Self {68self.inferred_schema = inferred_schema;69self70}7172pub fn get_inferred_schema(&self) -> SchemaRef {73self.inferred_schema.clone()74}7576pub fn get_estimated_n_rows(&self) -> usize {77(self.rows_read as f64 / self.bytes_read as f64 * self.bytes_total as f64) as usize78}79}8081impl CsvReadOptions {82/// Note: This does not update the schema from the inference result.83pub fn update_with_inference_result(&mut self, si_result: &SchemaInferenceResult) {84self.n_threads = si_result.n_threads;85}86}8788pub fn finish_infer_field_schema(possibilities: &PlHashSet<DataType>) -> DataType {89// determine data type based on possible types90// if there are incompatible types, use DataType::String91match possibilities.len() {921 => possibilities.iter().next().unwrap().clone(),932 if possibilities.contains(&DataType::Int64)94&& possibilities.contains(&DataType::Float64) =>95{96// we have an integer and double, fall down to double97DataType::Float6498},99// default to String for conflicting datatypes (e.g bool and int)100_ => DataType::String,101}102}103104/// Infer the data type of a record105pub fn infer_field_schema(string: &str, try_parse_dates: bool, decimal_comma: bool) -> DataType {106// when quoting is enabled in the reader, these quotes aren't escaped, we default to107// String for them108let bytes = string.as_bytes();109if bytes.len() >= 2 && *bytes.first().unwrap() == b'"' && *bytes.last().unwrap() == b'"' {110if try_parse_dates {111#[cfg(feature = "polars-time")]112{113match date_infer::infer_pattern_single(&string[1..string.len() - 1]) {114Some(pattern_with_offset) => match pattern_with_offset {115Pattern::DatetimeYMD | Pattern::DatetimeDMY => {116DataType::Datetime(TimeUnit::Microseconds, None)117},118Pattern::DateYMD | Pattern::DateDMY => DataType::Date,119Pattern::DatetimeYMDZ => {120DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))121},122Pattern::Time => DataType::Time,123},124None => DataType::String,125}126}127#[cfg(not(feature = "polars-time"))]128{129panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")130}131} else {132DataType::String133}134}135// match regex in a particular order136else if BOOLEAN_RE.is_match(string) {137DataType::Boolean138} else if !decimal_comma && FLOAT_RE.is_match(string)139|| decimal_comma && FLOAT_RE_DECIMAL.is_match(string)140{141DataType::Float64142} else if INTEGER_RE.is_match(string) {143DataType::Int64144} else if try_parse_dates {145#[cfg(feature = "polars-time")]146{147match date_infer::infer_pattern_single(string) {148Some(pattern_with_offset) => match pattern_with_offset {149Pattern::DatetimeYMD | Pattern::DatetimeDMY => {150DataType::Datetime(TimeUnit::Microseconds, None)151},152Pattern::DateYMD | Pattern::DateDMY => DataType::Date,153Pattern::DatetimeYMDZ => {154DataType::Datetime(TimeUnit::Microseconds, Some(TimeZone::UTC))155},156Pattern::Time => DataType::Time,157},158None => DataType::String,159}160}161#[cfg(not(feature = "polars-time"))]162{163panic!("activate one of {{'dtype-date', 'dtype-datetime', dtype-time'}} features")164}165} else {166DataType::String167}168}169170#[inline]171fn parse_bytes_with_encoding(bytes: &[u8], encoding: CsvEncoding) -> PolarsResult<Cow<'_, str>> {172Ok(match encoding {173CsvEncoding::Utf8 => simdutf8::basic::from_utf8(bytes)174.map_err(|_| polars_err!(ComputeError: "invalid utf-8 sequence"))?175.into(),176CsvEncoding::LossyUtf8 => String::from_utf8_lossy(bytes),177})178}179180fn column_name(i: usize) -> PlSmallStr {181format_pl_smallstr!("column_{}", i + 1)182}183184#[allow(clippy::too_many_arguments)]185fn infer_file_schema_inner(186reader_bytes: &ReaderBytes,187parse_options: &CsvParseOptions,188max_read_rows: Option<usize>,189has_header: bool,190schema_overwrite: Option<&Schema>,191// we take &mut because we maybe need to skip more rows dependent192// on the schema inference193mut skip_rows: usize,194skip_rows_after_header: usize,195recursion_count: u8,196raise_if_empty: bool,197) -> PolarsResult<(Schema, usize, usize)> {198// keep track so that we can determine the amount of bytes read199let start_ptr = reader_bytes.as_ptr() as usize;200201// We use lossy utf8 here because we don't want the schema inference to fail on utf8.202// It may later.203let encoding = CsvEncoding::LossyUtf8;204205let bytes = skip_line_ending(skip_bom(reader_bytes), parse_options.eol_char);206if raise_if_empty {207polars_ensure!(!bytes.is_empty(), NoData: "empty CSV");208};209let mut lines = SplitLines::new(210bytes,211parse_options.quote_char,212parse_options.eol_char,213parse_options.comment_prefix.as_ref(),214)215.skip(skip_rows);216217// get or create header names218// when has_header is false, creates default column names with column_ prefix219220// skip lines that are comments221let mut first_line = None;222223for (i, line) in (&mut lines).enumerate() {224if !is_comment_line(line, parse_options.comment_prefix.as_ref()) {225first_line = Some(line);226skip_rows += i;227break;228}229}230231if first_line.is_none() {232first_line = lines.next();233}234235// now that we've found the first non-comment line we parse the headers, or we create a header236let mut headers: Vec<PlSmallStr> = if let Some(mut header_line) = first_line {237let len = header_line.len();238if len > 1 {239// remove carriage return240let trailing_byte = header_line[len - 1];241if trailing_byte == b'\r' {242header_line = &header_line[..len - 1];243}244}245246let byterecord = SplitFields::new(247header_line,248parse_options.separator,249parse_options.quote_char,250parse_options.eol_char,251);252if has_header {253let headers = byterecord254.map(|(slice, needs_escaping)| {255let slice_escaped = if needs_escaping && (slice.len() >= 2) {256&slice[1..(slice.len() - 1)]257} else {258slice259};260let s = parse_bytes_with_encoding(slice_escaped, encoding)?;261Ok(s)262})263.collect::<PolarsResult<Vec<_>>>()?;264265let mut final_headers = Vec::with_capacity(headers.len());266267let mut header_names = PlHashMap::with_capacity(headers.len());268269for name in &headers {270let count = header_names.entry(name.as_ref()).or_insert(0usize);271if *count != 0 {272final_headers.push(format_pl_smallstr!("{}_duplicated_{}", name, *count - 1))273} else {274final_headers.push(PlSmallStr::from_str(name))275}276*count += 1;277}278final_headers279} else {280byterecord281.enumerate()282.map(|(i, _s)| column_name(i))283.collect::<Vec<PlSmallStr>>()284}285} else if has_header && !bytes.is_empty() && recursion_count == 0 {286// there was no new line char. So we copy the whole buf and add one287// this is likely to be cheap as there are no rows.288let mut buf = Vec::with_capacity(bytes.len() + 2);289buf.extend_from_slice(bytes);290buf.push(parse_options.eol_char);291292return infer_file_schema_inner(293&ReaderBytes::Owned(buf.into()),294parse_options,295max_read_rows,296has_header,297schema_overwrite,298skip_rows,299skip_rows_after_header,300recursion_count + 1,301raise_if_empty,302);303} else if !raise_if_empty {304return Ok((Schema::default(), 0, 0));305} else {306polars_bail!(NoData: "empty CSV");307};308if !has_header {309// re-init lines so that the header is included in type inference.310lines = SplitLines::new(311bytes,312parse_options.quote_char,313parse_options.eol_char,314parse_options.comment_prefix.as_ref(),315)316.skip(skip_rows);317}318319// keep track of inferred field types320let mut column_types: Vec<PlHashSet<DataType>> =321vec![PlHashSet::with_capacity(4); headers.len()];322// keep track of columns with nulls323let mut nulls: Vec<bool> = vec![false; headers.len()];324325let mut rows_count = 0;326let mut fields = Vec::with_capacity(headers.len());327328// needed to prevent ownership going into the iterator loop329let records_ref = &mut lines;330331let mut end_ptr = start_ptr;332for mut line in records_ref333.take(match max_read_rows {334Some(max_read_rows) => {335if max_read_rows <= (usize::MAX - skip_rows_after_header) {336// read skip_rows_after_header more rows for inferring337// the correct schema as the first skip_rows_after_header338// rows will be skipped339max_read_rows + skip_rows_after_header340} else {341max_read_rows342}343},344None => usize::MAX,345})346.skip(skip_rows_after_header)347{348rows_count += 1;349// keep track so that we can determine the amount of bytes read350end_ptr = line.as_ptr() as usize + line.len();351352if line.is_empty() {353continue;354}355356// line is a comment -> skip357if is_comment_line(line, parse_options.comment_prefix.as_ref()) {358continue;359}360361let len = line.len();362if len > 1 {363// remove carriage return364let trailing_byte = line[len - 1];365if trailing_byte == b'\r' {366line = &line[..len - 1];367}368}369370let record = SplitFields::new(371line,372parse_options.separator,373parse_options.quote_char,374parse_options.eol_char,375);376377for (i, (slice, needs_escaping)) in record.enumerate() {378// When `has_header = False` and ``379// Increase the schema if the first line didn't have all columns.380if i >= headers.len() {381if !has_header {382headers.push(column_name(i));383column_types.push(Default::default());384nulls.push(false);385} else {386break;387}388}389390if slice.is_empty() {391unsafe { *nulls.get_unchecked_mut(i) = true };392} else {393let slice_escaped = if needs_escaping && (slice.len() >= 2) {394&slice[1..(slice.len() - 1)]395} else {396slice397};398let s = parse_bytes_with_encoding(slice_escaped, encoding)?;399let dtype = match &parse_options.null_values {400None => Some(infer_field_schema(401&s,402parse_options.try_parse_dates,403parse_options.decimal_comma,404)),405Some(NullValues::AllColumns(names)) => {406if !names.iter().any(|nv| nv == s.as_ref()) {407Some(infer_field_schema(408&s,409parse_options.try_parse_dates,410parse_options.decimal_comma,411))412} else {413None414}415},416Some(NullValues::AllColumnsSingle(name)) => {417if s.as_ref() != name.as_str() {418Some(infer_field_schema(419&s,420parse_options.try_parse_dates,421parse_options.decimal_comma,422))423} else {424None425}426},427Some(NullValues::Named(names)) => {428// SAFETY:429// we iterate over headers length.430let current_name = unsafe { headers.get_unchecked(i) };431let null_name = &names.iter().find(|name| name.0 == current_name);432433if let Some(null_name) = null_name {434if null_name.1.as_str() != s.as_ref() {435Some(infer_field_schema(436&s,437parse_options.try_parse_dates,438parse_options.decimal_comma,439))440} else {441None442}443} else {444Some(infer_field_schema(445&s,446parse_options.try_parse_dates,447parse_options.decimal_comma,448))449}450},451};452if let Some(dtype) = dtype {453unsafe { column_types.get_unchecked_mut(i).insert(dtype) };454}455}456}457}458459// build schema from inference results460for i in 0..headers.len() {461let field_name = &headers[i];462463if let Some(schema_overwrite) = schema_overwrite {464if let Some((_, name, dtype)) = schema_overwrite.get_full(field_name) {465fields.push(Field::new(name.clone(), dtype.clone()));466continue;467}468469// column might have been renamed470// execute only if schema is complete471if schema_overwrite.len() == headers.len() {472if let Some((name, dtype)) = schema_overwrite.get_at_index(i) {473fields.push(Field::new(name.clone(), dtype.clone()));474continue;475}476}477}478479let possibilities = &column_types[i];480let dtype = finish_infer_field_schema(possibilities);481fields.push(Field::new(field_name.clone(), dtype));482}483// if there is a single line after the header without an eol484// we copy the bytes add an eol and rerun this function485// so that the inference is consistent with and without eol char486if rows_count == 0487&& !reader_bytes.is_empty()488&& reader_bytes[reader_bytes.len() - 1] != parse_options.eol_char489&& recursion_count == 0490{491let mut rb = Vec::with_capacity(reader_bytes.len() + 1);492rb.extend_from_slice(reader_bytes);493rb.push(parse_options.eol_char);494return infer_file_schema_inner(495&ReaderBytes::Owned(rb.into()),496parse_options,497max_read_rows,498has_header,499schema_overwrite,500skip_rows,501skip_rows_after_header,502recursion_count + 1,503raise_if_empty,504);505}506507Ok((Schema::from_iter(fields), rows_count, end_ptr - start_ptr))508}509510/// Infer the schema of a CSV file by reading through the first n rows of the file,511/// with `max_read_rows` controlling the maximum number of rows to read.512///513/// If `max_read_rows` is not set, the whole file is read to infer its schema.514///515/// Returns516/// - inferred schema517/// - number of rows used for inference.518/// - bytes read519#[allow(clippy::too_many_arguments)]520pub fn infer_file_schema(521reader_bytes: &ReaderBytes,522parse_options: &CsvParseOptions,523max_read_rows: Option<usize>,524has_header: bool,525schema_overwrite: Option<&Schema>,526skip_rows: usize,527skip_lines: usize,528skip_rows_after_header: usize,529raise_if_empty: bool,530) -> PolarsResult<(Schema, usize, usize)> {531if skip_lines > 0 {532polars_ensure!(skip_rows == 0, InvalidOperation: "only one of 'skip_rows'/'skip_lines' may be set");533let bytes = skip_lines_naive(reader_bytes, parse_options.eol_char, skip_lines);534let reader_bytes = ReaderBytes::Borrowed(bytes);535infer_file_schema_inner(536&reader_bytes,537parse_options,538max_read_rows,539has_header,540schema_overwrite,541skip_rows,542skip_rows_after_header,5430,544raise_if_empty,545)546} else {547infer_file_schema_inner(548reader_bytes,549parse_options,550max_read_rows,551has_header,552schema_overwrite,553skip_rows,554skip_rows_after_header,5550,556raise_if_empty,557)558}559}560561562