Path: blob/main/crates/polars-expr/src/expressions/ternary.rs
6940 views
use polars_core::POOL;1use polars_core::prelude::*;2use polars_plan::prelude::*;34use super::*;5use crate::expressions::{AggregationContext, PhysicalExpr};67pub struct TernaryExpr {8predicate: Arc<dyn PhysicalExpr>,9truthy: Arc<dyn PhysicalExpr>,10falsy: Arc<dyn PhysicalExpr>,11expr: Expr,12// Can be expensive on small data to run literals in parallel.13run_par: bool,14returns_scalar: bool,15}1617impl TernaryExpr {18pub fn new(19predicate: Arc<dyn PhysicalExpr>,20truthy: Arc<dyn PhysicalExpr>,21falsy: Arc<dyn PhysicalExpr>,22expr: Expr,23run_par: bool,24returns_scalar: bool,25) -> Self {26Self {27predicate,28truthy,29falsy,30expr,31run_par,32returns_scalar,33}34}35}3637fn finish_as_iters<'a>(38mut ac_truthy: AggregationContext<'a>,39mut ac_falsy: AggregationContext<'a>,40mut ac_mask: AggregationContext<'a>,41) -> PolarsResult<AggregationContext<'a>> {42let ca = ac_truthy43.iter_groups(false)44.zip(ac_falsy.iter_groups(false))45.zip(ac_mask.iter_groups(false))46.map(|((truthy, falsy), mask)| {47match (truthy, falsy, mask) {48(Some(truthy), Some(falsy), Some(mask)) => Some(49truthy50.as_ref()51.zip_with(mask.as_ref().bool()?, falsy.as_ref()),52),53_ => None,54}55.transpose()56})57.collect::<PolarsResult<ListChunked>>()?58.with_name(ac_truthy.get_values().name().clone());5960// Aggregation leaves only a single chunk.61let arr = ca.downcast_iter().next().unwrap();62let list_vals_len = arr.values().len();6364let mut out = ca.into_column();65if ac_truthy.arity_should_explode() && ac_falsy.arity_should_explode() && ac_mask.arity_should_explode() &&66// Exploded list should be equal to groups length.67list_vals_len == ac_truthy.groups.len()68{69out = out.explode(false)?70}7172ac_truthy.with_agg_state(AggState::AggregatedList(out));73ac_truthy.with_update_groups(UpdateGroups::WithSeriesLen);7475Ok(ac_truthy)76}7778impl PhysicalExpr for TernaryExpr {79fn as_expression(&self) -> Option<&Expr> {80Some(&self.expr)81}8283fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {84let mut state = state.split();85// Don't cache window functions as they run in parallel.86state.remove_cache_window_flag();87let mask_series = self.predicate.evaluate(df, &state)?;88let mask = mask_series.bool()?.clone();8990let op_truthy = || self.truthy.evaluate(df, &state);91let op_falsy = || self.falsy.evaluate(df, &state);92let (truthy, falsy) = if self.run_par {93POOL.install(|| rayon::join(op_truthy, op_falsy))94} else {95(op_truthy(), op_falsy())96};97let truthy = truthy?;98let falsy = falsy?;99100truthy.zip_with(&mask, &falsy)101}102103fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {104self.truthy.to_field(input_schema)105}106107#[allow(clippy::ptr_arg)]108fn evaluate_on_groups<'a>(109&self,110df: &DataFrame,111groups: &'a GroupPositions,112state: &ExecutionState,113) -> PolarsResult<AggregationContext<'a>> {114let op_mask = || self.predicate.evaluate_on_groups(df, groups, state);115let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state);116let op_falsy = || self.falsy.evaluate_on_groups(df, groups, state);117let (ac_mask, (ac_truthy, ac_falsy)) = if self.run_par {118POOL.install(|| rayon::join(op_mask, || rayon::join(op_truthy, op_falsy)))119} else {120(op_mask(), (op_truthy(), op_falsy()))121};122123let mut ac_mask = ac_mask?;124let mut ac_truthy = ac_truthy?;125let mut ac_falsy = ac_falsy?;126127use AggState::*;128129// Check if there are any:130// - non-unit literals131// - AggregatedScalar or AggregatedList132let mut has_non_unit_literal = false;133let mut has_aggregated = false;134// If the length has changed then we must not apply on the flat values135// as ternary broadcasting is length-sensitive.136let mut non_aggregated_len_modified = false;137138for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {139match ac.agg_state() {140LiteralScalar(s) => {141has_non_unit_literal = s.len() != 1;142143if has_non_unit_literal {144break;145}146},147NotAggregated(_) => {148non_aggregated_len_modified |= !ac.original_len;149},150AggregatedScalar(_) | AggregatedList(_) => {151has_aggregated = true;152},153}154}155156if has_non_unit_literal {157// finish_as_iters for non-unit literals to avoid materializing the158// literal inputs per-group.159if state.verbose() {160eprintln!("ternary agg: finish as iters due to non-unit literal")161}162return finish_as_iters(ac_truthy, ac_falsy, ac_mask);163}164165if !has_aggregated && !non_aggregated_len_modified {166// Everything is flat (either NotAggregated or a unit literal).167if state.verbose() {168eprintln!("ternary agg: finish all not-aggregated or unit literal");169}170171let out = ac_truthy172.get_values()173.zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?;174175for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {176if matches!(ac.agg_state(), NotAggregated(_)) {177let ac_target = ac;178179return Ok(AggregationContext {180state: NotAggregated(out),181groups: ac_target.groups.clone(),182update_groups: ac_target.update_groups,183original_len: ac_target.original_len,184});185}186}187188ac_truthy.with_agg_state(LiteralScalar(out));189190return Ok(ac_truthy);191}192193for ac in [&mut ac_mask, &mut ac_truthy, &mut ac_falsy].into_iter() {194if matches!(ac.agg_state(), NotAggregated(_)) {195let _ = ac.aggregated();196}197}198199// At this point the input agg states are one of the following:200// * `Literal` where `s.len() == 1`201// * `AggregatedList`202// * `AggregatedScalar`203204let mut non_literal_acs = Vec::<&AggregationContext>::with_capacity(3);205206// non_literal_acs will have at least 1 item because has_aggregated was207// true from above.208for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {209if !matches!(ac.agg_state(), LiteralScalar(_)) {210non_literal_acs.push(ac);211}212}213214for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) {215if std::mem::discriminant(ac_l.agg_state()) != std::mem::discriminant(ac_r.agg_state())216{217// Mix of AggregatedScalar and AggregatedList is done per group,218// as every row of the AggregatedScalar must be broadcasted to a219// list of the same length as the corresponding AggregatedList220// row.221if state.verbose() {222eprintln!(223"ternary agg: finish as iters due to mix of AggregatedScalar and AggregatedList"224)225}226return finish_as_iters(ac_truthy, ac_falsy, ac_mask);227}228}229230// At this point, the possible combinations are:231// * mix of unit literals and AggregatedScalar232// * `zip_with` can be called directly with the series233// * mix of unit literals and AggregatedList234// * `zip_with` can be called with the flat values after the offsets235// have been checked for alignment236let ac_target = non_literal_acs.first().unwrap();237238let agg_state_out = match ac_target.agg_state() {239AggregatedList(_) => {240// Ternary can be applied directly on the flattened series,241// given that their offsets have been checked to be equal.242if state.verbose() {243eprintln!("ternary agg: finish AggregatedList")244}245246for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) {247match (ac_l.agg_state(), ac_r.agg_state()) {248(AggregatedList(s_l), AggregatedList(s_r)) => {249let check = s_l.list().unwrap().offsets()?.as_slice()250== s_r.list().unwrap().offsets()?.as_slice();251252polars_ensure!(253check,254ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation"255);256},257_ => unreachable!(),258}259}260261let truthy = if let AggregatedList(s) = ac_truthy.agg_state() {262s.list().unwrap().get_inner().into_column()263} else {264ac_truthy.get_values().clone()265};266267let falsy = if let AggregatedList(s) = ac_falsy.agg_state() {268s.list().unwrap().get_inner().into_column()269} else {270ac_falsy.get_values().clone()271};272273let mask = if let AggregatedList(s) = ac_mask.agg_state() {274s.list().unwrap().get_inner().into_column()275} else {276ac_mask.get_values().clone()277};278279let out = truthy.zip_with(mask.bool()?, &falsy)?;280281// The output series is guaranteed to be aligned with expected282// offsets buffer of the result, so we construct the result283// ListChunked directly from the 2.284let out = out.rechunk();285// @scalar-opt286// @partition-opt287let values = out.as_materialized_series().array_ref(0);288let offsets = ac_target.get_values().list().unwrap().offsets()?;289let inner_type = out.dtype();290let dtype = LargeListArray::default_datatype(values.dtype().clone());291292// SAFETY: offsets are correct.293let out = LargeListArray::new(dtype, offsets, values.clone(), None);294295let mut out = ListChunked::with_chunk(truthy.name().clone(), out);296unsafe { out.to_logical(inner_type.clone()) };297298if ac_target.get_values().list().unwrap()._can_fast_explode() {299out.set_fast_explode();300};301302let out = out.into_column();303304AggregatedList(out)305},306AggregatedScalar(_) => {307if state.verbose() {308eprintln!("ternary agg: finish AggregatedScalar")309}310311let out = ac_truthy312.get_values()313.zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?;314AggregatedScalar(out)315},316_ => {317unreachable!()318},319};320321Ok(AggregationContext {322state: agg_state_out,323groups: ac_target.groups.clone(),324update_groups: ac_target.update_groups,325original_len: ac_target.original_len,326})327}328fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {329Some(self)330}331332fn is_scalar(&self) -> bool {333self.returns_scalar334}335}336337impl PartitionedAggregation for TernaryExpr {338fn evaluate_partitioned(339&self,340df: &DataFrame,341groups: &GroupPositions,342state: &ExecutionState,343) -> PolarsResult<Column> {344let truthy = self.truthy.as_partitioned_aggregator().unwrap();345let falsy = self.falsy.as_partitioned_aggregator().unwrap();346let mask = self.predicate.as_partitioned_aggregator().unwrap();347348let truthy = truthy.evaluate_partitioned(df, groups, state)?;349let falsy = falsy.evaluate_partitioned(df, groups, state)?;350let mask = mask.evaluate_partitioned(df, groups, state)?;351let mask = mask.bool()?.clone();352353truthy.zip_with(&mask, &falsy)354}355356fn finalize(357&self,358partitioned: Column,359_groups: &GroupPositions,360_state: &ExecutionState,361) -> PolarsResult<Column> {362Ok(partitioned)363}364}365366367