Path: blob/main/crates/polars-plan/src/plans/aexpr/function_expr/array.rs
7889 views
use polars_core::utils::slice_offsets;1use polars_ops::chunked_array::array::*;23use super::*;45#[derive(Clone, Eq, PartialEq, Hash, Debug)]6#[cfg_attr(feature = "ir_serde", derive(serde::Serialize, serde::Deserialize))]7pub enum IRArrayFunction {8Length,9Min,10Max,11Sum,12ToList,13Unique(bool),14NUnique,15Std(u8),16Var(u8),17Mean,18Median,19#[cfg(feature = "array_any_all")]20Any,21#[cfg(feature = "array_any_all")]22All,23Sort(SortOptions),24Reverse,25ArgMin,26ArgMax,27Get(bool),28Join(bool),29#[cfg(feature = "is_in")]30Contains {31nulls_equal: bool,32},33#[cfg(feature = "array_count")]34CountMatches,35Shift,36Explode(ExplodeOptions),37Concat,38Slice(i64, i64),39#[cfg(feature = "array_to_struct")]40ToStruct(Option<DslNameGenerator>),41}4243impl<'a> FieldsMapper<'a> {44/// Validate that the dtype is an array.45pub fn ensure_is_array(self) -> PolarsResult<Self> {46let dt = self.args()[0].dtype();47polars_ensure!(48dt.is_array(),49InvalidOperation: format!("expected Array datatype for array operation, got: {:?}", dt)50);51Ok(self)52}53}5455impl IRArrayFunction {56pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {57use IRArrayFunction::*;5859match self {60Concat => Ok(Field::new(61mapper62.args()63.first()64.map_or(PlSmallStr::EMPTY, |x| x.name.clone()),65concat_arr_output_dtype(66&mut mapper.args().iter().map(|x| (x.name.as_str(), &x.dtype)),67)?,68)),69Length => mapper.ensure_is_array()?.with_dtype(IDX_DTYPE),70Min | Max => mapper71.ensure_is_array()?72.map_to_list_and_array_inner_dtype(),73Sum => mapper.ensure_is_array()?.nested_sum_type(),74ToList => mapper75.ensure_is_array()?76.try_map_dtype(map_array_dtype_to_list_dtype),77Unique(_) => mapper78.ensure_is_array()?79.try_map_dtype(map_array_dtype_to_list_dtype),80NUnique => mapper.ensure_is_array()?.with_dtype(IDX_DTYPE),81Std(_) => mapper.ensure_is_array()?.moment_dtype(),82Var(_) => mapper.ensure_is_array()?.var_dtype(),83Mean => mapper.ensure_is_array()?.moment_dtype(),84Median => mapper.ensure_is_array()?.moment_dtype(),85#[cfg(feature = "array_any_all")]86Any | All => mapper.ensure_is_array()?.with_dtype(DataType::Boolean),87Sort(_) => mapper.ensure_is_array()?.with_same_dtype(),88Reverse => mapper.ensure_is_array()?.with_same_dtype(),89ArgMin | ArgMax => mapper.ensure_is_array()?.with_dtype(IDX_DTYPE),90Get(_) => mapper91.ensure_is_array()?92.map_to_list_and_array_inner_dtype(),93Join(_) => mapper.ensure_is_array()?.with_dtype(DataType::String),94#[cfg(feature = "is_in")]95Contains { nulls_equal: _ } => mapper.ensure_is_array()?.with_dtype(DataType::Boolean),96#[cfg(feature = "array_count")]97CountMatches => mapper.ensure_is_array()?.with_dtype(IDX_DTYPE),98Shift => mapper.ensure_is_array()?.with_same_dtype(),99Explode { .. } => mapper.ensure_is_array()?.try_map_to_array_inner_dtype(),100Slice(offset, length) => mapper101.ensure_is_array()?102.try_map_dtype(map_to_array_fixed_length(offset, length)),103#[cfg(feature = "array_to_struct")]104ToStruct(name_generator) => mapper.ensure_is_array()?.try_map_dtype(|dtype| {105let DataType::Array(inner, width) = dtype else {106polars_bail!(InvalidOperation: "expected Array type, got: {dtype}")107};108109(0..*width)110.map(|i| {111let name = match name_generator {112None => arr_default_struct_name_gen(i),113Some(ng) => PlSmallStr::from_string(ng.call(i)?),114};115Ok(Field::new(name, inner.as_ref().clone()))116})117.collect::<PolarsResult<Vec<Field>>>()118.map(DataType::Struct)119}),120}121}122123pub fn function_options(&self) -> FunctionOptions {124use IRArrayFunction as A;125match self {126#[cfg(feature = "array_any_all")]127A::Any | A::All => FunctionOptions::elementwise(),128#[cfg(feature = "is_in")]129A::Contains { nulls_equal: _ } => FunctionOptions::elementwise(),130#[cfg(feature = "array_count")]131A::CountMatches => FunctionOptions::elementwise(),132A::Concat => FunctionOptions::elementwise()133.with_flags(|f| f | FunctionFlags::INPUT_WILDCARD_EXPANSION),134A::Length135| A::Min136| A::Max137| A::Sum138| A::ToList139| A::Unique(_)140| A::NUnique141| A::Std(_)142| A::Var(_)143| A::Mean144| A::Median145| A::Sort(_)146| A::Reverse147| A::ArgMin148| A::ArgMax149| A::Get(_)150| A::Join(_)151| A::Shift152| A::Slice(_, _) => FunctionOptions::elementwise(),153A::Explode { .. } => FunctionOptions::row_separable(),154#[cfg(feature = "array_to_struct")]155A::ToStruct(_) => FunctionOptions::elementwise(),156}157}158}159160fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult<DataType> {161if let DataType::Array(inner, _) = datatype {162Ok(DataType::List(inner.clone()))163} else {164polars_bail!(ComputeError: "expected array dtype")165}166}167168fn map_to_array_fixed_length(169offset: &i64,170length: &i64,171) -> impl FnOnce(&DataType) -> PolarsResult<DataType> {172move |datatype: &DataType| {173if let DataType::Array(inner, array_len) = datatype {174let length: usize = if *length < 0 {175(*array_len as i64 + *length).max(0)176} else {177*length178}.try_into().map_err(|_| {179polars_err!(OutOfBounds: "length must be a non-negative integer, got: {}", length)180})?;181let (_, slice_offset) = slice_offsets(*offset, length, *array_len);182Ok(DataType::Array(inner.clone(), slice_offset))183} else {184polars_bail!(ComputeError: "expected array dtype, got {}", datatype);185}186}187}188189impl Display for IRArrayFunction {190fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {191use IRArrayFunction::*;192let name = match self {193Concat => "concat",194Length => "length",195Min => "min",196Max => "max",197Sum => "sum",198ToList => "to_list",199Unique(_) => "unique",200NUnique => "n_unique",201Std(_) => "std",202Var(_) => "var",203Mean => "mean",204Median => "median",205#[cfg(feature = "array_any_all")]206Any => "any",207#[cfg(feature = "array_any_all")]208All => "all",209Sort(_) => "sort",210Reverse => "reverse",211ArgMin => "arg_min",212ArgMax => "arg_max",213Get(_) => "get",214Join(_) => "join",215#[cfg(feature = "is_in")]216Contains { nulls_equal: _ } => "contains",217#[cfg(feature = "array_count")]218CountMatches => "count_matches",219Shift => "shift",220Slice(_, _) => "slice",221Explode { .. } => "explode",222#[cfg(feature = "array_to_struct")]223ToStruct(_) => "to_struct",224};225write!(f, "arr.{name}")226}227}228229/// Determine the output dtype of a `concat_arr` operation. Also performs validation to ensure input230/// dtypes are compatible.231fn concat_arr_output_dtype(232inputs: &mut dyn ExactSizeIterator<Item = (&str, &DataType)>,233) -> PolarsResult<DataType> {234#[allow(clippy::len_zero)]235if inputs.len() == 0 {236// should not be reachable - we did not set ALLOW_EMPTY_INPUTS237panic!();238}239240let mut inputs = inputs.map(|(name, dtype)| {241let (inner_dtype, width) = match dtype {242DataType::Array(inner, width) => (inner.as_ref(), *width),243dt => (dt, 1),244};245(name, dtype, inner_dtype, width)246});247let (first_name, first_dtype, first_inner_dtype, mut out_width) = inputs.next().unwrap();248249for (col_name, dtype, inner_dtype, width) in inputs {250out_width += width;251252if inner_dtype != first_inner_dtype {253polars_bail!(254SchemaMismatch:255"concat_arr dtype mismatch: expected {} or array[{}] dtype to match dtype of first \256input column (name: {}, dtype: {}), got {} instead for column {}",257first_inner_dtype, first_inner_dtype, first_name, first_dtype, dtype, col_name,258)259}260}261262Ok(DataType::Array(263Box::new(first_inner_dtype.clone()),264out_width,265))266}267268269