Path: blob/main/crates/polars-json/src/json/infer_schema.rs
8427 views
use std::borrow::Borrow;12use arrow::datatypes::{ArrowDataType, Field};3use indexmap::map::Entry;4use polars_utils::pl_str::PlSmallStr;5use simd_json::borrowed::Object;6use simd_json::{BorrowedValue, StaticNode};78use super::*;910const ITEM_NAME: &str = "item";1112/// Infers [`ArrowDataType`] from [`Value`][Value].13///14/// [Value]: simd_json::value::Value15pub fn infer(json: &BorrowedValue) -> PolarsResult<ArrowDataType> {16Ok(match json {17BorrowedValue::Static(StaticNode::Bool(_)) => ArrowDataType::Boolean,18BorrowedValue::Static(StaticNode::I64(_)) => ArrowDataType::Int64,19BorrowedValue::Static(StaticNode::U64(x)) if *x <= i64::MAX as u64 => ArrowDataType::Int64,20BorrowedValue::Static(StaticNode::U64(_) | StaticNode::U128(_) | StaticNode::I128(_)) => {21ArrowDataType::Int12822},23BorrowedValue::Static(StaticNode::F64(_)) => ArrowDataType::Float64,24BorrowedValue::Static(StaticNode::Null) => ArrowDataType::Null,25BorrowedValue::Array(array) => infer_array(array)?,26BorrowedValue::String(_) => ArrowDataType::LargeUtf8,27BorrowedValue::Object(inner) => infer_object(inner)?,28})29}3031fn infer_object(inner: &Object) -> PolarsResult<ArrowDataType> {32let fields = inner33.iter()34.map(|(key, value)| infer(value).map(|dt| (key, dt)))35.map(|maybe_dt| {36let (key, dt) = maybe_dt?;37Ok(Field::new(key.as_ref().into(), dt, true))38})39.collect::<PolarsResult<Vec<_>>>()?;40Ok(ArrowDataType::Struct(fields))41}4243fn infer_array(values: &[BorrowedValue]) -> PolarsResult<ArrowDataType> {44let types = values45.iter()46.map(infer)47// deduplicate entries48.collect::<PolarsResult<PlIndexSet<_>>>()?;4950let dt = if !types.is_empty() {51let types = types.into_iter().collect::<Vec<_>>();52coerce_dtype(&types)53} else {54ArrowDataType::Null55};5657Ok(ArrowDataType::LargeList(Box::new(Field::new(58PlSmallStr::from_static(ITEM_NAME),59dt,60true,61))))62}6364/// Coerce an heterogeneous set of [`ArrowDataType`] into a single one. Rules:65/// * The empty set is coerced to `Null`66/// * `Int64` and `Float64` are `Float64`67/// * Lists and scalars are coerced to a list of a compatible scalar68/// * Structs contain the union of all fields69/// * All other types are coerced to `Utf8`70pub(crate) fn coerce_dtype<A: Borrow<ArrowDataType>>(datatypes: &[A]) -> ArrowDataType {71use ArrowDataType::*;7273if datatypes.is_empty() {74return Null;75}7677let are_all_equal = datatypes.windows(2).all(|w| w[0].borrow() == w[1].borrow());7879if are_all_equal {80return datatypes[0].borrow().clone();81}82let mut are_all_structs = true;83let mut are_all_lists = true;84for dt in datatypes {85are_all_structs &= matches!(dt.borrow(), Struct(_));86are_all_lists &= matches!(dt.borrow(), LargeList(_));87}8889if are_all_structs {90// all are structs => union of all fields (that may have equal names)91let fields = datatypes.iter().fold(vec![], |mut acc, dt| {92if let Struct(new_fields) = dt.borrow() {93acc.extend(new_fields);94};95acc96});97// group fields by unique98let fields = fields.iter().fold(99PlIndexMap::<&str, PlHashSet<&ArrowDataType>>::default(),100|mut acc, field| {101match acc.entry(field.name.as_str()) {102Entry::Occupied(mut v) => {103v.get_mut().insert(&field.dtype);104},105Entry::Vacant(v) => {106let mut a = PlHashSet::default();107a.insert(&field.dtype);108v.insert(a);109},110}111acc112},113);114// and finally, coerce each of the fields within the same name115let fields = fields116.into_iter()117.map(|(name, dts)| {118let dts = dts.into_iter().collect::<Vec<_>>();119Field::new(name.into(), coerce_dtype(&dts), true)120})121.collect();122return Struct(fields);123} else if are_all_lists {124let inner_types: Vec<&ArrowDataType> = datatypes125.iter()126.map(|dt| {127if let LargeList(inner) = dt.borrow() {128inner.dtype()129} else {130unreachable!();131}132})133.collect();134return LargeList(Box::new(Field::new(135PlSmallStr::from_static(ITEM_NAME),136coerce_dtype(inner_types.as_slice()),137true,138)));139} else if datatypes.len() > 2 {140return datatypes141.iter()142.map(|t| t.borrow().clone())143.reduce(|a, b| coerce_dtype(&[a, b]))144.expect("not empty");145}146let (lhs, rhs) = (datatypes[0].borrow(), datatypes[1].borrow());147148match (lhs, rhs) {149(lhs, rhs) if lhs == rhs => lhs.clone(),150(LargeList(lhs), LargeList(rhs)) => {151let inner = coerce_dtype(&[lhs.dtype(), rhs.dtype()]);152LargeList(Box::new(Field::new(153PlSmallStr::from_static(ITEM_NAME),154inner,155true,156)))157},158(scalar, LargeList(list)) => {159let inner = coerce_dtype(&[scalar, list.dtype()]);160LargeList(Box::new(Field::new(161PlSmallStr::from_static(ITEM_NAME),162inner,163true,164)))165},166(LargeList(list), scalar) => {167let inner = coerce_dtype(&[scalar, list.dtype()]);168LargeList(Box::new(Field::new(169PlSmallStr::from_static(ITEM_NAME),170inner,171true,172)))173},174(Float64, Int64) => Float64,175(Int64, Float64) => Float64,176(Int64, Boolean) => Int64,177(Boolean, Int64) => Int64,178(Null, rhs) => rhs.clone(),179(lhs, Null) => lhs.clone(),180(_, _) => LargeUtf8,181}182}183184185