Path: blob/main/crates/polars-expr/src/dispatch/array.rs
7884 views
use polars_core::error::{PolarsResult, polars_bail, polars_ensure, polars_err};1use polars_core::prelude::{Column, DataType, ExplodeOptions, IntoColumn, SortOptions};2use polars_ops::prelude::array::ArrayNameSpace;3#[cfg(feature = "array_to_struct")]4use polars_plan::dsl::DslNameGenerator;5use polars_plan::dsl::{ColumnsUdf, SpecialEq};6use polars_plan::plans::IRArrayFunction;7use polars_utils::pl_str::PlSmallStr;89use super::*;1011pub fn function_expr_to_udf(func: IRArrayFunction) -> SpecialEq<Arc<dyn ColumnsUdf>> {12use IRArrayFunction::*;13match func {14Concat => map_as_slice!(concat_arr),15Length => map!(length),16Min => map!(min),17Max => map!(max),18Sum => map!(sum),19ToList => map!(to_list),20Unique(stable) => map!(unique, stable),21NUnique => map!(n_unique),22Std(ddof) => map!(std, ddof),23Var(ddof) => map!(var, ddof),24Mean => map!(mean),25Median => map!(median),26#[cfg(feature = "array_any_all")]27Any => map!(any),28#[cfg(feature = "array_any_all")]29All => map!(all),30Sort(options) => map!(sort, options),31Reverse => map!(reverse),32ArgMin => map!(arg_min),33ArgMax => map!(arg_max),34Get(null_on_oob) => map_as_slice!(get, null_on_oob),35Join(ignore_nulls) => map_as_slice!(join, ignore_nulls),36#[cfg(feature = "is_in")]37Contains { nulls_equal } => map_as_slice!(contains, nulls_equal),38#[cfg(feature = "array_count")]39CountMatches => map_as_slice!(count_matches),40Shift => map_as_slice!(shift),41Explode(options) => map_as_slice!(explode, options),42Slice(offset, length) => map!(slice, offset, length),43#[cfg(feature = "array_to_struct")]44ToStruct(ng) => map!(arr_to_struct, ng.clone()),45}46}4748pub(super) fn length(s: &Column) -> PolarsResult<Column> {49let array = s.array()?;50let width = array.width();51let width = IdxSize::try_from(width)52.map_err(|_| polars_err!(bigidx, ctx = "array length", size = width))?;5354let mut c = Column::new_scalar(array.name().clone(), width.into(), array.len());55if let Some(validity) = array.rechunk_validity() {56let mut series = c.into_materialized_series().clone();5758// SAFETY: We keep datatypes intact and call compute_len afterwards.59let chunks = unsafe { series.chunks_mut() };60assert_eq!(chunks.len(), 1);6162chunks[0] = chunks[0].with_validity(Some(validity));6364series.compute_len();65c = series.into_column();66}6768Ok(c)69}7071pub(super) fn max(s: &Column) -> PolarsResult<Column> {72Ok(s.array()?.array_max().into())73}7475pub(super) fn min(s: &Column) -> PolarsResult<Column> {76Ok(s.array()?.array_min().into())77}7879pub(super) fn sum(s: &Column) -> PolarsResult<Column> {80s.array()?.array_sum().map(Column::from)81}8283pub(super) fn std(s: &Column, ddof: u8) -> PolarsResult<Column> {84s.array()?.array_std(ddof).map(Column::from)85}8687pub(super) fn var(s: &Column, ddof: u8) -> PolarsResult<Column> {88s.array()?.array_var(ddof).map(Column::from)89}9091pub(super) fn mean(s: &Column) -> PolarsResult<Column> {92s.array()?.array_mean().map(Column::from)93}9495pub(super) fn median(s: &Column) -> PolarsResult<Column> {96s.array()?.array_median().map(Column::from)97}9899pub(super) fn unique(s: &Column, stable: bool) -> PolarsResult<Column> {100let ca = s.array()?;101let out = if stable {102ca.array_unique_stable()103} else {104ca.array_unique()105};106out.map(|ca| ca.into_column())107}108109pub(super) fn n_unique(s: &Column) -> PolarsResult<Column> {110Ok(s.array()?.array_n_unique()?.into_column())111}112113pub(super) fn to_list(s: &Column) -> PolarsResult<Column> {114if let DataType::Array(inner, _) = s.dtype() {115s.cast(&DataType::List(inner.clone()))116} else {117polars_bail!(ComputeError: "expected array dtype")118}119}120121#[cfg(feature = "array_any_all")]122pub(super) fn any(s: &Column) -> PolarsResult<Column> {123s.array()?.array_any().map(Column::from)124}125126#[cfg(feature = "array_any_all")]127pub(super) fn all(s: &Column) -> PolarsResult<Column> {128s.array()?.array_all().map(Column::from)129}130131pub(super) fn sort(s: &Column, options: SortOptions) -> PolarsResult<Column> {132Ok(s.array()?.array_sort(options)?.into_column())133}134135pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {136Ok(s.array()?.array_reverse().into_column())137}138139pub(super) fn arg_min(s: &Column) -> PolarsResult<Column> {140Ok(s.array()?.array_arg_min().into_column())141}142143pub(super) fn arg_max(s: &Column) -> PolarsResult<Column> {144Ok(s.array()?.array_arg_max().into_column())145}146147pub(super) fn get(s: &[Column], null_on_oob: bool) -> PolarsResult<Column> {148let ca = s[0].array()?;149let index = s[1].cast(&DataType::Int64)?;150let index = index.i64().unwrap();151ca.array_get(index, null_on_oob).map(Column::from)152}153154pub(super) fn join(s: &[Column], ignore_nulls: bool) -> PolarsResult<Column> {155let ca = s[0].array()?;156let separator = s[1].str()?;157ca.array_join(separator, ignore_nulls).map(Column::from)158}159160#[cfg(feature = "is_in")]161pub(super) fn contains(s: &[Column], nulls_equal: bool) -> PolarsResult<Column> {162let array = &s[0];163let item = &s[1];164polars_ensure!(matches!(array.dtype(), DataType::Array(_, _)),165SchemaMismatch: "invalid series dtype: expected `Array`, got `{}`", array.dtype(),166);167let mut ca = polars_ops::series::is_in(168item.as_materialized_series(),169array.as_materialized_series(),170nulls_equal,171)?;172ca.rename(array.name().clone());173Ok(ca.into_column())174}175176#[cfg(feature = "array_count")]177pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {178let s = &args[0];179let element = &args[1];180polars_ensure!(181element.len() == 1,182ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}",183element.len()184);185let ca = s.array()?;186ca.array_count_matches(element.get(0).unwrap())187.map(Column::from)188}189190pub(super) fn shift(s: &[Column]) -> PolarsResult<Column> {191let ca = s[0].array()?;192let n = &s[1];193194ca.array_shift(n.as_materialized_series()).map(Column::from)195}196197pub(super) fn slice(s: &Column, offset: i64, length: i64) -> PolarsResult<Column> {198let ca = s.array()?;199ca.array_slice(offset, length).map(Column::from)200}201202fn explode(c: &[Column], options: ExplodeOptions) -> PolarsResult<Column> {203c[0].explode(options)204}205206fn concat_arr(args: &[Column]) -> PolarsResult<Column> {207let dtype = concat_arr_output_dtype(&mut args.iter().map(|c| (c.name().as_str(), c.dtype())))?;208209polars_ops::series::concat_arr::concat_arr(args, &dtype)210}211212/// Determine the output dtype of a `concat_arr` operation. Also performs validation to ensure input213/// dtypes are compatible.214fn concat_arr_output_dtype(215inputs: &mut dyn ExactSizeIterator<Item = (&str, &DataType)>,216) -> PolarsResult<DataType> {217#[allow(clippy::len_zero)]218if inputs.len() == 0 {219// should not be reachable - we did not set ALLOW_EMPTY_INPUTS220panic!();221}222223let mut inputs = inputs.map(|(name, dtype)| {224let (inner_dtype, width) = match dtype {225DataType::Array(inner, width) => (inner.as_ref(), *width),226dt => (dt, 1),227};228(name, dtype, inner_dtype, width)229});230let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap();231232for (col_name, dtype, inner_dtype, width) in inputs {233out_width += width;234235if inner_dtype != first_inner_dtype {236polars_bail!(237SchemaMismatch:238"concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \239input column (name: {}, dtype: {}), got {} instead for column {}",240first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name,241)242}243}244245Ok(DataType::Array(246Box::new(first_inner_dtype.clone()),247out_width,248))249}250251#[cfg(feature = "array_to_struct")]252fn arr_to_struct(s: &Column, name_generator: Option<DslNameGenerator>) -> PolarsResult<Column> {253use polars_ops::prelude::array::ToStruct;254255let name_generator =256name_generator.map(|f| Arc::new(move |i| f.call(i).map(PlSmallStr::from)) as Arc<_>);257s.array()?258.to_struct(name_generator)259.map(IntoColumn::into_column)260}261262263