Path: blob/main/crates/polars-compute/src/gather/sublist/list.rs
6940 views
use arrow::array::{Array, ArrayRef, ListArray};1use arrow::legacy::prelude::*;2use arrow::legacy::trusted_len::TrustedLenPush;3use arrow::legacy::utils::CustomIterTools;4use arrow::offset::{Offsets, OffsetsBuffer};5use polars_utils::IdxSize;67use crate::gather::take_unchecked;89/// Get the indices that would result in a get operation on the lists values.10/// for example, consider this list:11/// ```text12/// [[1, 2, 3],13/// [4, 5],14/// [6]]15///16/// This contains the following values array:17/// [1, 2, 3, 4, 5, 6]18///19/// get index 020/// would lead to the following indexes:21/// [0, 3, 5].22/// if we use those in a take operation on the values array we get:23/// [1, 4, 6]24///25///26/// get index -127/// would lead to the following indexes:28/// [2, 4, 5].29/// if we use those in a take operation on the values array we get:30/// [3, 5, 6]31///32/// ```33fn sublist_get_indexes(arr: &ListArray<i64>, index: i64) -> IdxArr {34let offsets = arr.offsets().as_slice();35let mut iter = offsets.iter();3637// the indices can be sliced, so we should not start at 0.38let mut cum_offset = (*offsets.first().unwrap_or(&0)) as IdxSize;3940if let Some(mut previous) = iter.next().copied() {41if arr.null_count() == 0 {42iter.map(|&offset| {43let len = offset - previous;44previous = offset;45// make sure that empty lists don't get accessed46// and out of bounds return null47if len == 0 {48return None;49}50if index >= len {51cum_offset += len as IdxSize;52return None;53}5455let out = index56.negative_to_usize(len as usize)57.map(|idx| idx as IdxSize + cum_offset);58cum_offset += len as IdxSize;59out60})61.collect_trusted()62} else {63// we can ensure that validity is not none as we have null value.64let validity = arr.validity().unwrap();65iter.enumerate()66.map(|(i, &offset)| {67let len = offset - previous;68previous = offset;69// make sure that empty and null lists don't get accessed and return null.70// SAFETY, we are within bounds71if len == 0 || !unsafe { validity.get_bit_unchecked(i) } {72cum_offset += len as IdxSize;73return None;74}7576// make sure that out of bounds return null77if index >= len {78cum_offset += len as IdxSize;79return None;80}8182let out = index83.negative_to_usize(len as usize)84.map(|idx| idx as IdxSize + cum_offset);85cum_offset += len as IdxSize;86out87})88.collect_trusted()89}90} else {91IdxArr::from_slice([])92}93}9495pub fn sublist_get(arr: &ListArray<i64>, index: i64) -> ArrayRef {96let take_by = sublist_get_indexes(arr, index);97let values = arr.values();98// SAFETY:99// the indices we generate are in bounds100unsafe { take_unchecked(&**values, &take_by) }101}102103/// Check if an index is out of bounds for at least one sublist.104pub fn index_is_oob(arr: &ListArray<i64>, index: i64) -> bool {105if arr.null_count() == 0 {106arr.offsets()107.lengths()108.any(|len| index.negative_to_usize(len).is_none())109} else {110arr.offsets()111.lengths()112.zip(arr.validity().unwrap())113.any(|(len, valid)| {114if valid {115index.negative_to_usize(len).is_none()116} else {117// skip nulls118false119}120})121}122}123124/// Convert a list `[1, 2, 3]` to a list type of `[[1], [2], [3]]`125pub fn array_to_unit_list(array: ArrayRef) -> ListArray<i64> {126let len = array.len();127let mut offsets = Vec::with_capacity(len + 1);128// SAFETY: we allocated enough129unsafe {130offsets.push_unchecked(0i64);131132for _ in 0..len {133offsets.push_unchecked(offsets.len() as i64)134}135};136137// SAFETY:138// offsets are monotonically increasing139unsafe {140let offsets: OffsetsBuffer<i64> = Offsets::new_unchecked(offsets).into();141let dtype = ListArray::<i64>::default_datatype(array.dtype().clone());142ListArray::<i64>::new(dtype, offsets, array, None)143}144}145146#[cfg(test)]147mod test {148use arrow::array::{Int32Array, PrimitiveArray};149use arrow::datatypes::ArrowDataType;150151use super::*;152153fn get_array() -> ListArray<i64> {154let values = Int32Array::from_slice([1, 2, 3, 4, 5, 6]);155let offsets = OffsetsBuffer::try_from(vec![0i64, 3, 5, 6]).unwrap();156157let dtype = ListArray::<i64>::default_datatype(ArrowDataType::Int32);158ListArray::<i64>::new(dtype, offsets, Box::new(values), None)159}160161#[test]162fn test_sublist_get_indexes() {163let arr = get_array();164let out = sublist_get_indexes(&arr, 0);165assert_eq!(out.values().as_slice(), &[0, 3, 5]);166let out = sublist_get_indexes(&arr, -1);167assert_eq!(out.values().as_slice(), &[2, 4, 5]);168let out = sublist_get_indexes(&arr, 3);169assert_eq!(out.null_count(), 3);170171let values = Int32Array::from_iter([172Some(1),173Some(1),174Some(3),175Some(4),176Some(5),177Some(6),178Some(7),179Some(8),180Some(9),181None,182Some(11),183]);184let offsets = OffsetsBuffer::try_from(vec![0i64, 1, 2, 3, 6, 9, 11]).unwrap();185186let dtype = ListArray::<i64>::default_datatype(ArrowDataType::Int32);187let arr = ListArray::<i64>::new(dtype, offsets, Box::new(values), None);188189let out = sublist_get_indexes(&arr, 1);190assert_eq!(191out.into_iter().collect::<Vec<_>>(),192&[None, None, None, Some(4), Some(7), Some(10)]193);194}195196#[test]197fn test_sublist_get() {198let arr = get_array();199200let out = sublist_get(&arr, 0);201let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();202203assert_eq!(out.values().as_slice(), &[1, 4, 6]);204let out = sublist_get(&arr, -1);205let out = out.as_any().downcast_ref::<PrimitiveArray<i32>>().unwrap();206assert_eq!(out.values().as_slice(), &[3, 5, 6]);207}208}209210211