Path: blob/main/crates/polars-arrow/src/io/ipc/read/schema.rs
8446 views
use std::sync::Arc;12use arrow_format::ipc::planus::ReadAsRoot;3use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef};4use polars_error::{PolarsResult, polars_bail, polars_err};5use polars_utils::pl_str::PlSmallStr;67use super::super::{IpcField, IpcSchema};8use super::{OutOfSpecKind, StreamMetadata};9use crate::datatypes::{10ArrowDataType, ArrowSchema, Extension, ExtensionType, Field, IntegerType, IntervalUnit,11Metadata, TimeUnit, UnionMode, UnionType, get_extension,12};1314fn try_unzip_vec<A, B, I: Iterator<Item = PolarsResult<(A, B)>>>(15iter: I,16) -> PolarsResult<(Vec<A>, Vec<B>)> {17let mut a = vec![];18let mut b = vec![];19for maybe_item in iter {20let (a_i, b_i) = maybe_item?;21a.push(a_i);22b.push(b_i);23}2425Ok((a, b))26}2728fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Field, IpcField)> {29let metadata = read_metadata(&ipc_field)?;3031let extension = metadata.as_ref().and_then(get_extension);3233let (dtype, ipc_field_) = get_dtype(ipc_field, extension, true)?;3435let field = Field {36name: PlSmallStr::from_str(37ipc_field38.name()?39.ok_or_else(|| polars_err!(oos = "Every field in IPC must have a name"))?,40),41dtype,42is_nullable: ipc_field.nullable()?,43metadata: metadata.map(Arc::new),44};4546Ok((field, ipc_field_))47}4849fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult<Option<Metadata>> {50Ok(if let Some(list) = field.custom_metadata()? {51let mut metadata_map = Metadata::new();52for kv in list {53let kv = kv?;54if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) {55metadata_map.insert(PlSmallStr::from_str(k), PlSmallStr::from_str(v));56}57}58Some(metadata_map)59} else {60None61})62}6364fn deserialize_integer(int: arrow_format::ipc::IntRef) -> PolarsResult<IntegerType> {65Ok(match (int.bit_width()?, int.is_signed()?) {66(8, true) => IntegerType::Int8,67(8, false) => IntegerType::UInt8,68(16, true) => IntegerType::Int16,69(16, false) => IntegerType::UInt16,70(32, true) => IntegerType::Int32,71(32, false) => IntegerType::UInt32,72(64, true) => IntegerType::Int64,73(64, false) => IntegerType::UInt64,74(128, true) => IntegerType::Int128,75(128, false) => IntegerType::UInt128,76_ => polars_bail!(oos = "IPC: indexType can only be 8, 16, 32, 64 or 128."),77})78}7980fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> PolarsResult<TimeUnit> {81use arrow_format::ipc::TimeUnit::*;82Ok(match time_unit {83Second => TimeUnit::Second,84Millisecond => TimeUnit::Millisecond,85Microsecond => TimeUnit::Microsecond,86Nanosecond => TimeUnit::Nanosecond,87})88}8990fn deserialize_time(time: TimeRef) -> PolarsResult<(ArrowDataType, IpcField)> {91let unit = deserialize_timeunit(time.unit()?)?;9293let dtype = match (time.bit_width()?, unit) {94(32, TimeUnit::Second) => ArrowDataType::Time32(TimeUnit::Second),95(32, TimeUnit::Millisecond) => ArrowDataType::Time32(TimeUnit::Millisecond),96(64, TimeUnit::Microsecond) => ArrowDataType::Time64(TimeUnit::Microsecond),97(64, TimeUnit::Nanosecond) => ArrowDataType::Time64(TimeUnit::Nanosecond),98(bits, precision) => {99polars_bail!(ComputeError:100"Time type with bit width of {bits} and unit of {precision:?}"101)102},103};104Ok((dtype, IpcField::default()))105}106107fn deserialize_timestamp(timestamp: TimestampRef) -> PolarsResult<(ArrowDataType, IpcField)> {108let timezone = timestamp.timezone()?;109let time_unit = deserialize_timeunit(timestamp.unit()?)?;110Ok((111ArrowDataType::Timestamp(time_unit, timezone.map(PlSmallStr::from_str)),112IpcField::default(),113))114}115116fn deserialize_union(union_: UnionRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {117let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse);118let ids = union_.type_ids()?.map(|x| x.iter().collect());119120let fields = field121.children()?122.ok_or_else(|| polars_err!(oos = "IPC: Union must contain children"))?;123if fields.is_empty() {124polars_bail!(oos = "IPC: Union must contain at least one child");125}126127let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {128let (field, fields) = deserialize_field(field?)?;129Ok((field, fields))130}))?;131let ipc_field = IpcField {132fields: ipc_fields,133dictionary_id: None,134};135Ok((136ArrowDataType::Union(Box::new(UnionType { fields, ids, mode })),137ipc_field,138))139}140141fn deserialize_map(map: MapRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {142let is_sorted = map.keys_sorted()?;143144let children = field145.children()?146.ok_or_else(|| polars_err!(oos = "IPC: Map must contain children"))?;147let inner = children148.get(0)149.ok_or_else(|| polars_err!(oos = "IPC: Map must contain one child"))??;150let (field, ipc_field) = deserialize_field(inner)?;151152let dtype = ArrowDataType::Map(Box::new(field), is_sorted);153Ok((154dtype,155IpcField {156fields: vec![ipc_field],157dictionary_id: None,158},159))160}161162fn deserialize_struct(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {163let fields = field164.children()?165.ok_or_else(|| polars_err!(oos = "IPC: Struct must contain children"))?;166let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {167let (field, fields) = deserialize_field(field?)?;168Ok((field, fields))169}))?;170let ipc_field = IpcField {171fields: ipc_fields,172dictionary_id: None,173};174Ok((ArrowDataType::Struct(fields), ipc_field))175}176177fn deserialize_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {178let children = field179.children()?180.ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;181let inner = children182.get(0)183.ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;184let (field, ipc_field) = deserialize_field(inner)?;185186Ok((187ArrowDataType::List(Box::new(field)),188IpcField {189fields: vec![ipc_field],190dictionary_id: None,191},192))193}194195fn deserialize_large_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {196let children = field197.children()?198.ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;199let inner = children200.get(0)201.ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;202let (field, ipc_field) = deserialize_field(inner)?;203204Ok((205ArrowDataType::LargeList(Box::new(field)),206IpcField {207fields: vec![ipc_field],208dictionary_id: None,209},210))211}212213fn deserialize_fixed_size_list(214list: FixedSizeListRef,215field: FieldRef,216) -> PolarsResult<(ArrowDataType, IpcField)> {217let children = field218.children()?219.ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain children"))?;220let inner = children221.get(0)222.ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain one child"))??;223let (field, ipc_field) = deserialize_field(inner)?;224225let size = list226.list_size()?227.try_into()228.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;229230Ok((231ArrowDataType::FixedSizeList(Box::new(field), size),232IpcField {233fields: vec![ipc_field],234dictionary_id: None,235},236))237}238239/// Get the Arrow data type from the flatbuffer Field table240fn get_dtype(241field: arrow_format::ipc::FieldRef,242extension: Extension,243may_be_dictionary: bool,244) -> PolarsResult<(ArrowDataType, IpcField)> {245if let Some(dictionary) = field.dictionary()? {246if may_be_dictionary {247let int = dictionary248.index_type()?249.ok_or_else(|| polars_err!(oos = "indexType is mandatory in Dictionary."))?;250let index_type = deserialize_integer(int)?;251let (inner, mut ipc_field) = get_dtype(field, extension, false)?;252ipc_field.dictionary_id = Some(dictionary.id()?);253return Ok((254ArrowDataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?),255ipc_field,256));257}258}259260if let Some(extension) = extension {261let (name, metadata) = extension;262let (dtype, fields) = get_dtype(field, None, false)?;263return Ok((264ArrowDataType::Extension(Box::new(ExtensionType {265name,266inner: dtype,267metadata,268})),269fields,270));271}272273let type_ = field274.type_()?275.ok_or_else(|| polars_err!(oos = "IPC: field type is mandatory"))?;276277use arrow_format::ipc::TypeRef::*;278Ok(match type_ {279Null(_) => (ArrowDataType::Null, IpcField::default()),280Bool(_) => (ArrowDataType::Boolean, IpcField::default()),281Int(int) => {282let dtype = deserialize_integer(int)?.into();283(dtype, IpcField::default())284},285Binary(_) => (ArrowDataType::Binary, IpcField::default()),286LargeBinary(_) => (ArrowDataType::LargeBinary, IpcField::default()),287Utf8(_) => (ArrowDataType::Utf8, IpcField::default()),288LargeUtf8(_) => (ArrowDataType::LargeUtf8, IpcField::default()),289BinaryView(_) => (ArrowDataType::BinaryView, IpcField::default()),290Utf8View(_) => (ArrowDataType::Utf8View, IpcField::default()),291FixedSizeBinary(fixed) => (292ArrowDataType::FixedSizeBinary(293fixed294.byte_width()?295.try_into()296.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?,297),298IpcField::default(),299),300FloatingPoint(float) => {301let dtype = match float.precision()? {302arrow_format::ipc::Precision::Half => ArrowDataType::Float16,303arrow_format::ipc::Precision::Single => ArrowDataType::Float32,304arrow_format::ipc::Precision::Double => ArrowDataType::Float64,305};306(dtype, IpcField::default())307},308Date(date) => {309let dtype = match date.unit()? {310arrow_format::ipc::DateUnit::Day => ArrowDataType::Date32,311arrow_format::ipc::DateUnit::Millisecond => ArrowDataType::Date64,312};313(dtype, IpcField::default())314},315Time(time) => deserialize_time(time)?,316Timestamp(timestamp) => deserialize_timestamp(timestamp)?,317Interval(interval) => {318let dtype = match interval.unit()? {319arrow_format::ipc::IntervalUnit::YearMonth => {320ArrowDataType::Interval(IntervalUnit::YearMonth)321},322arrow_format::ipc::IntervalUnit::DayTime => {323ArrowDataType::Interval(IntervalUnit::DayTime)324},325arrow_format::ipc::IntervalUnit::MonthDayNano => {326ArrowDataType::Interval(IntervalUnit::MonthDayNano)327},328};329(dtype, IpcField::default())330},331Duration(duration) => {332let time_unit = deserialize_timeunit(duration.unit()?)?;333(ArrowDataType::Duration(time_unit), IpcField::default())334},335Decimal(decimal) => {336let bit_width: usize = decimal337.bit_width()?338.try_into()339.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;340let precision: usize = decimal341.precision()?342.try_into()343.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;344let scale: usize = decimal345.scale()?346.try_into()347.map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;348349let dtype = match bit_width {35032 => ArrowDataType::Decimal32(precision, scale),35164 => ArrowDataType::Decimal64(precision, scale),352128 => ArrowDataType::Decimal(precision, scale),353256 => ArrowDataType::Decimal256(precision, scale),354_ => return Err(polars_err!(oos = OutOfSpecKind::NegativeFooterLength)),355};356357(dtype, IpcField::default())358},359List(_) => deserialize_list(field)?,360LargeList(_) => deserialize_large_list(field)?,361FixedSizeList(list) => deserialize_fixed_size_list(list, field)?,362Struct(_) => deserialize_struct(field)?,363Union(union_) => deserialize_union(union_, field)?,364Map(map) => deserialize_map(map, field)?,365RunEndEncoded(_) => todo!(),366LargeListView(_) | ListView(_) => todo!(),367})368}369370/// Deserialize an flatbuffers-encoded Schema message into [`ArrowSchema`] and [`IpcSchema`].371pub fn deserialize_schema(372message: &[u8],373) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {374let message = arrow_format::ipc::MessageRef::read_as_root(message)375.map_err(|err| polars_err!(oos = format!("Unable deserialize message: {err:?}")))?;376377let schema = match message378.header()?379.ok_or_else(|| polars_err!(oos = "Unable to convert header to a schema".to_string()))?380{381arrow_format::ipc::MessageHeaderRef::Schema(schema) => PolarsResult::Ok(schema),382_ => polars_bail!(ComputeError: "The message is expected to be a Schema message"),383}?;384385fb_to_schema(schema)386}387388/// Deserialize the raw Schema table from IPC format to Schema data type389pub(super) fn fb_to_schema(390schema: arrow_format::ipc::SchemaRef,391) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {392let fields = schema393.fields()?394.ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?;395396let mut arrow_schema = ArrowSchema::with_capacity(fields.len());397let mut ipc_fields = Vec::with_capacity(fields.len());398399for field in fields {400let (field, ipc_field) = deserialize_field(field?)?;401arrow_schema.insert(field.name.clone(), field);402ipc_fields.push(ipc_field);403}404405let is_little_endian = match schema.endianness()? {406arrow_format::ipc::Endianness::Little => true,407arrow_format::ipc::Endianness::Big => false,408};409410let custom_schema_metadata = match schema.custom_metadata()? {411None => None,412Some(metadata) => {413let metadata: Metadata = metadata414.into_iter()415.filter_map(|kv_result| {416// TODO: silently hiding errors here417let kv_ref = kv_result.ok()?;418Some((kv_ref.key().ok()??.into(), kv_ref.value().ok()??.into()))419})420.collect();421422if metadata.is_empty() {423None424} else {425Some(metadata)426}427},428};429430Ok((431arrow_schema,432IpcSchema {433fields: ipc_fields,434is_little_endian,435},436custom_schema_metadata,437))438}439440pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> PolarsResult<StreamMetadata> {441let message = arrow_format::ipc::MessageRef::read_as_root(meta)442.map_err(|err| polars_err!(oos = format!("Unable to get root as message: {err:?}")))?;443let version = message.version()?;444// message header is a Schema, so read it445let header = message446.header()?447.ok_or_else(|| polars_err!(oos = "Unable to read the first IPC message"))?;448let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header {449schema450} else {451polars_bail!(oos = "The first IPC message of the stream must be a schema")452};453let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(schema)?;454455Ok(StreamMetadata {456schema,457version,458ipc_schema,459custom_schema_metadata,460})461}462463464