Path: blob/main/crates/polars-expr/src/expressions/ternary.rs
8422 views
use polars_core::POOL;1use polars_core::prelude::*;2use polars_plan::prelude::*;34use super::*;5use crate::expressions::{AggregationContext, PhysicalExpr};67pub struct TernaryExpr {8predicate: Arc<dyn PhysicalExpr>,9truthy: Arc<dyn PhysicalExpr>,10falsy: Arc<dyn PhysicalExpr>,11expr: Expr,12// Can be expensive on small data to run literals in parallel.13run_par: bool,14returns_scalar: bool,15}1617impl TernaryExpr {18pub fn new(19predicate: Arc<dyn PhysicalExpr>,20truthy: Arc<dyn PhysicalExpr>,21falsy: Arc<dyn PhysicalExpr>,22expr: Expr,23run_par: bool,24returns_scalar: bool,25) -> Self {26Self {27predicate,28truthy,29falsy,30expr,31run_par,32returns_scalar,33}34}35}3637fn finish_as_iters<'a>(38mut ac_truthy: AggregationContext<'a>,39mut ac_falsy: AggregationContext<'a>,40mut ac_mask: AggregationContext<'a>,41) -> PolarsResult<AggregationContext<'a>> {42let ca = ac_truthy43.iter_groups(false)44.zip(ac_falsy.iter_groups(false))45.zip(ac_mask.iter_groups(false))46.map(|((truthy, falsy), mask)| {47match (truthy, falsy, mask) {48(Some(truthy), Some(falsy), Some(mask)) => Some(49truthy50.as_ref()51.zip_with(mask.as_ref().bool()?, falsy.as_ref()),52),53_ => None,54}55.transpose()56})57.collect::<PolarsResult<ListChunked>>()?58.with_name(ac_truthy.get_values().name().clone());5960// Aggregation leaves only a single chunk.61let arr = ca.downcast_iter().next().unwrap();62let list_vals_len = arr.values().len();6364let mut out = ca.into_column();65if ac_truthy.arity_should_explode() && ac_falsy.arity_should_explode() && ac_mask.arity_should_explode() &&66// Exploded list should be equal to groups length.67list_vals_len == ac_truthy.groups.len()68{69out = out.explode(ExplodeOptions {70empty_as_null: true,71keep_nulls: true,72})?73}7475ac_truthy.with_agg_state(AggState::AggregatedList(out));76ac_truthy.with_update_groups(UpdateGroups::WithSeriesLen);7778Ok(ac_truthy)79}8081impl PhysicalExpr for TernaryExpr {82fn as_expression(&self) -> Option<&Expr> {83Some(&self.expr)84}8586fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {87let mut state = state.split();88// Don't cache window functions as they run in parallel.89state.remove_cache_window_flag();90let mask_series = self.predicate.evaluate(df, &state)?;91let mask = mask_series.bool()?.clone();9293let op_truthy = || self.truthy.evaluate(df, &state);94let op_falsy = || self.falsy.evaluate(df, &state);95let (truthy, falsy) = if self.run_par {96POOL.install(|| rayon::join(op_truthy, op_falsy))97} else {98(op_truthy(), op_falsy())99};100let truthy = truthy?;101let falsy = falsy?;102103truthy.zip_with(&mask, &falsy)104}105106fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {107self.truthy.to_field(input_schema)108}109110#[allow(clippy::ptr_arg)]111fn evaluate_on_groups<'a>(112&self,113df: &DataFrame,114groups: &'a GroupPositions,115state: &ExecutionState,116) -> PolarsResult<AggregationContext<'a>> {117let op_mask = || self.predicate.evaluate_on_groups(df, groups, state);118let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state);119let op_falsy = || self.falsy.evaluate_on_groups(df, groups, state);120let (ac_mask, (ac_truthy, ac_falsy)) = if self.run_par {121POOL.install(|| rayon::join(op_mask, || rayon::join(op_truthy, op_falsy)))122} else {123(op_mask(), (op_truthy(), op_falsy()))124};125126let mut ac_mask = ac_mask?;127let mut ac_truthy = ac_truthy?;128let mut ac_falsy = ac_falsy?;129130use AggState::*;131132// Check if there are any:133// - non-unit literals134// - AggregatedScalar or AggregatedList135let mut has_non_unit_literal = false;136let mut has_aggregated = false;137// If the length has changed then we must not apply on the flat values138// as ternary broadcasting is length-sensitive.139let mut non_aggregated_len_modified = false;140141for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {142match ac.agg_state() {143LiteralScalar(s) => {144has_non_unit_literal = s.len() != 1;145146if has_non_unit_literal {147break;148}149},150NotAggregated(_) => {151non_aggregated_len_modified |= !ac.original_len;152},153AggregatedScalar(_) | AggregatedList(_) => {154has_aggregated = true;155},156}157}158159if has_non_unit_literal {160// finish_as_iters for non-unit literals to avoid materializing the161// literal inputs per-group.162if state.verbose() {163eprintln!("ternary agg: finish as iters due to non-unit literal")164}165return finish_as_iters(ac_truthy, ac_falsy, ac_mask);166}167168if !has_aggregated && !non_aggregated_len_modified {169// Everything is flat (either NotAggregated or a unit literal).170if state.verbose() {171eprintln!("ternary agg: finish all not-aggregated or unit literal");172}173174let out = ac_truthy175.get_values()176.zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?;177178for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {179if matches!(ac.agg_state(), NotAggregated(_)) {180let ac_target = ac;181182return Ok(AggregationContext {183state: NotAggregated(out),184groups: ac_target.groups.clone(),185update_groups: ac_target.update_groups,186original_len: ac_target.original_len,187});188}189}190191ac_truthy.with_agg_state(LiteralScalar(out));192193return Ok(ac_truthy);194}195196for ac in [&mut ac_mask, &mut ac_truthy, &mut ac_falsy].into_iter() {197if matches!(ac.agg_state(), NotAggregated(_)) {198let _ = ac.aggregated();199}200}201202// At this point the input agg states are one of the following:203// * `Literal` where `s.len() == 1`204// * `AggregatedList`205// * `AggregatedScalar`206207let mut non_literal_acs = Vec::<&AggregationContext>::with_capacity(3);208209// non_literal_acs will have at least 1 item because has_aggregated was210// true from above.211for ac in [&ac_mask, &ac_truthy, &ac_falsy].into_iter() {212if !matches!(ac.agg_state(), LiteralScalar(_)) {213non_literal_acs.push(ac);214}215}216217for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) {218if std::mem::discriminant(ac_l.agg_state()) != std::mem::discriminant(ac_r.agg_state())219{220// Mix of AggregatedScalar and AggregatedList is done per group,221// as every row of the AggregatedScalar must be broadcasted to a222// list of the same length as the corresponding AggregatedList223// row.224if state.verbose() {225eprintln!(226"ternary agg: finish as iters due to mix of AggregatedScalar and AggregatedList"227)228}229return finish_as_iters(ac_truthy, ac_falsy, ac_mask);230}231}232233// At this point, the possible combinations are:234// * mix of unit literals and AggregatedScalar235// * `zip_with` can be called directly with the series236// * mix of unit literals and AggregatedList237// * `zip_with` can be called with the flat values after the offsets238// have been checked for alignment239let ac_target = non_literal_acs.first().unwrap();240241let agg_state_out = match ac_target.agg_state() {242AggregatedList(_) => {243// Ternary can be applied directly on the flattened series,244// given that their offsets have been checked to be equal.245if state.verbose() {246eprintln!("ternary agg: finish AggregatedList")247}248249for (ac_l, ac_r) in non_literal_acs.iter().zip(non_literal_acs.iter().skip(1)) {250match (ac_l.agg_state(), ac_r.agg_state()) {251(AggregatedList(s_l), AggregatedList(s_r)) => {252let check = s_l.list().unwrap().offsets()?.as_slice()253== s_r.list().unwrap().offsets()?.as_slice();254255polars_ensure!(256check,257ShapeMismatch: "shapes of `self`, `mask` and `other` are not suitable for `zip_with` operation"258);259},260_ => unreachable!(),261}262}263264let truthy = if let AggregatedList(s) = ac_truthy.agg_state() {265s.list().unwrap().get_inner().into_column()266} else {267ac_truthy.get_values().clone()268};269270let falsy = if let AggregatedList(s) = ac_falsy.agg_state() {271s.list().unwrap().get_inner().into_column()272} else {273ac_falsy.get_values().clone()274};275276let mask = if let AggregatedList(s) = ac_mask.agg_state() {277s.list().unwrap().get_inner().into_column()278} else {279ac_mask.get_values().clone()280};281282let out = truthy.zip_with(mask.bool()?, &falsy)?;283284// The output series is guaranteed to be aligned with expected285// offsets buffer of the result, so we construct the result286// ListChunked directly from the 2.287let out = out.rechunk();288// @scalar-opt289// @partition-opt290let values = out.as_materialized_series().array_ref(0);291let offsets = ac_target.get_values().list().unwrap().offsets()?;292let inner_type = out.dtype();293let dtype = LargeListArray::default_datatype(values.dtype().clone());294295// SAFETY: offsets are correct.296let out = LargeListArray::new(dtype, offsets, values.clone(), None);297298let mut out = ListChunked::with_chunk(truthy.name().clone(), out);299unsafe { out.to_logical(inner_type.clone()) };300301if ac_target.get_values().list().unwrap()._can_fast_explode() {302out.set_fast_explode();303};304305let out = out.into_column();306307AggregatedList(out)308},309AggregatedScalar(_) => {310if state.verbose() {311eprintln!("ternary agg: finish AggregatedScalar")312}313314let out = ac_truthy315.get_values()316.zip_with(ac_mask.get_values().bool()?, ac_falsy.get_values())?;317AggregatedScalar(out)318},319_ => {320unreachable!()321},322};323324Ok(AggregationContext {325state: agg_state_out,326groups: ac_target.groups.clone(),327update_groups: ac_target.update_groups,328original_len: ac_target.original_len,329})330}331332fn is_scalar(&self) -> bool {333self.returns_scalar334}335}336337338