Path: blob/main/crates/polars-arrow/src/io/ipc/read/schema.rs
6940 views
use std::sync::Arc;12use arrow_format::ipc::planus::ReadAsRoot;3use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef};4use polars_error::{PolarsResult, polars_bail, polars_err};5use polars_utils::pl_str::PlSmallStr;67use super::super::{IpcField, IpcSchema};8use super::{OutOfSpecKind, StreamMetadata};9use crate::datatypes::{10ArrowDataType, ArrowSchema, Extension, ExtensionType, Field, IntegerType, IntervalUnit,11Metadata, TimeUnit, UnionMode, UnionType, get_extension,12};1314fn try_unzip_vec<A, B, I: Iterator<Item = PolarsResult<(A, B)>>>(15iter: I,16) -> PolarsResult<(Vec<A>, Vec<B>)> {17let mut a = vec![];18let mut b = vec![];19for maybe_item in iter {20let (a_i, b_i) = maybe_item?;21a.push(a_i);22b.push(b_i);23}2425Ok((a, b))26}2728fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Field, IpcField)> {29let metadata = read_metadata(&ipc_field)?;3031let extension = metadata.as_ref().and_then(get_extension);3233let (dtype, ipc_field_) = get_dtype(ipc_field, extension, true)?;3435let field = Field {36name: PlSmallStr::from_str(37ipc_field38.name()?39.ok_or_else(|| polars_err!(oos = "Every field in IPC must have a name"))?,40),41dtype,42is_nullable: ipc_field.nullable()?,43metadata: metadata.map(Arc::new),44};4546Ok((field, ipc_field_))47}4849fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult<Option<Metadata>> {50Ok(if let Some(list) = field.custom_metadata()? {51let mut metadata_map = Metadata::new();52for kv in list {53let kv = kv?;54if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) {55metadata_map.insert(PlSmallStr::from_str(k), PlSmallStr::from_str(v));56}57}58Some(metadata_map)59} else {60None61})62}6364fn deserialize_integer(int: arrow_format::ipc::IntRef) -> PolarsResult<IntegerType> {65Ok(match (int.bit_width()?, int.is_signed()?) {66(8, true) => IntegerType::Int8,67(8, false) => IntegerType::UInt8,68(16, true) => IntegerType::Int16,69(16, false) => IntegerType::UInt16,70(32, true) => IntegerType::Int32,71(32, false) => IntegerType::UInt32,72(64, true) => IntegerType::Int64,73(64, false) => IntegerType::UInt64,74(128, true) => IntegerType::Int128,75_ => polars_bail!(oos = "IPC: indexType can only be 8, 16, 32, 64 or 128."),76})77}7879fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> PolarsResult<TimeUnit> {80use arrow_format::ipc::TimeUnit::*;81Ok(match time_unit {82Second => TimeUnit::Second,83Millisecond => TimeUnit::Millisecond,84Microsecond => TimeUnit::Microsecond,85Nanosecond => TimeUnit::Nanosecond,86})87}8889fn deserialize_time(time: TimeRef) -> PolarsResult<(ArrowDataType, IpcField)> {90let unit = deserialize_timeunit(time.unit()?)?;9192let dtype = match (time.bit_width()?, unit) {93(32, TimeUnit::Second) => ArrowDataType::Time32(TimeUnit::Second),94(32, TimeUnit::Millisecond) => ArrowDataType::Time32(TimeUnit::Millisecond),95(64, TimeUnit::Microsecond) => ArrowDataType::Time64(TimeUnit::Microsecond),96(64, TimeUnit::Nanosecond) => ArrowDataType::Time64(TimeUnit::Nanosecond),97(bits, precision) => {98polars_bail!(ComputeError:99"Time type with bit width of {bits} and unit of {precision:?}"100)101},102};103Ok((dtype, IpcField::default()))104}105106fn deserialize_timestamp(timestamp: TimestampRef) -> PolarsResult<(ArrowDataType, IpcField)> {107let timezone = timestamp.timezone()?;108let time_unit = deserialize_timeunit(timestamp.unit()?)?;109Ok((110ArrowDataType::Timestamp(time_unit, timezone.map(PlSmallStr::from_str)),111IpcField::default(),112))113}114115fn deserialize_union(union_: UnionRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {116let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse);117let ids = union_.type_ids()?.map(|x| x.iter().collect());118119let fields = field120.children()?121.ok_or_else(|| polars_err!(oos = "IPC: Union must contain children"))?;122if fields.is_empty() {123polars_bail!(oos = "IPC: Union must contain at least one child");124}125126let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {127let (field, fields) = deserialize_field(field?)?;128Ok((field, fields))129}))?;130let ipc_field = IpcField {131fields: ipc_fields,132dictionary_id: None,133};134Ok((135ArrowDataType::Union(Box::new(UnionType { fields, ids, mode })),136ipc_field,137))138}139140fn deserialize_map(map: MapRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {141let is_sorted = map.keys_sorted()?;142143let children = field144.children()?145.ok_or_else(|| polars_err!(oos = "IPC: Map must contain children"))?;146let inner = children147.get(0)148.ok_or_else(|| polars_err!(oos = "IPC: Map must contain one child"))??;149let (field, ipc_field) = deserialize_field(inner)?;150151let dtype = ArrowDataType::Map(Box::new(field), is_sorted);152Ok((153dtype,154IpcField {155fields: vec![ipc_field],156dictionary_id: None,157},158))159}160161fn deserialize_struct(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {162let fields = field163.children()?164.ok_or_else(|| polars_err!(oos = "IPC: Struct must contain children"))?;165let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {166let (field, fields) = deserialize_field(field?)?;167Ok((field, fields))168}))?;169let ipc_field = IpcField {170fields: ipc_fields,171dictionary_id: None,172};173Ok((ArrowDataType::Struct(fields), ipc_field))174}175176fn deserialize_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {177let children = field178.children()?179.ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;180let inner = children181.get(0)182.ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;183let (field, ipc_field) = deserialize_field(inner)?;184185Ok((186ArrowDataType::List(Box::new(field)),187IpcField {188fields: vec![ipc_field],189dictionary_id: None,190},191))192}193194fn deserialize_large_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {195let children = field196.children()?197.ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;198let inner = children199.get(0)200.ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;201let (field, ipc_field) = deserialize_field(inner)?;202203Ok((204ArrowDataType::LargeList(Box::new(field)),205IpcField {206fields: vec![ipc_field],207dictionary_id: None,208},209))210}211212fn deserialize_fixed_size_list(213list: FixedSizeListRef,214field: FieldRef,215) -> PolarsResult<(ArrowDataType, IpcField)> {216let children = field217.children()?218.ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain children"))?;219let inner = children220.get(0)221.ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain one child"))??;222let (field, ipc_field) = deserialize_field(inner)?;223224let size = list225.list_size()?226.try_into()227.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;228229Ok((230ArrowDataType::FixedSizeList(Box::new(field), size),231IpcField {232fields: vec![ipc_field],233dictionary_id: None,234},235))236}237238/// Get the Arrow data type from the flatbuffer Field table239fn get_dtype(240field: arrow_format::ipc::FieldRef,241extension: Extension,242may_be_dictionary: bool,243) -> PolarsResult<(ArrowDataType, IpcField)> {244if let Some(dictionary) = field.dictionary()? {245if may_be_dictionary {246let int = dictionary247.index_type()?248.ok_or_else(|| polars_err!(oos = "indexType is mandatory in Dictionary."))?;249let index_type = deserialize_integer(int)?;250let (inner, mut ipc_field) = get_dtype(field, extension, false)?;251ipc_field.dictionary_id = Some(dictionary.id()?);252return Ok((253ArrowDataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?),254ipc_field,255));256}257}258259if let Some(extension) = extension {260let (name, metadata) = extension;261let (dtype, fields) = get_dtype(field, None, false)?;262return Ok((263ArrowDataType::Extension(Box::new(ExtensionType {264name,265inner: dtype,266metadata,267})),268fields,269));270}271272let type_ = field273.type_()?274.ok_or_else(|| polars_err!(oos = "IPC: field type is mandatory"))?;275276use arrow_format::ipc::TypeRef::*;277Ok(match type_ {278Null(_) => (ArrowDataType::Null, IpcField::default()),279Bool(_) => (ArrowDataType::Boolean, IpcField::default()),280Int(int) => {281let dtype = deserialize_integer(int)?.into();282(dtype, IpcField::default())283},284Binary(_) => (ArrowDataType::Binary, IpcField::default()),285LargeBinary(_) => (ArrowDataType::LargeBinary, IpcField::default()),286Utf8(_) => (ArrowDataType::Utf8, IpcField::default()),287LargeUtf8(_) => (ArrowDataType::LargeUtf8, IpcField::default()),288BinaryView(_) => (ArrowDataType::BinaryView, IpcField::default()),289Utf8View(_) => (ArrowDataType::Utf8View, IpcField::default()),290FixedSizeBinary(fixed) => (291ArrowDataType::FixedSizeBinary(292fixed293.byte_width()?294.try_into()295.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?,296),297IpcField::default(),298),299FloatingPoint(float) => {300let dtype = match float.precision()? {301arrow_format::ipc::Precision::Half => ArrowDataType::Float16,302arrow_format::ipc::Precision::Single => ArrowDataType::Float32,303arrow_format::ipc::Precision::Double => ArrowDataType::Float64,304};305(dtype, IpcField::default())306},307Date(date) => {308let dtype = match date.unit()? {309arrow_format::ipc::DateUnit::Day => ArrowDataType::Date32,310arrow_format::ipc::DateUnit::Millisecond => ArrowDataType::Date64,311};312(dtype, IpcField::default())313},314Time(time) => deserialize_time(time)?,315Timestamp(timestamp) => deserialize_timestamp(timestamp)?,316Interval(interval) => {317let dtype = match interval.unit()? {318arrow_format::ipc::IntervalUnit::YearMonth => {319ArrowDataType::Interval(IntervalUnit::YearMonth)320},321arrow_format::ipc::IntervalUnit::DayTime => {322ArrowDataType::Interval(IntervalUnit::DayTime)323},324arrow_format::ipc::IntervalUnit::MonthDayNano => {325ArrowDataType::Interval(IntervalUnit::MonthDayNano)326},327};328(dtype, IpcField::default())329},330Duration(duration) => {331let time_unit = deserialize_timeunit(duration.unit()?)?;332(ArrowDataType::Duration(time_unit), IpcField::default())333},334Decimal(decimal) => {335let bit_width: usize = decimal336.bit_width()?337.try_into()338.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;339let precision: usize = decimal340.precision()?341.try_into()342.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;343let scale: usize = decimal344.scale()?345.try_into()346.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;347348let dtype = match bit_width {34932 => ArrowDataType::Decimal32(precision, scale),35064 => ArrowDataType::Decimal64(precision, scale),351128 => ArrowDataType::Decimal(precision, scale),352256 => ArrowDataType::Decimal256(precision, scale),353_ => return Err(polars_err!(oos = OutOfSpecKind::NegativeFooterLength)),354};355356(dtype, IpcField::default())357},358List(_) => deserialize_list(field)?,359LargeList(_) => deserialize_large_list(field)?,360FixedSizeList(list) => deserialize_fixed_size_list(list, field)?,361Struct(_) => deserialize_struct(field)?,362Union(union_) => deserialize_union(union_, field)?,363Map(map) => deserialize_map(map, field)?,364RunEndEncoded(_) => todo!(),365LargeListView(_) | ListView(_) => todo!(),366})367}368369/// Deserialize an flatbuffers-encoded Schema message into [`ArrowSchema`] and [`IpcSchema`].370pub fn deserialize_schema(371message: &[u8],372) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {373let message = arrow_format::ipc::MessageRef::read_as_root(message)374.map_err(|err| polars_err!(oos = format!("Unable deserialize message: {err:?}")))?;375376let schema = match message377.header()?378.ok_or_else(|| polars_err!(oos = "Unable to convert header to a schema".to_string()))?379{380arrow_format::ipc::MessageHeaderRef::Schema(schema) => PolarsResult::Ok(schema),381_ => polars_bail!(ComputeError: "The message is expected to be a Schema message"),382}?;383384fb_to_schema(schema)385}386387/// Deserialize the raw Schema table from IPC format to Schema data type388pub(super) fn fb_to_schema(389schema: arrow_format::ipc::SchemaRef,390) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {391let fields = schema392.fields()?393.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?;394395let mut arrow_schema = ArrowSchema::with_capacity(fields.len());396let mut ipc_fields = Vec::with_capacity(fields.len());397398for field in fields {399let (field, ipc_field) = deserialize_field(field?)?;400arrow_schema.insert(field.name.clone(), field);401ipc_fields.push(ipc_field);402}403404let is_little_endian = match schema.endianness()? {405arrow_format::ipc::Endianness::Little => true,406arrow_format::ipc::Endianness::Big => false,407};408409let custom_schema_metadata = match schema.custom_metadata()? {410None => None,411Some(metadata) => {412let metadata: Metadata = metadata413.into_iter()414.filter_map(|kv_result| {415// FIXME: silently hiding errors here416let kv_ref = kv_result.ok()?;417Some((kv_ref.key().ok()??.into(), kv_ref.value().ok()??.into()))418})419.collect();420421if metadata.is_empty() {422None423} else {424Some(metadata)425}426},427};428429Ok((430arrow_schema,431IpcSchema {432fields: ipc_fields,433is_little_endian,434},435custom_schema_metadata,436))437}438439pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> PolarsResult<StreamMetadata> {440let message = arrow_format::ipc::MessageRef::read_as_root(meta)441.map_err(|err| polars_err!(oos = format!("Unable to get root as message: {err:?}")))?;442let version = message.version()?;443// message header is a Schema, so read it444let header = message445.header()?446.ok_or_else(|| polars_err!(oos = "Unable to read the first IPC message"))?;447let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header {448schema449} else {450polars_bail!(oos = "The first IPC message of the stream must be a schema")451};452let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(schema)?;453454Ok(StreamMetadata {455schema,456version,457ipc_schema,458custom_schema_metadata,459})460}461462463