Path: blob/main/crates/polars-json/src/json/infer_schema.rs
6939 views
use std::borrow::Borrow;12use arrow::datatypes::{ArrowDataType, Field};3use indexmap::map::Entry;4use polars_utils::pl_str::PlSmallStr;5use simd_json::borrowed::Object;6use simd_json::{BorrowedValue, StaticNode};78use super::*;910const ITEM_NAME: &str = "item";1112/// Infers [`ArrowDataType`] from [`Value`][Value].13///14/// [Value]: simd_json::value::Value15pub fn infer(json: &BorrowedValue) -> PolarsResult<ArrowDataType> {16Ok(match json {17BorrowedValue::Static(StaticNode::Bool(_)) => ArrowDataType::Boolean,18BorrowedValue::Static(StaticNode::U64(_) | StaticNode::I64(_)) => ArrowDataType::Int64,19BorrowedValue::Static(StaticNode::F64(_)) => ArrowDataType::Float64,20BorrowedValue::Static(StaticNode::Null) => ArrowDataType::Null,21BorrowedValue::Array(array) => infer_array(array)?,22BorrowedValue::String(_) => ArrowDataType::LargeUtf8,23BorrowedValue::Object(inner) => infer_object(inner)?,24})25}2627fn infer_object(inner: &Object) -> PolarsResult<ArrowDataType> {28let fields = inner29.iter()30.map(|(key, value)| infer(value).map(|dt| (key, dt)))31.map(|maybe_dt| {32let (key, dt) = maybe_dt?;33Ok(Field::new(key.as_ref().into(), dt, true))34})35.collect::<PolarsResult<Vec<_>>>()?;36Ok(ArrowDataType::Struct(fields))37}3839fn infer_array(values: &[BorrowedValue]) -> PolarsResult<ArrowDataType> {40let types = values41.iter()42.map(infer)43// deduplicate entries44.collect::<PolarsResult<PlHashSet<_>>>()?;4546let dt = if !types.is_empty() {47let types = types.into_iter().collect::<Vec<_>>();48coerce_dtype(&types)49} else {50ArrowDataType::Null51};5253Ok(ArrowDataType::LargeList(Box::new(Field::new(54PlSmallStr::from_static(ITEM_NAME),55dt,56true,57))))58}5960/// Coerce an heterogeneous set of [`ArrowDataType`] into a single one. Rules:61/// * The empty set is coerced to `Null`62/// * `Int64` and `Float64` are `Float64`63/// * Lists and scalars are coerced to a list of a compatible scalar64/// * Structs contain the union of all fields65/// * All other types are coerced to `Utf8`66pub(crate) fn coerce_dtype<A: Borrow<ArrowDataType>>(datatypes: &[A]) -> ArrowDataType {67use ArrowDataType::*;6869if datatypes.is_empty() {70return Null;71}7273let are_all_equal = datatypes.windows(2).all(|w| w[0].borrow() == w[1].borrow());7475if are_all_equal {76return datatypes[0].borrow().clone();77}78let mut are_all_structs = true;79let mut are_all_lists = true;80for dt in datatypes {81are_all_structs &= matches!(dt.borrow(), Struct(_));82are_all_lists &= matches!(dt.borrow(), LargeList(_));83}8485if are_all_structs {86// all are structs => union of all fields (that may have equal names)87let fields = datatypes.iter().fold(vec![], |mut acc, dt| {88if let Struct(new_fields) = dt.borrow() {89acc.extend(new_fields);90};91acc92});93// group fields by unique94let fields = fields.iter().fold(95PlIndexMap::<&str, PlHashSet<&ArrowDataType>>::default(),96|mut acc, field| {97match acc.entry(field.name.as_str()) {98Entry::Occupied(mut v) => {99v.get_mut().insert(&field.dtype);100},101Entry::Vacant(v) => {102let mut a = PlHashSet::default();103a.insert(&field.dtype);104v.insert(a);105},106}107acc108},109);110// and finally, coerce each of the fields within the same name111let fields = fields112.into_iter()113.map(|(name, dts)| {114let dts = dts.into_iter().collect::<Vec<_>>();115Field::new(name.into(), coerce_dtype(&dts), true)116})117.collect();118return Struct(fields);119} else if are_all_lists {120let inner_types: Vec<&ArrowDataType> = datatypes121.iter()122.map(|dt| {123if let LargeList(inner) = dt.borrow() {124inner.dtype()125} else {126unreachable!();127}128})129.collect();130return LargeList(Box::new(Field::new(131PlSmallStr::from_static(ITEM_NAME),132coerce_dtype(inner_types.as_slice()),133true,134)));135} else if datatypes.len() > 2 {136return datatypes137.iter()138.map(|t| t.borrow().clone())139.reduce(|a, b| coerce_dtype(&[a, b]))140.expect("not empty");141}142let (lhs, rhs) = (datatypes[0].borrow(), datatypes[1].borrow());143144match (lhs, rhs) {145(lhs, rhs) if lhs == rhs => lhs.clone(),146(LargeList(lhs), LargeList(rhs)) => {147let inner = coerce_dtype(&[lhs.dtype(), rhs.dtype()]);148LargeList(Box::new(Field::new(149PlSmallStr::from_static(ITEM_NAME),150inner,151true,152)))153},154(scalar, LargeList(list)) => {155let inner = coerce_dtype(&[scalar, list.dtype()]);156LargeList(Box::new(Field::new(157PlSmallStr::from_static(ITEM_NAME),158inner,159true,160)))161},162(LargeList(list), scalar) => {163let inner = coerce_dtype(&[scalar, list.dtype()]);164LargeList(Box::new(Field::new(165PlSmallStr::from_static(ITEM_NAME),166inner,167true,168)))169},170(Float64, Int64) => Float64,171(Int64, Float64) => Float64,172(Int64, Boolean) => Int64,173(Boolean, Int64) => Int64,174(Null, rhs) => rhs.clone(),175(lhs, Null) => lhs.clone(),176(_, _) => LargeUtf8,177}178}179180181