Path: blob/main/crates/polars-arrow/src/io/avro/read/deserialize.rs
7884 views
use std::sync::Arc;12use avro_schema::file::Block;3use avro_schema::schema::{Enum, Field as AvroField, Record, Schema as AvroSchema};4use polars_error::{PolarsResult, polars_bail, polars_err};56use super::nested::*;7use super::util;8use crate::array::*;9use crate::datatypes::*;10use crate::record_batch::RecordBatchT;11use crate::types::months_days_ns;12use crate::with_match_primitive_type_full;1314fn make_mutable(15dtype: &ArrowDataType,16avro_field: Option<&AvroSchema>,17capacity: usize,18) -> PolarsResult<Box<dyn MutableArray>> {19Ok(match dtype.to_physical_type() {20PhysicalType::Boolean => {21Box::new(MutableBooleanArray::with_capacity(capacity)) as Box<dyn MutableArray>22},23PhysicalType::Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {24Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(dtype.clone()))25as Box<dyn MutableArray>26}),27PhysicalType::Binary => {28Box::new(MutableBinaryArray::<i32>::with_capacity(capacity)) as Box<dyn MutableArray>29},30PhysicalType::Utf8 => {31Box::new(MutableUtf8Array::<i32>::with_capacity(capacity)) as Box<dyn MutableArray>32},33PhysicalType::Dictionary(_) => {34if let Some(AvroSchema::Enum(Enum { symbols, .. })) = avro_field {35let values = Utf8Array::<i32>::from_slice(symbols);36Box::new(FixedItemsUtf8Dictionary::with_capacity(values, capacity))37as Box<dyn MutableArray>38} else {39unreachable!()40}41},42_ => match dtype {43ArrowDataType::List(inner) => {44let values = make_mutable(inner.dtype(), None, 0)?;45Box::new(DynMutableListArray::<i32>::new_from(46values,47dtype.clone(),48capacity,49)) as Box<dyn MutableArray>50},51ArrowDataType::FixedSizeBinary(size) => {52Box::new(MutableFixedSizeBinaryArray::with_capacity(*size, capacity))53as Box<dyn MutableArray>54},55ArrowDataType::Struct(fields) => {56let values = fields57.iter()58.map(|field| make_mutable(field.dtype(), None, capacity))59.collect::<PolarsResult<Vec<_>>>()?;60Box::new(DynMutableStructArray::new(values, dtype.clone())) as Box<dyn MutableArray>61},62ArrowDataType::Extension(ext) => make_mutable(&ext.inner, avro_field, capacity)?,63other => {64polars_bail!(nyi = "Deserializing type {other:#?} is still not implemented")65},66},67})68}6970fn is_union_null_first(avro_field: &AvroSchema) -> bool {71if let AvroSchema::Union(schemas) = avro_field {72schemas[0] == AvroSchema::Null73} else {74unreachable!()75}76}7778fn deserialize_item<'a>(79array: &mut dyn MutableArray,80is_nullable: bool,81avro_field: &AvroSchema,82mut block: &'a [u8],83) -> PolarsResult<&'a [u8]> {84if is_nullable {85let variant = util::zigzag_i64(&mut block)?;86let is_null_first = is_union_null_first(avro_field);87if is_null_first && variant == 0 || !is_null_first && variant != 0 {88array.push_null();89return Ok(block);90}91}92deserialize_value(array, avro_field, block)93}9495fn deserialize_value<'a>(96array: &mut dyn MutableArray,97avro_field: &AvroSchema,98mut block: &'a [u8],99) -> PolarsResult<&'a [u8]> {100let dtype = array.dtype();101match dtype {102ArrowDataType::List(inner) => {103let is_nullable = inner.is_nullable;104let avro_inner = match avro_field {105AvroSchema::Array(inner) => inner.as_ref(),106AvroSchema::Union(u) => match &u.as_slice() {107&[AvroSchema::Array(inner), _] | &[_, AvroSchema::Array(inner)] => {108inner.as_ref()109},110_ => unreachable!(),111},112_ => unreachable!(),113};114115let array = array116.as_mut_any()117.downcast_mut::<DynMutableListArray<i32>>()118.unwrap();119// Arrays are encoded as a series of blocks.120loop {121// Each block consists of a long count value, followed by that many array items.122let len = util::zigzag_i64(&mut block)?;123let len = if len < 0 {124// Avro spec: If a block's count is negative, its absolute value is used,125// and the count is followed immediately by a long block size indicating the number of bytes in the block. This block size permits fast skipping through data, e.g., when projecting a record to a subset of its fields.126let _ = util::zigzag_i64(&mut block)?;127128-len129} else {130len131};132133// A block with count zero indicates the end of the array.134if len == 0 {135break;136}137138// Each item is encoded per the array’s item schema.139let values = array.mut_values();140for _ in 0..len {141block = deserialize_item(values, is_nullable, avro_inner, block)?;142}143}144array.try_push_valid()?;145},146ArrowDataType::Struct(inner_fields) => {147let fields = match avro_field {148AvroSchema::Record(Record { fields, .. }) => fields,149AvroSchema::Union(u) => match &u.as_slice() {150&[AvroSchema::Record(Record { fields, .. }), _]151| &[_, AvroSchema::Record(Record { fields, .. })] => fields,152_ => unreachable!(),153},154_ => unreachable!(),155};156157let is_nullable = inner_fields158.iter()159.map(|x| x.is_nullable)160.collect::<Vec<_>>();161let array = array162.as_mut_any()163.downcast_mut::<DynMutableStructArray>()164.unwrap();165166for (index, (field, is_nullable)) in fields.iter().zip(is_nullable.iter()).enumerate() {167let values = array.mut_values(index);168block = deserialize_item(values, *is_nullable, &field.schema, block)?;169}170array.try_push_valid()?;171},172_ => match dtype.to_physical_type() {173PhysicalType::Boolean => {174let is_valid = block[0] == 1;175block = &block[1..];176let array = array177.as_mut_any()178.downcast_mut::<MutableBooleanArray>()179.unwrap();180array.push(Some(is_valid))181},182PhysicalType::Primitive(primitive) => match primitive {183PrimitiveType::Int32 => {184let value = util::zigzag_i64(&mut block)? as i32;185let array = array186.as_mut_any()187.downcast_mut::<MutablePrimitiveArray<i32>>()188.unwrap();189array.push(Some(value))190},191PrimitiveType::Int64 => {192let value = util::zigzag_i64(&mut block)?;193let array = array194.as_mut_any()195.downcast_mut::<MutablePrimitiveArray<i64>>()196.unwrap();197array.push(Some(value))198},199PrimitiveType::Float32 => {200let value = f32::from_le_bytes(block[..size_of::<f32>()].try_into().unwrap());201block = &block[size_of::<f32>()..];202let array = array203.as_mut_any()204.downcast_mut::<MutablePrimitiveArray<f32>>()205.unwrap();206array.push(Some(value))207},208PrimitiveType::Float64 => {209let value = f64::from_le_bytes(block[..size_of::<f64>()].try_into().unwrap());210block = &block[size_of::<f64>()..];211let array = array212.as_mut_any()213.downcast_mut::<MutablePrimitiveArray<f64>>()214.unwrap();215array.push(Some(value))216},217PrimitiveType::MonthDayNano => {218// https://avro.apache.org/docs/current/spec.html#Duration219// 12 bytes, months, days, millis in LE220let data = &block[..12];221block = &block[12..];222223let value = months_days_ns::new(224i32::from_le_bytes([data[0], data[1], data[2], data[3]]),225i32::from_le_bytes([data[4], data[5], data[6], data[7]]),226i32::from_le_bytes([data[8], data[9], data[10], data[11]]) as i64227* 1_000_000,228);229230let array = array231.as_mut_any()232.downcast_mut::<MutablePrimitiveArray<months_days_ns>>()233.unwrap();234array.push(Some(value))235},236PrimitiveType::Int128 => {237let avro_inner = match avro_field {238AvroSchema::Bytes(_) | AvroSchema::Fixed(_) => avro_field,239AvroSchema::Union(u) => match &u.as_slice() {240&[e, AvroSchema::Null] | &[AvroSchema::Null, e] => e,241_ => unreachable!(),242},243_ => unreachable!(),244};245let len = match avro_inner {246AvroSchema::Bytes(_) => {247util::zigzag_i64(&mut block)?.try_into().map_err(|_| {248polars_err!(249oos = "Avro format contains a non-usize number of bytes"250)251})?252},253AvroSchema::Fixed(b) => b.size,254_ => unreachable!(),255};256if len > 16 {257polars_bail!(oos = "Avro decimal bytes return more than 16 bytes")258}259let mut bytes = [0u8; 16];260bytes[..len].copy_from_slice(&block[..len]);261block = &block[len..];262let data = i128::from_be_bytes(bytes) >> (8 * (16 - len));263let array = array264.as_mut_any()265.downcast_mut::<MutablePrimitiveArray<i128>>()266.unwrap();267array.push(Some(data))268},269_ => unreachable!(),270},271PhysicalType::Utf8 => {272let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| {273polars_err!(oos = "Avro format contains a non-usize number of bytes")274})?;275let data = simdutf8::basic::from_utf8(&block[..len])?;276block = &block[len..];277278let array = array279.as_mut_any()280.downcast_mut::<MutableUtf8Array<i32>>()281.unwrap();282array.push(Some(data))283},284PhysicalType::Binary => {285let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| {286polars_err!(oos = "Avro format contains a non-usize number of bytes")287})?;288let data = &block[..len];289block = &block[len..];290291let array = array292.as_mut_any()293.downcast_mut::<MutableBinaryArray<i32>>()294.unwrap();295array.push(Some(data));296},297PhysicalType::FixedSizeBinary => {298let array = array299.as_mut_any()300.downcast_mut::<MutableFixedSizeBinaryArray>()301.unwrap();302let len = array.size();303let data = &block[..len];304block = &block[len..];305array.push(Some(data));306},307PhysicalType::Dictionary(_) => {308let index = util::zigzag_i64(&mut block)? as i32;309let array = array310.as_mut_any()311.downcast_mut::<FixedItemsUtf8Dictionary>()312.unwrap();313array.push_valid(index);314},315_ => todo!(),316},317};318Ok(block)319}320321fn skip_item<'a>(322field: &Field,323avro_field: &AvroSchema,324mut block: &'a [u8],325) -> PolarsResult<&'a [u8]> {326if field.is_nullable {327let variant = util::zigzag_i64(&mut block)?;328let is_null_first = is_union_null_first(avro_field);329if is_null_first && variant == 0 || !is_null_first && variant != 0 {330return Ok(block);331}332}333match &field.dtype {334ArrowDataType::List(inner) => {335let avro_inner = match avro_field {336AvroSchema::Array(inner) => inner.as_ref(),337AvroSchema::Union(u) => match &u.as_slice() {338&[AvroSchema::Array(inner), _] | &[_, AvroSchema::Array(inner)] => {339inner.as_ref()340},341_ => unreachable!(),342},343_ => unreachable!(),344};345346loop {347let len = util::zigzag_i64(&mut block)?;348let (len, bytes) = if len < 0 {349// Avro spec: If a block's count is negative, its absolute value is used,350// and the count is followed immediately by a long block size indicating the number of bytes in the block. This block size permits fast skipping through data, e.g., when projecting a record to a subset of its fields.351let bytes = util::zigzag_i64(&mut block)?;352353(-len, Some(bytes))354} else {355(len, None)356};357358let bytes: Option<usize> = bytes359.map(|bytes| {360bytes361.try_into()362.map_err(|_| polars_err!(oos = "Avro block size negative or too large"))363})364.transpose()?;365366if len == 0 {367break;368}369370if let Some(bytes) = bytes {371block = &block[bytes..];372} else {373for _ in 0..len {374block = skip_item(inner, avro_inner, block)?;375}376}377}378},379ArrowDataType::Struct(inner_fields) => {380let fields = match avro_field {381AvroSchema::Record(Record { fields, .. }) => fields,382AvroSchema::Union(u) => match &u.as_slice() {383&[AvroSchema::Record(Record { fields, .. }), _]384| &[_, AvroSchema::Record(Record { fields, .. })] => fields,385_ => unreachable!(),386},387_ => unreachable!(),388};389390for (field, avro_field) in inner_fields.iter().zip(fields.iter()) {391block = skip_item(field, &avro_field.schema, block)?;392}393},394_ => match field.dtype.to_physical_type() {395PhysicalType::Boolean => {396let _ = block[0] == 1;397block = &block[1..];398},399PhysicalType::Primitive(primitive) => match primitive {400PrimitiveType::Int32 => {401let _ = util::zigzag_i64(&mut block)?;402},403PrimitiveType::Int64 => {404let _ = util::zigzag_i64(&mut block)?;405},406PrimitiveType::Float32 => {407block = &block[size_of::<f32>()..];408},409PrimitiveType::Float64 => {410block = &block[size_of::<f64>()..];411},412PrimitiveType::MonthDayNano => {413block = &block[12..];414},415PrimitiveType::Int128 => {416let avro_inner = match avro_field {417AvroSchema::Bytes(_) | AvroSchema::Fixed(_) => avro_field,418AvroSchema::Union(u) => match &u.as_slice() {419&[e, AvroSchema::Null] | &[AvroSchema::Null, e] => e,420_ => unreachable!(),421},422_ => unreachable!(),423};424let len = match avro_inner {425AvroSchema::Bytes(_) => {426util::zigzag_i64(&mut block)?.try_into().map_err(|_| {427polars_err!(428oos = "Avro format contains a non-usize number of bytes"429)430})?431},432AvroSchema::Fixed(b) => b.size,433_ => unreachable!(),434};435block = &block[len..];436},437_ => unreachable!(),438},439PhysicalType::Utf8 | PhysicalType::Binary => {440let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| {441polars_err!(oos = "Avro format contains a non-usize number of bytes")442})?;443block = &block[len..];444},445PhysicalType::FixedSizeBinary => {446let len = if let ArrowDataType::FixedSizeBinary(len) = &field.dtype {447*len448} else {449unreachable!()450};451452block = &block[len..];453},454PhysicalType::Dictionary(_) => {455let _ = util::zigzag_i64(&mut block)? as i32;456},457_ => todo!(),458},459}460Ok(block)461}462463/// Deserializes a [`Block`] assumed to be encoded according to [`AvroField`] into [`RecordBatchT`],464/// using `projection` to ignore `avro_fields`.465/// # Panics466/// `fields`, `avro_fields` and `projection` must have the same length.467pub fn deserialize(468block: &Block,469fields: &ArrowSchema,470avro_fields: &[AvroField],471projection: &[bool],472) -> PolarsResult<RecordBatchT<Box<dyn Array>>> {473assert_eq!(fields.len(), avro_fields.len());474assert_eq!(fields.len(), projection.len());475476let rows = block.number_of_rows;477let mut block = block.data.as_ref();478479// create mutables, one per field480let mut arrays: Vec<Box<dyn MutableArray>> = fields481.iter_values()482.zip(avro_fields.iter())483.zip(projection.iter())484.map(|((field, avro_field), projection)| {485if *projection {486make_mutable(&field.dtype, Some(&avro_field.schema), rows)487} else {488// just something; we are not going to use it489make_mutable(&ArrowDataType::Int32, None, 0)490}491})492.collect::<PolarsResult<_>>()?;493494// this is _the_ expensive transpose (rows -> columns)495for _ in 0..rows {496let iter = arrays497.iter_mut()498.zip(fields.iter_values())499.zip(avro_fields.iter())500.zip(projection.iter());501502for (((array, field), avro_field), projection) in iter {503block = if *projection {504deserialize_item(array.as_mut(), field.is_nullable, &avro_field.schema, block)505} else {506skip_item(field, &avro_field.schema, block)507}?508}509}510511let projected_schema = fields512.iter_values()513.zip(projection)514.filter_map(|(f, p)| (*p).then_some(f))515.cloned()516.collect();517518RecordBatchT::try_new(519rows,520Arc::new(projected_schema),521arrays522.iter_mut()523.zip(projection.iter())524.filter_map(|x| x.1.then(|| x.0))525.map(|array| array.as_box())526.collect(),527)528}529530531