Path: blob/main/crates/polars-io/src/csv/read/splitfields.rs
6939 views
#![allow(unsafe_op_in_unsafe_fn)]1#[cfg(not(feature = "simd"))]2mod inner {3/// An adapted version of std::iter::Split.4/// This exists solely because we cannot split the lines naively as5pub(crate) struct SplitFields<'a> {6v: &'a [u8],7separator: u8,8finished: bool,9quote_char: u8,10quoting: bool,11eol_char: u8,12}1314impl<'a> SplitFields<'a> {15pub(crate) fn new(16slice: &'a [u8],17separator: u8,18quote_char: Option<u8>,19eol_char: u8,20) -> Self {21Self {22v: slice,23separator,24finished: false,25quote_char: quote_char.unwrap_or(b'"'),26quoting: quote_char.is_some(),27eol_char,28}29}3031unsafe fn finish_eol(32&mut self,33need_escaping: bool,34idx: usize,35) -> Option<(&'a [u8], bool)> {36self.finished = true;37debug_assert!(idx <= self.v.len());38Some((self.v.get_unchecked(..idx), need_escaping))39}4041fn finish(&mut self, need_escaping: bool) -> Option<(&'a [u8], bool)> {42self.finished = true;43Some((self.v, need_escaping))44}4546fn eof_eol(&self, current_ch: u8) -> bool {47current_ch == self.separator || current_ch == self.eol_char48}49}5051impl<'a> Iterator for SplitFields<'a> {52// the bool is used to indicate that it requires escaping53type Item = (&'a [u8], bool);5455#[inline]56fn next(&mut self) -> Option<(&'a [u8], bool)> {57if self.finished {58return None;59} else if self.v.is_empty() {60return self.finish(false);61}6263let mut needs_escaping = false;64// There can be strings with separators:65// "Street, City",6667// SAFETY:68// we have checked bounds69let pos = if self.quoting && unsafe { *self.v.get_unchecked(0) } == self.quote_char {70needs_escaping = true;71// There can be pair of double-quotes within string.72// Each of the embedded double-quote characters must be represented73// by a pair of double-quote characters:74// e.g. 1997,Ford,E350,"Super, ""luxurious"" truck",200207576// denotes if we are in a string field, started with a quote77let mut in_field = false;7879let mut idx = 0u32;80let mut current_idx = 0u32;81// micro optimizations82#[allow(clippy::explicit_counter_loop)]83for &c in self.v.iter() {84if c == self.quote_char {85// toggle between string field enclosure86// if we encounter a starting '"' -> in_field = true;87// if we encounter a closing '"' -> in_field = false;88in_field = !in_field;89}9091if !in_field && self.eof_eol(c) {92if c == self.eol_char {93// SAFETY:94// we are in bounds95return unsafe {96self.finish_eol(needs_escaping, current_idx as usize)97};98}99idx = current_idx;100break;101}102current_idx += 1;103}104105if idx == 0 {106return self.finish(needs_escaping);107}108109idx as usize110} else {111match self.v.iter().position(|&c| self.eof_eol(c)) {112None => return self.finish(needs_escaping),113Some(idx) => unsafe {114// SAFETY:115// idx was just found116if *self.v.get_unchecked(idx) == self.eol_char {117return self.finish_eol(needs_escaping, idx);118} else {119idx120}121},122}123};124125unsafe {126debug_assert!(pos <= self.v.len());127// SAFETY:128// we are in bounds129let ret = Some((self.v.get_unchecked(..pos), needs_escaping));130self.v = self.v.get_unchecked(pos + 1..);131ret132}133}134}135}136137#[cfg(feature = "simd")]138mod inner {139use std::simd::prelude::*;140141use polars_utils::clmul::prefix_xorsum_inclusive;142143const SIMD_SIZE: usize = 64;144type SimdVec = u8x64;145146/// An adapted version of std::iter::Split.147/// This exists solely because we cannot split the lines naively as148pub(crate) struct SplitFields<'a> {149pub v: &'a [u8],150separator: u8,151pub finished: bool,152quote_char: u8,153quoting: bool,154eol_char: u8,155simd_separator: SimdVec,156simd_eol_char: SimdVec,157simd_quote_char: SimdVec,158previous_valid_ends: u64,159}160161impl<'a> SplitFields<'a> {162pub(crate) fn new(163slice: &'a [u8],164separator: u8,165quote_char: Option<u8>,166eol_char: u8,167) -> Self {168let simd_separator = SimdVec::splat(separator);169let simd_eol_char = SimdVec::splat(eol_char);170let quoting = quote_char.is_some();171let quote_char = quote_char.unwrap_or(b'"');172let simd_quote_char = SimdVec::splat(quote_char);173174Self {175v: slice,176separator,177finished: false,178quote_char,179quoting,180eol_char,181simd_separator,182simd_eol_char,183simd_quote_char,184previous_valid_ends: 0,185}186}187188unsafe fn finish_eol(189&mut self,190need_escaping: bool,191pos: usize,192) -> Option<(&'a [u8], bool)> {193self.finished = true;194debug_assert!(pos <= self.v.len());195Some((self.v.get_unchecked(..pos), need_escaping))196}197198#[inline]199fn finish(&mut self, need_escaping: bool) -> Option<(&'a [u8], bool)> {200self.finished = true;201Some((self.v, need_escaping))202}203204fn eof_eol(&self, current_ch: u8) -> bool {205current_ch == self.separator || current_ch == self.eol_char206}207}208209impl<'a> Iterator for SplitFields<'a> {210// the bool is used to indicate that it requires escaping211type Item = (&'a [u8], bool);212213#[inline]214fn next(&mut self) -> Option<(&'a [u8], bool)> {215// This must be before we check the cached value216if self.finished {217return None;218}219// Then check cached value as this is hot.220if self.previous_valid_ends != 0 {221let pos = self.previous_valid_ends.trailing_zeros() as usize;222self.previous_valid_ends >>= (pos + 1) as u64;223224unsafe {225debug_assert!(pos < self.v.len());226// SAFETY:227// we are in bounds228let needs_escaping = self229.v230.first()231.map(|c| *c == self.quote_char && self.quoting)232.unwrap_or(false);233234if *self.v.get_unchecked(pos) == self.eol_char {235return self.finish_eol(needs_escaping, pos);236}237238let bytes = self.v.get_unchecked(..pos);239240self.v = self.v.get_unchecked(pos + 1..);241let ret = Some((bytes, needs_escaping));242243return ret;244}245}246if self.v.is_empty() {247return self.finish(false);248}249250let mut needs_escaping = false;251// There can be strings with separators:252// "Street, City",253254// SAFETY:255// we have checked bounds256let pos = if self.quoting && unsafe { *self.v.get_unchecked(0) } == self.quote_char {257// Start of an enclosed field258let mut total_idx = 0;259needs_escaping = true;260let mut not_in_field_previous_iter = true;261262loop {263let bytes = unsafe { self.v.get_unchecked(total_idx..) };264265if bytes.len() > SIMD_SIZE {266let lane: [u8; SIMD_SIZE] = unsafe {267bytes268.get_unchecked(0..SIMD_SIZE)269.try_into()270.unwrap_unchecked()271};272let simd_bytes = SimdVec::from(lane);273let has_eol = simd_bytes.simd_eq(self.simd_eol_char);274let has_sep = simd_bytes.simd_eq(self.simd_separator);275let quote_mask = simd_bytes.simd_eq(self.simd_quote_char).to_bitmask();276let mut end_mask = (has_sep | has_eol).to_bitmask();277278let mut not_in_quote_field = prefix_xorsum_inclusive(quote_mask);279280if not_in_field_previous_iter {281not_in_quote_field = !not_in_quote_field;282}283not_in_field_previous_iter =284(not_in_quote_field & (1 << (SIMD_SIZE - 1))) > 0;285end_mask &= not_in_quote_field;286287if end_mask != 0 {288let pos = end_mask.trailing_zeros() as usize;289total_idx += pos;290debug_assert!(291self.v[total_idx] == self.eol_char292|| self.v[total_idx] == self.separator293);294295if pos == SIMD_SIZE - 1 {296self.previous_valid_ends = 0;297} else {298self.previous_valid_ends = end_mask >> (pos + 1) as u64;299}300301break;302} else {303total_idx += SIMD_SIZE;304}305} else {306// There can be a pair of double-quotes within a string.307// Each of the embedded double-quote characters must be represented308// by a pair of double-quote characters:309// e.g. 1997,Ford,E350,"Super, ""luxurious"" truck",20020310311// denotes if we are in a string field, started with a quote312let mut in_field = !not_in_field_previous_iter;313314// usize::MAX is unset.315let mut idx = usize::MAX;316let mut current_idx = 0;317// micro optimizations318#[allow(clippy::explicit_counter_loop)]319for &c in bytes.iter() {320if c == self.quote_char {321// toggle between string field enclosure322// if we encounter a starting '"' -> in_field = true;323// if we encounter a closing '"' -> in_field = false;324in_field = !in_field;325}326327if !in_field && self.eof_eol(c) {328idx = current_idx;329break;330}331current_idx += 1;332}333334if idx == usize::MAX {335return self.finish(needs_escaping);336}337338total_idx += idx;339debug_assert!(340self.v[total_idx] == self.eol_char341|| self.v[total_idx] == self.separator342);343break;344}345}346total_idx347} else {348// Start of an unenclosed field349let mut total_idx = 0;350351loop {352let bytes = unsafe { self.v.get_unchecked(total_idx..) };353354if bytes.len() > SIMD_SIZE {355let lane: [u8; SIMD_SIZE] = unsafe {356bytes357.get_unchecked(0..SIMD_SIZE)358.try_into()359.unwrap_unchecked()360};361let simd_bytes = SimdVec::from(lane);362let has_eol_char = simd_bytes.simd_eq(self.simd_eol_char);363let has_separator = simd_bytes.simd_eq(self.simd_separator);364let has_any_mask = (has_separator | has_eol_char).to_bitmask();365366if has_any_mask != 0 {367total_idx += has_any_mask.trailing_zeros() as usize;368break;369} else {370total_idx += SIMD_SIZE;371}372} else {373match bytes.iter().position(|&c| self.eof_eol(c)) {374None => return self.finish(needs_escaping),375Some(idx) => {376total_idx += idx;377break;378},379}380}381}382total_idx383};384385// Make sure the iterator is done when EOL.386let c = unsafe { *self.v.get_unchecked(pos) };387if c == self.eol_char {388// SAFETY:389// we are in bounds390return unsafe { self.finish_eol(needs_escaping, pos) };391}392393unsafe {394debug_assert!(pos < self.v.len());395// SAFETY:396// we are in bounds397let ret = Some((self.v.get_unchecked(..pos), needs_escaping));398self.v = self.v.get_unchecked(pos + 1..);399ret400}401}402}403}404405pub(crate) use inner::SplitFields;406407#[cfg(test)]408mod test {409use super::SplitFields;410411#[test]412fn test_splitfields() {413let input = "\"foo\",\"bar\"";414let mut fields = SplitFields::new(input.as_bytes(), b',', Some(b'"'), b'\n');415416assert_eq!(fields.next(), Some(("\"foo\"".as_bytes(), true)));417assert_eq!(fields.next(), Some(("\"bar\"".as_bytes(), true)));418assert_eq!(fields.next(), None);419420let input2 = "\"foo\n bar\";\"baz\";12345";421let mut fields2 = SplitFields::new(input2.as_bytes(), b';', Some(b'"'), b'\n');422423assert_eq!(fields2.next(), Some(("\"foo\n bar\"".as_bytes(), true)));424assert_eq!(fields2.next(), Some(("\"baz\"".as_bytes(), true)));425assert_eq!(fields2.next(), Some(("12345".as_bytes(), false)));426assert_eq!(fields2.next(), None);427}428}429430431