Path: blob/main/crates/polars-expr/src/expressions/apply.rs
6940 views
use std::borrow::Cow;12use polars_core::POOL;3use polars_core::chunked_array::builder::get_list_builder;4use polars_core::chunked_array::from_iterator_par::{5ChunkedCollectParIterExt, try_list_from_par_iter,6};7use polars_core::prelude::*;8use rayon::prelude::*;910use super::*;11use crate::expressions::{12AggState, AggregationContext, PartitionedAggregation, PhysicalExpr, UpdateGroups,13};1415#[derive(Clone)]16pub struct ApplyExpr {17inputs: Vec<Arc<dyn PhysicalExpr>>,18function: SpecialEq<Arc<dyn ColumnsUdf>>,19expr: Expr,20flags: FunctionFlags,21function_operates_on_scalar: bool,22input_schema: SchemaRef,23allow_threading: bool,24check_lengths: bool,25output_field: Field,26}2728impl ApplyExpr {29#[allow(clippy::too_many_arguments)]30pub(crate) fn new(31inputs: Vec<Arc<dyn PhysicalExpr>>,32function: SpecialEq<Arc<dyn ColumnsUdf>>,33expr: Expr,34options: FunctionOptions,35allow_threading: bool,36input_schema: SchemaRef,37output_field: Field,38function_operates_on_scalar: bool,39) -> Self {40debug_assert!(41!options.is_length_preserving()42|| !options.flags.contains(FunctionFlags::RETURNS_SCALAR),43"expr {expr:?} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive",44);4546Self {47inputs,48function,49expr,50flags: options.flags,51function_operates_on_scalar,52input_schema,53allow_threading,54check_lengths: options.check_lengths(),55output_field,56}57}5859#[allow(clippy::ptr_arg)]60fn prepare_multiple_inputs<'a>(61&self,62df: &DataFrame,63groups: &'a GroupPositions,64state: &ExecutionState,65) -> PolarsResult<Vec<AggregationContext<'a>>> {66let f = |e: &Arc<dyn PhysicalExpr>| e.evaluate_on_groups(df, groups, state);67if self.allow_threading {68POOL.install(|| self.inputs.par_iter().map(f).collect())69} else {70self.inputs.iter().map(f).collect()71}72}7374fn finish_apply_groups<'a>(75&self,76mut ac: AggregationContext<'a>,77ca: ListChunked,78) -> PolarsResult<AggregationContext<'a>> {79let c = if self.flags.returns_scalar() {80let out = ca.explode(false).unwrap();81// if the explode doesn't return the same len, it wasn't scalar.82polars_ensure!(out.len() == ca.len(), InvalidOperation: "expected scalar for expr: {}, got {}", self.expr, &out);83ac.update_groups = UpdateGroups::No;84out.into_column()85} else {86ac.with_update_groups(UpdateGroups::WithSeriesLen);87ca.into_series().into()88};8990ac.with_values_and_args(c, true, None, false, self.flags.returns_scalar())?;9192Ok(ac)93}9495fn get_input_schema(&self, _df: &DataFrame) -> Cow<'_, Schema> {96Cow::Borrowed(self.input_schema.as_ref())97}9899/// Evaluates and flattens `Option<Column>` to `Column`.100fn eval_and_flatten(&self, inputs: &mut [Column]) -> PolarsResult<Column> {101self.function.call_udf(inputs)102}103fn apply_single_group_aware<'a>(104&self,105mut ac: AggregationContext<'a>,106) -> PolarsResult<AggregationContext<'a>> {107let s = ac.get_values();108109#[allow(clippy::nonminimal_bool)]110{111polars_ensure!(112!(matches!(ac.agg_state(), AggState::AggregatedScalar(_)) && !s.dtype().is_list() ) ,113expr = self.expr,114ComputeError: "cannot aggregate, the column is already aggregated",115);116}117118let name = s.name().clone();119let agg = ac.aggregated();120// Collection of empty list leads to a null dtype. See: #3687.121if agg.is_empty() {122// Create input for the function to determine the output dtype, see #3946.123let agg = agg.list().unwrap();124let input_dtype = agg.inner_dtype();125let input = Column::full_null(name.clone(), 0, input_dtype);126127let output = self.eval_and_flatten(&mut [input])?;128let ca = ListChunked::full(name, output.as_materialized_series(), 0);129return self.finish_apply_groups(ac, ca);130}131132let f = |opt_s: Option<Series>| match opt_s {133None => Ok(None),134Some(mut s) => {135if self.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY) {136s.rename(name.clone());137}138Ok(Some(139self.function140.call_udf(&mut [Column::from(s)])?141.take_materialized_series(),142))143},144};145146let ca: ListChunked = if self.allow_threading {147let dtype = if self.output_field.dtype.is_known() && !self.output_field.dtype.is_null()148{149Some(self.output_field.dtype.clone())150} else {151None152};153154let lst = agg.list().unwrap();155let iter = lst.par_iter().map(f);156157if let Some(dtype) = dtype {158// @NOTE: Since the output type for scalars does an implicit explode, we need to159// patch up the type here to also be a list.160let out_dtype = if self.is_scalar() {161DataType::List(Box::new(dtype))162} else {163dtype164};165166let out: ListChunked = POOL.install(|| {167iter.collect_ca_with_dtype::<PolarsResult<_>>(PlSmallStr::EMPTY, out_dtype)168})?;169out170} else {171POOL.install(|| try_list_from_par_iter(iter, PlSmallStr::EMPTY))?172}173} else {174agg.list()175.unwrap()176.into_iter()177.map(f)178.collect::<PolarsResult<_>>()?179};180181self.finish_apply_groups(ac, ca.with_name(name))182}183184/// Apply elementwise e.g. ignore the group/list indices.185fn apply_single_elementwise<'a>(186&self,187mut ac: AggregationContext<'a>,188) -> PolarsResult<AggregationContext<'a>> {189let (c, aggregated) = match ac.agg_state() {190AggState::AggregatedList(c) => {191let ca = c.list().unwrap();192let out = ca.apply_to_inner(&|s| {193Ok(self194.eval_and_flatten(&mut [s.into_column()])?195.take_materialized_series())196})?;197(out.into_column(), true)198},199AggState::NotAggregated(c) => {200let (out, aggregated) = (self.eval_and_flatten(&mut [c.clone()])?, false);201check_map_output_len(c.len(), out.len(), &self.expr)?;202(out, aggregated)203},204agg_state => {205ac.with_agg_state(agg_state.try_map(|s| self.eval_and_flatten(&mut [s.clone()]))?);206return Ok(ac);207},208};209210ac.with_values_and_args(c, aggregated, Some(&self.expr), true, self.is_scalar())?;211Ok(ac)212}213fn apply_multiple_group_aware<'a>(214&self,215mut acs: Vec<AggregationContext<'a>>,216df: &DataFrame,217) -> PolarsResult<AggregationContext<'a>> {218let mut container = vec![Default::default(); acs.len()];219let schema = self.get_input_schema(df);220let field = self.to_field(&schema)?;221222// Aggregate representation of the aggregation contexts,223// then unpack the lists and finally create iterators from this list chunked arrays.224let mut iters = acs225.iter_mut()226.map(|ac| ac.iter_groups(self.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY)))227.collect::<Vec<_>>();228229// Length of the items to iterate over.230let len = iters[0].size_hint().0;231232let ca = if len == 0 {233let mut builder = get_list_builder(&field.dtype, len * 5, len, field.name);234for _ in 0..len {235container.clear();236for iter in &mut iters {237match iter.next().unwrap() {238None => {239builder.append_null();240},241Some(s) => container.push(s.deep_clone().into()),242}243}244let out = self245.function246.call_udf(&mut container)247.map(|c| c.take_materialized_series())?;248249builder.append_series(&out)?250}251builder.finish()252} else {253// We still need this branch to materialize unknown/ data dependent types in eager. :(254(0..len)255.map(|_| {256container.clear();257for iter in &mut iters {258match iter.next().unwrap() {259None => return Ok(None),260Some(s) => container.push(s.deep_clone().into()),261}262}263Ok(Some(264self.function265.call_udf(&mut container)?266.take_materialized_series(),267))268})269.collect::<PolarsResult<ListChunked>>()?270.with_name(field.name.clone())271};272#[cfg(debug_assertions)]273{274let inner = ca.dtype().inner_dtype().unwrap();275if field.dtype.is_known() {276assert_eq!(inner, &field.dtype);277}278}279280drop(iters);281282// Take the first aggregation context that as that is the input series.283let ac = acs.swap_remove(0);284self.finish_apply_groups(ac, ca)285}286}287288fn check_map_output_len(input_len: usize, output_len: usize, expr: &Expr) -> PolarsResult<()> {289polars_ensure!(290input_len == output_len, expr = expr, InvalidOperation:291"output length of `map` ({}) must be equal to the input length ({}); \292consider using `apply` instead", output_len, input_len293);294Ok(())295}296297impl PhysicalExpr for ApplyExpr {298fn as_expression(&self) -> Option<&Expr> {299Some(&self.expr)300}301302fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {303let f = |e: &Arc<dyn PhysicalExpr>| e.evaluate(df, state);304let mut inputs = if self.allow_threading && self.inputs.len() > 1 {305POOL.install(|| {306self.inputs307.par_iter()308.map(f)309.collect::<PolarsResult<Vec<_>>>()310})311} else {312self.inputs.iter().map(f).collect::<PolarsResult<Vec<_>>>()313}?;314315if self.flags.contains(FunctionFlags::ALLOW_RENAME) {316self.eval_and_flatten(&mut inputs)317} else {318let in_name = inputs[0].name().clone();319Ok(self.eval_and_flatten(&mut inputs)?.with_name(in_name))320}321}322323#[allow(clippy::ptr_arg)]324fn evaluate_on_groups<'a>(325&self,326df: &DataFrame,327groups: &'a GroupPositions,328state: &ExecutionState,329) -> PolarsResult<AggregationContext<'a>> {330if self.inputs.len() == 1 {331let ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;332333match self.flags.is_elementwise() {334false => self.apply_single_group_aware(ac),335true => self.apply_single_elementwise(ac),336}337} else {338let acs = self.prepare_multiple_inputs(df, groups, state)?;339340match self.flags.is_elementwise() {341false => self.apply_multiple_group_aware(acs, df),342true => {343let mut has_agg_list = false;344let mut has_agg_scalar = false;345let mut has_not_agg = false;346for ac in &acs {347match ac.state {348AggState::AggregatedList(_) => has_agg_list = true,349AggState::AggregatedScalar(_) => has_agg_scalar = true,350AggState::NotAggregated(_) => has_not_agg = true,351_ => {},352}353}354if has_agg_list || (has_agg_scalar && has_not_agg) {355self.apply_multiple_group_aware(acs, df)356} else {357apply_multiple_elementwise(358acs,359self.function.as_ref(),360&self.expr,361self.check_lengths,362self.is_scalar(),363)364}365},366}367}368}369370fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {371self.expr.to_field(input_schema)372}373fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {374if self.inputs.len() == 1 && self.flags.is_elementwise() {375Some(self)376} else {377None378}379}380fn is_scalar(&self) -> bool {381self.flags.returns_scalar()382|| (self.function_operates_on_scalar && self.flags.is_length_preserving())383}384}385386fn apply_multiple_elementwise<'a>(387mut acs: Vec<AggregationContext<'a>>,388function: &dyn ColumnsUdf,389expr: &Expr,390check_lengths: bool,391returns_scalar: bool,392) -> PolarsResult<AggregationContext<'a>> {393match acs.first().unwrap().agg_state() {394// A fast path that doesn't drop groups of the first arg.395// This doesn't require group re-computation.396AggState::AggregatedList(s) => {397let ca = s.list().unwrap();398399let other = acs[1..]400.iter()401.map(|ac| ac.flat_naive().into_owned())402.collect::<Vec<_>>();403404let out = ca.apply_to_inner(&|s| {405let mut args = Vec::with_capacity(other.len() + 1);406args.push(s.into());407args.extend_from_slice(&other);408Ok(function409.call_udf(&mut args)?410.as_materialized_series()411.clone())412})?;413let mut ac = acs.swap_remove(0);414ac.with_values(out.into_column(), true, None)?;415Ok(ac)416},417first_as => {418let check_lengths = check_lengths && !matches!(first_as, AggState::LiteralScalar(_));419let aggregated = acs.iter().all(|ac| ac.is_aggregated() | ac.is_literal())420&& acs.iter().any(|ac| ac.is_aggregated());421let mut c = acs422.iter_mut()423.enumerate()424.map(|(i, ac)| {425// Make sure the groups are updated because we are about to throw away426// the series length information, only on the first iteration.427if let (0, UpdateGroups::WithSeriesLen) = (i, &ac.update_groups) {428ac.groups();429}430431ac.flat_naive().into_owned()432})433.collect::<Vec<_>>();434435let input_len = c[0].len();436let c = function.call_udf(&mut c)?;437if check_lengths {438check_map_output_len(input_len, c.len(), expr)?;439}440441// Take the first aggregation context that as that is the input series.442let mut ac = acs.swap_remove(0);443ac.with_values_and_args(c, aggregated, None, true, returns_scalar)?;444Ok(ac)445},446}447}448449impl PartitionedAggregation for ApplyExpr {450fn evaluate_partitioned(451&self,452df: &DataFrame,453groups: &GroupPositions,454state: &ExecutionState,455) -> PolarsResult<Column> {456let a = self.inputs[0].as_partitioned_aggregator().unwrap();457let s = a.evaluate_partitioned(df, groups, state)?;458459if self.flags.contains(FunctionFlags::ALLOW_RENAME) {460self.eval_and_flatten(&mut [s])461} else {462let in_name = s.name().clone();463Ok(self.eval_and_flatten(&mut [s])?.with_name(in_name))464}465}466467fn finalize(468&self,469partitioned: Column,470_groups: &GroupPositions,471_state: &ExecutionState,472) -> PolarsResult<Column> {473Ok(partitioned)474}475}476477478