Path: blob/main/crates/polars-arrow/src/array/union/mod.rs
6939 views
use polars_error::{PolarsResult, polars_bail, polars_err};12use super::{Array, Splitable, new_empty_array, new_null_array};3use crate::bitmap::Bitmap;4use crate::buffer::Buffer;5use crate::datatypes::{ArrowDataType, Field, UnionMode};6use crate::scalar::{Scalar, new_scalar};78mod ffi;9pub(super) mod fmt;10mod iterator;1112type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode);1314/// [`UnionArray`] represents an array whose each slot can contain different values.15///16// How to read a value at slot i:17// ```18// let index = self.types()[i] as usize;19// let field = self.fields()[index];20// let offset = self.offsets().map(|x| x[index]).unwrap_or(i);21// let field = field.as_any().downcast to correct type;22// let value = field.value(offset);23// ```24#[derive(Clone)]25pub struct UnionArray {26// Invariant: every item in `types` is `> 0 && < fields.len()`27types: Buffer<i8>,28// Invariant: `map.len() == fields.len()`29// Invariant: every item in `map` is `> 0 && < fields.len()`30map: Option<[usize; 127]>,31fields: Vec<Box<dyn Array>>,32// Invariant: when set, `offsets.len() == types.len()`33offsets: Option<Buffer<i32>>,34dtype: ArrowDataType,35offset: usize,36}3738impl UnionArray {39/// Returns a new [`UnionArray`].40/// # Errors41/// This function errors iff:42/// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Union`].43/// * the fields's len is different from the `dtype`'s children's length44/// * The number of `fields` is larger than `i8::MAX`45/// * any of the values's data type is different from its corresponding children' data type46pub fn try_new(47dtype: ArrowDataType,48types: Buffer<i8>,49fields: Vec<Box<dyn Array>>,50offsets: Option<Buffer<i32>>,51) -> PolarsResult<Self> {52let (f, ids, mode) = Self::try_get_all(&dtype)?;5354if f.len() != fields.len() {55polars_bail!(ComputeError: "the number of `fields` must equal the number of children fields in DataType::Union")56};57let number_of_fields: i8 = fields.len().try_into().map_err(58|_| polars_err!(ComputeError: "the number of `fields` cannot be larger than i8::MAX"),59)?;6061f62.iter().map(|a| a.dtype())63.zip(fields.iter().map(|a| a.dtype()))64.enumerate()65.try_for_each(|(index, (dtype, child))| {66if dtype != child {67polars_bail!(ComputeError:68"the children DataTypes of a UnionArray must equal the children data types.69However, the field {index} has data type {dtype:?} but the value has data type {child:?}"70)71} else {72Ok(())73}74})?;7576if let Some(offsets) = &offsets {77if offsets.len() != types.len() {78polars_bail!(ComputeError:79"in a UnionArray, the offsets' length must be equal to the number of types"80)81}82}83if offsets.is_none() != mode.is_sparse() {84polars_bail!(ComputeError:85"in a sparse UnionArray, the offsets must be set (and vice-versa)",86)87}8889// build hash90let map = if let Some(&ids) = ids.as_ref() {91if ids.len() != fields.len() {92polars_bail!(ComputeError:93"in a union, when the ids are set, their length must be equal to the number of fields",94)95}9697// example:98// * types = [5, 7, 5, 7, 7, 7, 5, 7, 7, 5, 5]99// * ids = [5, 7]100// => hash = [0, 0, 0, 0, 0, 0, 1, 0, ...]101let mut hash = [0; 127];102103for (pos, &id) in ids.iter().enumerate() {104if !(0..=127).contains(&id) {105polars_bail!(ComputeError:106"in a union, when the ids are set, every id must belong to [0, 128[",107)108}109hash[id as usize] = pos;110}111112types.iter().try_for_each(|&type_| {113if type_ < 0 {114polars_bail!(ComputeError:115"in a union, when the ids are set, every type must be >= 0"116)117}118let id = hash[type_ as usize];119if id >= fields.len() {120polars_bail!(ComputeError:121"in a union, when the ids are set, each id must be smaller than the number of fields."122)123} else {124Ok(())125}126})?;127128Some(hash)129} else {130// SAFETY: every type in types is smaller than number of fields131let mut is_valid = true;132for &type_ in types.iter() {133if type_ < 0 || type_ >= number_of_fields {134is_valid = false135}136}137if !is_valid {138polars_bail!(ComputeError:139"every type in `types` must be larger than 0 and smaller than the number of fields.",140)141}142143None144};145146Ok(Self {147dtype,148map,149fields,150offsets,151types,152offset: 0,153})154}155156/// Returns a new [`UnionArray`].157/// # Panics158/// This function panics iff:159/// * `dtype`'s physical type is not [`crate::datatypes::PhysicalType::Union`].160/// * the fields's len is different from the `dtype`'s children's length161/// * any of the values's data type is different from its corresponding children' data type162pub fn new(163dtype: ArrowDataType,164types: Buffer<i8>,165fields: Vec<Box<dyn Array>>,166offsets: Option<Buffer<i32>>,167) -> Self {168Self::try_new(dtype, types, fields, offsets).unwrap()169}170171/// Creates a new null [`UnionArray`].172pub fn new_null(dtype: ArrowDataType, length: usize) -> Self {173if let ArrowDataType::Union(u) = &dtype {174let fields = u175.fields176.iter()177.map(|x| new_null_array(x.dtype().clone(), length))178.collect();179180let offsets = if u.mode.is_sparse() {181None182} else {183Some((0..length as i32).collect::<Vec<_>>().into())184};185186// all from the same field187let types = vec![0i8; length].into();188189Self::new(dtype, types, fields, offsets)190} else {191panic!("Union struct must be created with the corresponding Union DataType")192}193}194195/// Creates a new empty [`UnionArray`].196pub fn new_empty(dtype: ArrowDataType) -> Self {197if let ArrowDataType::Union(u) = dtype.to_logical_type() {198let fields = u199.fields200.iter()201.map(|x| new_empty_array(x.dtype().clone()))202.collect();203204let offsets = if u.mode.is_sparse() {205None206} else {207Some(Buffer::default())208};209210Self {211dtype,212map: None,213fields,214offsets,215types: Buffer::new(),216offset: 0,217}218} else {219panic!("Union struct must be created with the corresponding Union DataType")220}221}222}223224impl UnionArray {225/// Returns a slice of this [`UnionArray`].226/// # Implementation227/// This operation is `O(F)` where `F` is the number of fields.228/// # Panic229/// This function panics iff `offset + length > self.len()`.230#[inline]231pub fn slice(&mut self, offset: usize, length: usize) {232assert!(233offset + length <= self.len(),234"the offset of the new array cannot exceed the existing length"235);236unsafe { self.slice_unchecked(offset, length) }237}238239/// Returns a slice of this [`UnionArray`].240/// # Implementation241/// This operation is `O(F)` where `F` is the number of fields.242///243/// # Safety244/// The caller must ensure that `offset + length <= self.len()`.245#[inline]246pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {247debug_assert!(offset + length <= self.len());248249self.types.slice_unchecked(offset, length);250if let Some(offsets) = self.offsets.as_mut() {251offsets.slice_unchecked(offset, length)252}253self.offset += offset;254}255256impl_sliced!();257impl_into_array!();258}259260impl UnionArray {261/// Returns the length of this array262#[inline]263pub fn len(&self) -> usize {264self.types.len()265}266267/// The optional offsets.268pub fn offsets(&self) -> Option<&Buffer<i32>> {269self.offsets.as_ref()270}271272/// The fields.273pub fn fields(&self) -> &Vec<Box<dyn Array>> {274&self.fields275}276277/// The types.278pub fn types(&self) -> &Buffer<i8> {279&self.types280}281282#[inline]283unsafe fn field_slot_unchecked(&self, index: usize) -> usize {284self.offsets()285.as_ref()286.map(|x| *x.get_unchecked(index) as usize)287.unwrap_or(index + self.offset)288}289290/// Returns the index and slot of the field to select from `self.fields`.291#[inline]292pub fn index(&self, index: usize) -> (usize, usize) {293assert!(index < self.len());294unsafe { self.index_unchecked(index) }295}296297/// Returns the index and slot of the field to select from `self.fields`.298/// The first value is guaranteed to be `< self.fields().len()`299///300/// # Safety301/// This function is safe iff `index < self.len`.302#[inline]303pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) {304debug_assert!(index < self.len());305// SAFETY: assumption of the function306let type_ = unsafe { *self.types.get_unchecked(index) };307// SAFETY: assumption of the struct308let type_ = self309.map310.as_ref()311.map(|map| unsafe { *map.get_unchecked(type_ as usize) })312.unwrap_or(type_ as usize);313// SAFETY: assumption of the function314let index = self.field_slot_unchecked(index);315(type_, index)316}317318/// Returns the slot `index` as a [`Scalar`].319/// # Panics320/// iff `index >= self.len()`321pub fn value(&self, index: usize) -> Box<dyn Scalar> {322assert!(index < self.len());323unsafe { self.value_unchecked(index) }324}325326/// Returns the slot `index` as a [`Scalar`].327///328/// # Safety329/// This function is safe iff `i < self.len`.330pub unsafe fn value_unchecked(&self, index: usize) -> Box<dyn Scalar> {331debug_assert!(index < self.len());332let (type_, index) = self.index_unchecked(index);333// SAFETY: assumption of the struct334debug_assert!(type_ < self.fields.len());335let field = self.fields.get_unchecked(type_).as_ref();336new_scalar(field, index)337}338}339340impl Array for UnionArray {341impl_common_array!();342343fn validity(&self) -> Option<&Bitmap> {344None345}346347fn with_validity(&self, _: Option<Bitmap>) -> Box<dyn Array> {348panic!("cannot set validity of a union array")349}350}351352impl UnionArray {353fn try_get_all(dtype: &ArrowDataType) -> PolarsResult<UnionComponents<'_>> {354match dtype.to_logical_type() {355ArrowDataType::Union(u) => Ok((&u.fields, u.ids.as_ref().map(|x| x.as_ref()), u.mode)),356_ => polars_bail!(ComputeError:357"The UnionArray requires a logical type of DataType::Union",358),359}360}361362fn get_all(dtype: &ArrowDataType) -> (&[Field], Option<&[i32]>, UnionMode) {363Self::try_get_all(dtype).unwrap()364}365366/// Returns all fields from [`ArrowDataType::Union`].367/// # Panic368/// Panics iff `dtype`'s logical type is not [`ArrowDataType::Union`].369pub fn get_fields(dtype: &ArrowDataType) -> &[Field] {370Self::get_all(dtype).0371}372373/// Returns whether the [`ArrowDataType::Union`] is sparse or not.374/// # Panic375/// Panics iff `dtype`'s logical type is not [`ArrowDataType::Union`].376pub fn is_sparse(dtype: &ArrowDataType) -> bool {377Self::get_all(dtype).2.is_sparse()378}379}380381impl Splitable for UnionArray {382fn check_bound(&self, offset: usize) -> bool {383offset <= self.len()384}385386unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) {387let (lhs_types, rhs_types) = unsafe { self.types.split_at_unchecked(offset) };388let (lhs_offsets, rhs_offsets) = self.offsets.as_ref().map_or((None, None), |v| {389let (lhs, rhs) = unsafe { v.split_at_unchecked(offset) };390(Some(lhs), Some(rhs))391});392393(394Self {395types: lhs_types,396map: self.map,397fields: self.fields.clone(),398offsets: lhs_offsets,399dtype: self.dtype.clone(),400offset: self.offset,401},402Self {403types: rhs_types,404map: self.map,405fields: self.fields.clone(),406offsets: rhs_offsets,407dtype: self.dtype.clone(),408offset: self.offset + offset,409},410)411}412}413414415