Path: blob/main/crates/polars-ops/src/chunked_array/array/get.rs
8354 views
use polars_compute::gather::sublist::fixed_size_list::{1sub_fixed_size_list_get, sub_fixed_size_list_get_literal,2};3use polars_core::prelude::arity::{try_binary_to_series, try_unary_to_series};45use super::*;6use crate::series::convert_and_bound_idx_ca;78/// Get the value by literal index in the array.9/// So index `0` would return the first item of every sub-array10/// and index `-1` would return the last item of every sub-array11/// if an index is out of bounds, it will return a `None`.12pub fn array_get(13ca: &ArrayChunked,14index: &Int64Chunked,15null_on_oob: bool,16) -> PolarsResult<Series> {17polars_ensure!(ca.width() < IdxSize::MAX as usize, ComputeError: "`arr.get` not supported for such wide arrays");1819// Base case. No overflow.20if ca.width() * ca.len() < IdxSize::MAX as usize {21return array_get_impl(ca, index, null_on_oob);22}2324// If the array width * length would overflow. Do it part-by-part.25assert!(ca.len() != 1 || index.len() != 1);26let rows_per_slice = IdxSize::MAX as usize / ca.width();2728let mut ca = ca.clone();29let mut index = index.clone();30let current_ca;31let current_index;32if ca.len() == 1 {33current_ca = ca.clone();34} else {35(current_ca, ca) = ca.split_at(rows_per_slice as i64);36}37if index.len() == 1 {38current_index = index.clone();39} else {40(current_index, index) = index.split_at(rows_per_slice as i64);41}42let mut s = array_get_impl(¤t_ca, ¤t_index, null_on_oob)?;4344while !ca.is_empty() && !index.is_empty() {45let current_ca;46let current_index;47if ca.len() == 1 {48current_ca = ca.clone();49} else {50(current_ca, ca) = ca.split_at(rows_per_slice as i64);51}52if index.len() == 1 {53current_index = index.clone();54} else {55(current_index, index) = index.split_at(rows_per_slice as i64);56}57s.append_owned(array_get_impl(¤t_ca, ¤t_index, null_on_oob)?)?;58}5960Ok(s)61}6263fn array_get_impl(64ca: &ArrayChunked,65index: &Int64Chunked,66null_on_oob: bool,67) -> PolarsResult<Series> {68match index.len() {691 => {70if let Some(index) = index.get(0) {71let out = try_unary_to_series(ca, |arr| {72sub_fixed_size_list_get_literal(arr, index, null_on_oob)73})?;74unsafe { out.from_physical_unchecked(ca.inner_dtype()) }75} else {76Ok(Series::full_null(77ca.name().clone(),78ca.len(),79ca.inner_dtype(),80))81}82},8384len if len == ca.len() => {85let out = try_binary_to_series(ca, index, |arr, idx_arr| {86sub_fixed_size_list_get(arr, idx_arr, null_on_oob)87})?;88unsafe { out.from_physical_unchecked(ca.inner_dtype()) }89},9091_len if ca.len() == 1 => {92if let Some(arr) = ca.get(0) {93let idx = convert_and_bound_idx_ca(index, arr.len(), null_on_oob)?;94let s = Series::try_from((ca.name().clone(), vec![arr])).unwrap();95unsafe {96s.take_unchecked(&idx)97.from_physical_unchecked(ca.inner_dtype())98}99} else {100Ok(Series::full_null(101ca.name().clone(),102ca.len(),103ca.inner_dtype(),104))105}106},107108len => polars_bail!(109ComputeError:110"`arr.get` expression got an index array of length {} while the array has {} elements",111len, ca.len()112),113}114}115116117