Path: blob/main/crates/polars-expr/src/expressions/apply.rs
8424 views
use std::borrow::Cow;12use polars_core::POOL;3use polars_core::chunked_array::builder::get_list_builder;4use polars_core::chunked_array::from_iterator_par::{5ChunkedCollectParIterExt, try_list_from_par_iter,6};7use polars_core::prelude::*;8use rayon::prelude::*;910use super::*;11use crate::dispatch::GroupsUdf;12use crate::expressions::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};1314#[derive(Clone)]15pub struct ApplyExpr {16inputs: Vec<Arc<dyn PhysicalExpr>>,17function: SpecialEq<Arc<dyn ColumnsUdf>>,18groups_function: Option<SpecialEq<Arc<dyn GroupsUdf>>>,19expr: Expr,20flags: FunctionFlags,21function_operates_on_scalar: bool,22input_schema: SchemaRef,23allow_threading: bool,24check_lengths: bool,25is_fallible: bool,2627/// Output field of the expression excluding potential aggregation.28output_field: Field,29}3031impl ApplyExpr {32#[allow(clippy::too_many_arguments)]33pub(crate) fn new(34inputs: Vec<Arc<dyn PhysicalExpr>>,35function: SpecialEq<Arc<dyn ColumnsUdf>>,36groups_function: Option<SpecialEq<Arc<dyn GroupsUdf>>>,37expr: Expr,38options: FunctionOptions,39allow_threading: bool,40input_schema: SchemaRef,41non_aggregated_output_field: Field,42function_operates_on_scalar: bool,43is_fallible: bool,44) -> Self {45debug_assert!(46!options.is_length_preserving()47|| !options.flags.contains(FunctionFlags::RETURNS_SCALAR),48"expr {expr:?} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive",49);5051Self {52inputs,53function,54groups_function,55expr,56flags: options.flags,57function_operates_on_scalar,58input_schema,59allow_threading,60check_lengths: options.check_lengths(),61output_field: non_aggregated_output_field,62is_fallible,63}64}6566#[allow(clippy::ptr_arg)]67fn prepare_multiple_inputs<'a>(68&self,69df: &DataFrame,70groups: &'a GroupPositions,71state: &ExecutionState,72) -> PolarsResult<Vec<AggregationContext<'a>>> {73let f = |e: &Arc<dyn PhysicalExpr>| e.evaluate_on_groups(df, groups, state);74if self.allow_threading {75POOL.install(|| self.inputs.par_iter().map(f).collect())76} else {77self.inputs.iter().map(f).collect()78}79}8081fn finish_apply_groups<'a>(82&self,83mut ac: AggregationContext<'a>,84ca: ListChunked,85) -> PolarsResult<AggregationContext<'a>> {86let c = if self.is_scalar() {87let out = ca88.explode(ExplodeOptions {89empty_as_null: true,90keep_nulls: true,91})92.unwrap();93// if the explode doesn't return the same len, it wasn't scalar.94polars_ensure!(out.len() == ca.len(), InvalidOperation: "expected scalar for expr: {}, got {}", self.expr, &out);95ac.update_groups = UpdateGroups::No;96out.into_column()97} else {98ac.with_update_groups(UpdateGroups::WithSeriesLen);99ca.into_series().into()100};101102ac.with_values_and_args(c, true, None, false, self.is_scalar())?;103104Ok(ac)105}106107fn get_input_schema(&self, _df: &DataFrame) -> Cow<'_, Schema> {108Cow::Borrowed(self.input_schema.as_ref())109}110111/// Evaluates and flattens `Option<Column>` to `Column`.112fn eval_and_flatten(&self, inputs: &mut [Column]) -> PolarsResult<Column> {113self.function.call_udf(inputs)114}115116fn apply_single_group_aware<'a>(117&self,118mut ac: AggregationContext<'a>,119) -> PolarsResult<AggregationContext<'a>> {120// Fix up groups for AggregatedScalar, so that we can pretend they are just normal groups.121ac.set_groups_for_undefined_agg_states();122123let name = ac.get_values().name().clone();124let f = |opt_s: Option<Series>| match opt_s {125None => Ok(None),126Some(mut s) => {127if self.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY) {128s.rename(name.clone());129}130Ok(Some(131self.function132.call_udf(&mut [Column::from(s)])?133.take_materialized_series(),134))135},136};137138// In case of overlapping (rolling) groups, we build groups in a lazy manner to avoid139// memory explosion.140// TODO: support Idx GroupsType.141if matches!(ac.agg_state(), AggState::NotAggregated(_)) && ac.groups.is_overlapping() {142let ca: ChunkedArray<_> = if self.allow_threading {143ac.par_iter_groups_lazy()144.map(f)145.collect::<PolarsResult<_>>()?146} else {147ac.iter_groups_lazy().map(f).collect::<PolarsResult<_>>()?148};149return self.finish_apply_groups(ac, ca.with_name(name));150}151152// At this point, calling aggregated() will not lead to memory explosion.153let agg = match ac.agg_state() {154AggState::AggregatedScalar(s) => s.as_list().into_column(),155_ => ac.aggregated(),156};157158// Collection of empty list leads to a null dtype. See: #3687.159if agg.is_empty() {160// Create input for the function to determine the output dtype, see #3946.161let agg = agg.list().unwrap();162let input_dtype = agg.inner_dtype();163let input = Column::full_null(name.clone(), 0, input_dtype);164165let output = self.eval_and_flatten(&mut [input])?;166let ca = ListChunked::full(name, output.as_materialized_series(), 0);167return self.finish_apply_groups(ac, ca);168}169170let ca: ListChunked = if self.allow_threading {171let lst = agg.list().unwrap();172let iter = lst.par_iter().map(f);173174if self.output_field.dtype.is_known() {175let dtype = self.output_field.dtype.clone();176let dtype = dtype.implode();177POOL.install(|| {178iter.collect_ca_with_dtype::<PolarsResult<_>>(PlSmallStr::EMPTY, dtype)179})?180} else {181POOL.install(|| try_list_from_par_iter(iter, PlSmallStr::EMPTY))?182}183} else {184agg.list()185.unwrap()186.into_iter()187.map(f)188.collect::<PolarsResult<_>>()?189};190191self.finish_apply_groups(ac, ca.with_name(name))192}193194/// Apply elementwise e.g. ignore the group/list indices.195fn apply_single_elementwise<'a>(196&self,197mut ac: AggregationContext<'a>,198) -> PolarsResult<AggregationContext<'a>> {199let (c, aggregated) = match ac.agg_state() {200AggState::AggregatedList(c) => {201let ca = c.list().unwrap();202let out = ca.apply_to_inner(&|s| {203Ok(self204.eval_and_flatten(&mut [s.into_column()])?205.take_materialized_series())206})?;207(out.into_column(), true)208},209AggState::NotAggregated(c) => {210let (out, aggregated) = (self.eval_and_flatten(&mut [c.clone()])?, false);211check_map_output_len(c.len(), out.len(), &self.expr)?;212(out, aggregated)213},214agg_state => {215ac.with_agg_state(agg_state.try_map(|s| self.eval_and_flatten(&mut [s.clone()]))?);216return Ok(ac);217},218};219220ac.with_values_and_args(c, aggregated, Some(&self.expr), true, self.is_scalar())?;221Ok(ac)222}223224// Fast-path when every AggState is a LiteralScalar. This path avoids calling aggregated() or225// groups(), and returns a LiteralScalar, on the implicit condition that the function is pure.226fn apply_all_literal_elementwise<'a>(227&self,228mut acs: Vec<AggregationContext<'a>>,229) -> PolarsResult<AggregationContext<'a>> {230let mut cols = acs231.iter()232.map(|ac| ac.get_values().clone())233.collect::<Vec<_>>();234let out = self.function.call_udf(&mut cols)?;235polars_ensure!(236out.len() == 1,237ComputeError: "elementwise expression {:?} must return exactly 1 value on literals, got {}",238&self.expr, out.len()239);240let mut ac = acs.pop().unwrap();241ac.with_literal(out);242Ok(ac)243}244245fn apply_multiple_elementwise<'a>(246&self,247mut acs: Vec<AggregationContext<'a>>,248must_aggregate: bool,249) -> PolarsResult<AggregationContext<'a>> {250// At this stage, we either have (with or without LiteralScalars):251// - one or more AggregatedList or NotAggregated ACs252// - one or more AggregatedScalar ACs253254let mut previous = None;255for ac in acs.iter_mut() {256// TBD: If we want to be strict, we would check all groups257if matches!(258ac.state,259AggState::LiteralScalar(_) | AggState::AggregatedScalar(_)260) {261continue;262}263264if must_aggregate {265ac.aggregated();266}267268if matches!(ac.state, AggState::AggregatedList(_)) {269if let Some(p) = previous {270ac.groups().check_lengths(p)?;271}272previous = Some(ac.groups());273}274}275276// At this stage, we do not have both AggregatedList and NotAggregated ACs277278// The first non-LiteralScalar AC will be used as the base AC to retain the context279let base_ac_idx = acs.iter().position(|ac| !ac.is_literal()).unwrap();280281match acs[base_ac_idx].agg_state() {282AggState::AggregatedList(s) => {283let aggregated = acs.iter().any(|ac| ac.is_aggregated());284let ca = s.list().unwrap();285let input_len = s.len();286287let out = ca.apply_to_inner(&|_| {288let mut cols = acs289.iter()290.map(|ac| ac.flat_naive().into_owned())291.collect::<Vec<_>>();292Ok(self293.function294.call_udf(&mut cols)?295.as_materialized_series()296.clone())297})?;298299let out = out.into_column();300if self.check_lengths {301check_map_output_len(input_len, out.len(), &self.expr)?;302}303304let mut ac = acs.swap_remove(base_ac_idx);305ac.with_values_and_args(306out,307aggregated,308Some(&self.expr),309false,310self.is_scalar(),311)?;312Ok(ac)313},314_ => {315let aggregated = acs.iter().any(|ac| ac.is_aggregated());316debug_assert!(aggregated == self.is_scalar());317318let mut cols = acs319.iter()320.map(|ac| ac.flat_naive().into_owned())321.collect::<Vec<_>>();322323let input_len = cols[base_ac_idx].len();324let out = self.function.call_udf(&mut cols)?;325if self.check_lengths {326check_map_output_len(input_len, out.len(), &self.expr)?;327}328329let mut ac = acs.swap_remove(base_ac_idx);330ac.with_values_and_args(331out,332aggregated,333Some(&self.expr),334false,335self.is_scalar(),336)?;337Ok(ac)338},339}340}341342fn apply_multiple_group_aware<'a>(343&self,344mut acs: Vec<AggregationContext<'a>>,345df: &DataFrame,346) -> PolarsResult<AggregationContext<'a>> {347let mut container = vec![Default::default(); acs.len()];348let schema = self.get_input_schema(df);349let field = self.to_field(&schema)?;350351// Aggregate representation of the aggregation contexts,352// then unpack the lists and finally create iterators from this list chunked arrays.353let mut iters = acs354.iter_mut()355.map(|ac| ac.iter_groups(self.flags.contains(FunctionFlags::PASS_NAME_TO_APPLY)))356.collect::<Vec<_>>();357358// Length of the items to iterate over.359let len = iters[0].size_hint().0;360361let ca = if field.dtype().is_known() {362let mut builder = get_list_builder(&field.dtype, len * 5, len, field.name);363for _ in 0..len {364container.clear();365for iter in &mut iters {366match iter.next().unwrap() {367None => {368builder.append_null();369},370Some(s) => container.push(s.deep_clone().into()),371}372}373let out = self374.function375.call_udf(&mut container)376.map(|c| c.take_materialized_series())?;377378builder.append_series(&out)?379}380builder.finish()381} else {382// We still need this branch to materialize unknown/ data dependent types in eager. :(383(0..len)384.map(|_| {385container.clear();386for iter in &mut iters {387match iter.next().unwrap() {388None => return Ok(None),389Some(s) => container.push(s.deep_clone().into()),390}391}392Ok(Some(393self.function394.call_udf(&mut container)?395.take_materialized_series(),396))397})398.collect::<PolarsResult<ListChunked>>()?399.with_name(field.name.clone())400};401#[cfg(debug_assertions)]402{403let inner = ca.dtype().inner_dtype().unwrap();404if field.dtype.is_known() {405assert_eq!(inner, &field.dtype);406}407}408409drop(iters);410411// Take the first aggregation context that as that is the input series.412let ac = acs.swap_remove(0);413self.finish_apply_groups(ac, ca)414}415}416417fn check_map_output_len(input_len: usize, output_len: usize, expr: &Expr) -> PolarsResult<()> {418polars_ensure!(419input_len == output_len, expr = expr, InvalidOperation:420"output length of `map` ({}) must be equal to the input length ({}); \421consider using `apply` instead", output_len, input_len422);423Ok(())424}425426impl PhysicalExpr for ApplyExpr {427fn as_expression(&self) -> Option<&Expr> {428Some(&self.expr)429}430431fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {432let f = |e: &Arc<dyn PhysicalExpr>| e.evaluate(df, state);433let mut inputs = if self.allow_threading && self.inputs.len() > 1 {434POOL.install(|| {435self.inputs436.par_iter()437.map(f)438.collect::<PolarsResult<Vec<_>>>()439})440} else {441self.inputs.iter().map(f).collect::<PolarsResult<Vec<_>>>()442}?;443444if self.flags.contains(FunctionFlags::ALLOW_RENAME) {445self.eval_and_flatten(&mut inputs)446} else {447let in_name = inputs[0].name().clone();448Ok(self.eval_and_flatten(&mut inputs)?.with_name(in_name))449}450}451452#[allow(clippy::ptr_arg)]453fn evaluate_on_groups<'a>(454&self,455df: &DataFrame,456groups: &'a GroupPositions,457state: &ExecutionState,458) -> PolarsResult<AggregationContext<'a>> {459// Some function have specialized implementation.460if let Some(groups_function) = self.groups_function.as_ref() {461return groups_function.evaluate_on_groups(&self.inputs, df, groups, state);462}463464if self.inputs.len() == 1 {465let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;466467if self.flags.is_elementwise() && (!self.is_fallible || ac.groups_cover_all_values()) {468self.apply_single_elementwise(ac)469} else {470self.apply_single_group_aware(ac)471}472} else {473let mut acs = self.prepare_multiple_inputs(df, groups, state)?;474475match self.flags.is_elementwise() {476false => self.apply_multiple_group_aware(acs, df),477true => {478// Implementation dispatch:479// The current implementation of `apply_multiple_elementwise` requires the480// multiple inputs to have a compatible data layout as it invokes `flat_naive()`.481// Compatible means matching as-is, or possibly matching after aggregation,482// or matching after an implicit broadcast by the function.483484// The dispatch logic between the implementations depends on the combination of aggstates:485// - Any presence of LiteralScalar is immaterial as it gets broadcasted in the UDF.486// - Combination of AggregatedScalar and AggregatedList => NOT compatible.487// - Combination of AggregatedScalar and NotAggregated => NOT compatible.488// - Any other combination => comptable, and thereforee allowed for elementwise.489// In this case, aggregated() on NotAggregated may be required; however, it can be490// prohibitively memory expensive when dealing with overlapping (e.g., rolling) groups,491// in which case we fall-back to group_aware.492493// Consequently, these may follow the elementwise path (not exhaustive):494// - All AggregatedScalar495// - A combination of AggregatedList(s) and NotAggregated(s) without expensive aggregation.496// - Either of the above with or without LiteralScalar497498// Visually, in the case of 2 aggstates:499// Legend:500// - el = elementwise, no need to aggregate() NotAgg501// - el + agg = elementwise, but must aggregate() NotAgg502// - ga = group_aware503// - alit = all_literal504// - * = broadcast falls back to group_aware505// - ~ = same a smirror pair (symmetric)506//507// | AggList | NotAgg | AggScalar | LitScalar508// --------------------------------------------------------509// AggList | el* | depends* | ga | el510// NotAgg | ~ | depends* | ga | el511// AggScalar | ~ | ~ | el | el512// LitScalar | ~ | ~ | ~ | alit513//514// In case it depends, extending to any combination of multiple aggstates515// (a) Multiple NotAggs, w/o AggList516//517// | !has_rolling | has_rolling518// -------------------------------------------------519// groups match | el | el520// groups_diverge | el+agg | ga521//522// (b) Multiple NotAggs, with at least 1 AggList523//524// | !has_rolling | has_rolling525// -------------------------------------------------526// groups match | el+agg | ga527// groups diverge | el+agg | ga528//529// * Finally, when broadcast is required in non-scalar we switch to group_aware530531// Collect statistics on input aggstates532let mut has_agg_list = false;533let mut has_agg_scalar = false;534let mut has_not_agg = false;535let mut has_not_agg_with_overlapping_groups = false;536let mut not_agg_groups_may_diverge = false;537538let mut previous: Option<&AggregationContext<'_>> = None;539for ac in &acs {540match ac.state {541AggState::AggregatedList(_) => {542has_agg_list = true;543},544AggState::AggregatedScalar(_) => has_agg_scalar = true,545AggState::NotAggregated(_) => {546has_not_agg = true;547if let Some(p) = previous {548not_agg_groups_may_diverge |=549!std::ptr::eq(p.groups.as_ref(), ac.groups.as_ref());550}551previous = Some(ac);552if ac.groups.is_overlapping() {553has_not_agg_with_overlapping_groups = true;554}555},556_ => {},557}558}559560let all_literal = !(has_agg_list || has_agg_scalar || has_not_agg);561let elementwise_must_aggregate =562has_not_agg && (has_agg_list || not_agg_groups_may_diverge);563564if all_literal {565// Fast path566self.apply_all_literal_elementwise(acs)567} else if has_agg_scalar && (has_agg_list || has_not_agg) {568// Not compatible569self.apply_multiple_group_aware(acs, df)570} else if elementwise_must_aggregate && has_not_agg_with_overlapping_groups {571// Compatible but calling aggregated() is too expensive572self.apply_multiple_group_aware(acs, df)573} else if self.is_fallible574&& acs.iter_mut().any(|ac| !ac.groups_cover_all_values())575{576// Fallible expression and there are elements that are masked out.577self.apply_multiple_group_aware(acs, df)578} else {579// Broadcast in NotAgg or AggList requires group_aware580acs.iter_mut().filter(|ac| !ac.is_literal()).for_each(|ac| {581ac.groups();582});583let has_broadcast =584if let Some(base_ac_idx) = acs.iter().position(|ac| !ac.is_literal()) {585acs.iter()586.enumerate()587.filter(|(i, ac)| *i != base_ac_idx && !ac.is_literal())588.any(|(_, ac)| {589acs[base_ac_idx].groups.iter().zip(ac.groups.iter()).any(590|(l, r)| {591l.len() != r.len() && (l.len() == 1 || r.len() == 1)592},593)594})595} else {596false597};598if has_broadcast {599// Broadcast fall-back.600self.apply_multiple_group_aware(acs, df)601} else {602self.apply_multiple_elementwise(acs, elementwise_must_aggregate)603}604}605},606}607}608}609610fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {611Ok(self.output_field.clone())612}613fn is_scalar(&self) -> bool {614self.flags.returns_scalar()615|| (self.function_operates_on_scalar && self.flags.is_length_preserving())616}617}618619620