Path: blob/main/crates/polars-expr/src/expressions/eval.rs
8416 views
use std::borrow::Cow;1use std::cell::LazyCell;2use std::sync::Arc;34use arrow::bitmap::{Bitmap, BitmapBuilder};5use polars_core::chunked_array::builder::AnonymousOwnedListBuilder;6use polars_core::error::{PolarsResult, feature_gated};7use polars_core::frame::DataFrame;8#[cfg(feature = "dtype-array")]9use polars_core::prelude::ArrayChunked;10use polars_core::prelude::{11ChunkCast, ChunkExplode, ChunkNestingUtils, Column, Field, GroupPositions, GroupsType, IdxCa,12IntoColumn, ListBuilderTrait, ListChunked,13};14use polars_core::schema::Schema;15use polars_core::series::Series;16use polars_plan::dsl::{EvalVariant, Expr};17use polars_utils::IdxSize;18use polars_utils::pl_str::PlSmallStr;1920use super::{AggregationContext, PhysicalExpr};21use crate::state::ExecutionState;2223#[derive(Clone)]24pub struct EvalExpr {25input: Arc<dyn PhysicalExpr>,26evaluation: Arc<dyn PhysicalExpr>,27variant: EvalVariant,28expr: Expr,29output_field: Field,30is_scalar: bool,31evaluation_is_scalar: bool,32evaluation_is_elementwise: bool,33evaluation_is_fallible: bool,34}3536impl EvalExpr {37#[allow(clippy::too_many_arguments)]38pub(crate) fn new(39input: Arc<dyn PhysicalExpr>,40evaluation: Arc<dyn PhysicalExpr>,41variant: EvalVariant,42expr: Expr,43output_field: Field,44is_scalar: bool,45evaluation_is_scalar: bool,46evaluation_is_elementwise: bool,47evaluation_is_fallible: bool,48) -> Self {49Self {50input,51evaluation,52variant,53expr,54output_field,55is_scalar,56evaluation_is_scalar,57evaluation_is_elementwise,58evaluation_is_fallible,59}60}6162fn evaluate_on_list_chunked(63&self,64ca: &ListChunked,65state: &ExecutionState,66is_agg: bool,67) -> PolarsResult<Column> {68let df = DataFrame::empty_with_height(ca.len());69let ca = ca70.trim_lists_to_normalized_offsets()71.map_or(Cow::Borrowed(ca), Cow::Owned);7273// Fast path: Empty or only nulls.74if ca.null_count() == ca.len() {75let name = self.output_field.name.clone();76return Ok(Column::full_null(name, ca.len(), self.output_field.dtype()));77}7879let has_masked_out_values = LazyCell::new(|| ca.has_masked_out_values());80let may_fail_on_masked_out_elements = self.evaluation_is_fallible && *has_masked_out_values;8182let flattened = ca.get_inner().into_column();83let flattened_len = flattened.len();84let validity = ca.rechunk_validity();8586// Fast path: fully elementwise expression without masked out values.87if self.evaluation_is_elementwise && !may_fail_on_masked_out_elements {88let mut state = state.clone();89state.element = Arc::new(Some((flattened, validity.clone())));90let mut column = self.evaluation.evaluate(&df, &state)?;9192// Since `lit` is marked as elementwise, this may lead to problems.93if column.len() == 1 && flattened_len != 1 {94column = column.new_from_index(0, flattened_len);95}9697if !is_agg || !self.evaluation_is_scalar {98column = ca99.with_inner_values(column.as_materialized_series())100.into_column();101}102103return Ok(column);104}105106let offsets = ca.offsets()?;107// Detect accidental inclusion of sliced-out elements from chunks after the 1st (if present).108assert_eq!(i64::try_from(flattened_len).unwrap(), *offsets.last());109110// Create groups for all valid array elements.111let groups = if ca.has_nulls() {112let validity = validity.as_ref().unwrap();113offsets114.offset_and_length_iter()115.zip(validity.iter())116.filter_map(|((offset, length), validity)| {117validity.then_some([offset as IdxSize, length as IdxSize])118})119.collect()120} else {121offsets122.offset_and_length_iter()123.map(|(offset, length)| [offset as IdxSize, length as IdxSize])124.collect()125};126let groups = GroupsType::new_slice(groups, false, true);127let groups = Cow::Owned(groups.into_sliceable());128129let mut state = state.clone();130state.element = Arc::new(Some((flattened, validity.clone())));131132let mut ac = self.evaluation.evaluate_on_groups(&df, &groups, &state)?;133134ac.groups(); // Update the groups.135136let flat_naive = ac.flat_naive();137138// Fast path. Groups are pointing to the same offsets in the data buffer.139if flat_naive.len() == flattened_len140&& let Some(output_groups) = ac.groups.as_ref().as_unrolled_slice()141&& !(is_agg && self.evaluation_is_scalar)142{143let groups_are_unchanged = if let Some(validity) = &validity {144assert_eq!(validity.set_bits(), output_groups.len());145validity146.true_idx_iter()147.zip(output_groups)148.all(|(j, [start, len])| {149let (original_start, original_end) =150unsafe { offsets.start_end_unchecked(j) };151(*start == original_start as IdxSize)152& (*len == (original_end - original_start) as IdxSize)153})154} else {155output_groups156.iter()157.zip(offsets.offset_and_length_iter())158.all(|([start, len], (original_start, original_len))| {159(*start == original_start as IdxSize) & (*len == original_len as IdxSize)160})161};162163if groups_are_unchanged {164let values = flat_naive.as_materialized_series();165return Ok(ca.with_inner_values(values).into_column());166}167}168169// Slow path. Groups have changed, so we need to gather data again.170if is_agg && self.evaluation_is_scalar {171let mut values = ac.finalize();172173// We didn't have any groups for the `null` values so we have to reinsert them.174if let Some(validity) = validity {175values = values.deposit(&validity);176}177178Ok(values)179} else {180let mut ca = ac.aggregated_as_list();181182// We didn't have any groups for the `null` values so we have to reinsert them.183if let Some(validity) = validity {184ca = Cow::Owned(ca.deposit(&validity));185}186187Ok(ca.into_owned().into_column())188}189}190191#[cfg(feature = "dtype-array")]192fn evaluate_on_array_chunked(193&self,194ca: &ArrayChunked,195state: &ExecutionState,196as_list: bool,197is_agg: bool,198) -> PolarsResult<Column> {199let df = DataFrame::empty_with_height(ca.len());200let ca = ca201.trim_lists_to_normalized_offsets()202.map_or(Cow::Borrowed(ca), Cow::Owned);203204// Fast path: Empty or only nulls.205if ca.null_count() == ca.len() {206let name = self.output_field.name.clone();207return Ok(Column::full_null(name, ca.len(), self.output_field.dtype()));208}209210let flattened = ca.get_inner().into_column();211let flattened_len = flattened.len();212let validity = ca.rechunk_validity();213214let may_fail_on_masked_out_elements = self.evaluation_is_fallible && ca.has_nulls();215216// Fast path: fully elementwise expression without masked out values.217if self.evaluation_is_elementwise && !may_fail_on_masked_out_elements {218assert!(!self.evaluation_is_scalar);219220let mut state = state.clone();221state.element = Arc::new(Some((flattened, None)));222223let mut column = self.evaluation.evaluate(&df, &state)?;224if column.len() == 1 && flattened_len != 1 {225column = column.new_from_index(0, flattened_len);226}227assert_eq!(column.len(), ca.len() * ca.width());228229let dtype = column.dtype().clone();230let mut out = ArrayChunked::from_aligned_values(231self.output_field.name.clone(),232&dtype,233ca.width(),234column.take_materialized_series().into_chunks(),235ca.len(),236);237238if let Some(validity) = validity {239out.set_validity(&validity);240}241242return Ok(if as_list {243out.to_list().into_column()244} else {245out.clone().into_column()246});247}248249assert_eq!(flattened_len, ca.width() * ca.len());250251// Create groups for all valid array elements.252let groups = if ca.has_nulls() {253let validity = validity.as_ref().unwrap();254(0..ca.len())255.filter(|i| unsafe { validity.get_bit_unchecked(*i) })256.map(|i| [(i * ca.width()) as IdxSize, ca.width() as IdxSize])257.collect()258} else {259(0..ca.len())260.map(|i| [(i * ca.width()) as IdxSize, ca.width() as IdxSize])261.collect()262};263let groups = GroupsType::new_slice(groups, false, true);264let groups = Cow::Owned(groups.into_sliceable());265266let mut state = state.clone();267state.element = Arc::new(Some((flattened, validity.clone())));268269let mut ac = self.evaluation.evaluate_on_groups(&df, &groups, &state)?;270271ac.groups(); // Update the groups.272273let flat_naive = ac.flat_naive();274275// Fast path. Groups are pointing to the same offsets in the data buffer.276if flat_naive.len() == ca.len() * ca.width()277&& let Some(output_groups) = ac.groups.as_ref().as_unrolled_slice()278&& !(is_agg && self.evaluation_is_scalar)279{280let ca_width = ca.width() as IdxSize;281let groups_are_unchanged = if let Some(validity) = &validity {282assert_eq!(validity.set_bits(), output_groups.len());283validity284.true_idx_iter()285.zip(output_groups)286.all(|(j, [start, len])| {287(*start == j as IdxSize * ca_width) & (*len == ca_width)288})289} else {290use polars_utils::itertools::Itertools;291292output_groups293.iter()294.enumerate_idx()295.all(|(i, [start, len])| (*start == i * ca_width) & (*len == ca_width))296};297298if groups_are_unchanged {299let values = flat_naive;300let dtype = values.dtype().clone();301let mut out = ArrayChunked::from_aligned_values(302self.output_field.name.clone(),303&dtype,304ca.width(),305values.as_materialized_series().chunks().clone(),306ca.len(),307);308309if let Some(validity) = validity {310out.set_validity(&validity);311}312313return Ok(if as_list {314out.to_list().into_column()315} else {316out.into_column()317});318}319}320321// Slow path. Groups have changed, so we need to gather data again.322if is_agg && self.evaluation_is_scalar {323let mut values = ac.finalize();324325// We didn't have any groups for the `null` values so we have to reinsert them.326if let Some(validity) = validity {327values = values.deposit(&validity);328}329330Ok(values)331} else {332let mut ca = ac.aggregated_as_list();333334// We didn't have any groups for the `null` values so we have to reinsert them.335if let Some(validity) = validity {336ca = Cow::Owned(ca.deposit(&validity));337}338339Ok(if as_list {340ca.into_owned().into_column()341} else {342ca.cast(self.output_field.dtype()).unwrap().into_column()343})344}345}346347fn evaluate_cumulative_eval(348&self,349input: &Series,350min_samples: usize,351state: &ExecutionState,352) -> PolarsResult<Column> {353if input.is_empty() {354return Ok(Column::new_empty(355self.output_field.name().clone(),356self.output_field.dtype(),357));358}359360let flattened = input.clone().into_column();361let validity = input.rechunk_validity();362363let mut deposit: Option<Bitmap> = None;364365let groups = if min_samples == 0 {366(1..input.len() as IdxSize).map(|i| [0, i]).collect()367} else {368let validity = validity369.clone()370.unwrap_or_else(|| Bitmap::new_with_value(true, input.len()));371let mut count = 0;372let mut deposit_builder = BitmapBuilder::with_capacity(input.len());373let out = (0..input.len() as IdxSize)374.filter(|i| {375count += usize::from(unsafe { validity.get_bit_unchecked(*i as usize) });376let is_selected = count >= min_samples;377unsafe { deposit_builder.push_unchecked(is_selected) };378is_selected379})380.map(|i| [0, i + 1])381.collect();382deposit = Some(deposit_builder.freeze());383out384};385386let groups = GroupsType::new_slice(groups, true, true);387388let groups = groups.into_sliceable();389390let df = DataFrame::empty_with_height(input.len());391392let mut state = state.clone();393state.element = Arc::new(Some((flattened, validity)));394395let agg = self.evaluation.evaluate_on_groups(&df, &groups, &state)?;396let (mut out, _) = agg.get_final_aggregation();397398// Since we only evaluated the expressions on the items that satisfied the min samples, we399// need to fix it up here again.400if let Some(deposit) = deposit {401let mut i = 0;402let gather_idxs = deposit403.iter()404.map(|v| {405let out = i;406i += IdxSize::from(v);407out408})409.collect::<Vec<IdxSize>>();410let gather_idxs =411IdxCa::from_vec_validity(PlSmallStr::EMPTY, gather_idxs, Some(deposit));412out = unsafe { out.take_unchecked(&gather_idxs) };413}414415Ok(out)416}417}418419impl PhysicalExpr for EvalExpr {420fn as_expression(&self) -> Option<&Expr> {421Some(&self.expr)422}423424fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {425let input = self.input.evaluate(df, state)?;426match self.variant {427EvalVariant::List => {428let lst = input.list()?;429self.evaluate_on_list_chunked(lst, state, false)430},431EvalVariant::ListAgg => {432let lst = input.list()?;433self.evaluate_on_list_chunked(lst, state, true)434},435EvalVariant::Array { as_list } => feature_gated!("dtype-array", {436let arr = input.array()?;437self.evaluate_on_array_chunked(arr, state, as_list, false)438}),439EvalVariant::ArrayAgg => feature_gated!("dtype-array", {440let arr = input.array()?;441self.evaluate_on_array_chunked(arr, state, true, true)442}),443EvalVariant::Cumulative { min_samples } => {444self.evaluate_cumulative_eval(input.as_materialized_series(), min_samples, state)445},446}447}448449fn evaluate_on_groups<'a>(450&self,451df: &DataFrame,452groups: &'a GroupPositions,453state: &ExecutionState,454) -> PolarsResult<AggregationContext<'a>> {455let mut input = self.input.evaluate_on_groups(df, groups, state)?;456input.groups();457458match self.variant {459EvalVariant::List => {460let input_col = input.flat_naive();461let out = self.evaluate_on_list_chunked(input_col.list()?, state, false)?;462input.with_values(out, false, Some(&self.expr))?;463},464EvalVariant::ListAgg => {465let input_col = input.flat_naive();466let out = self.evaluate_on_list_chunked(input_col.list()?, state, true)?;467input.with_values(out, false, Some(&self.expr))?;468},469EvalVariant::Array { as_list } => feature_gated!("dtype-array", {470let arr_col = input.flat_naive();471let out =472self.evaluate_on_array_chunked(arr_col.array()?, state, as_list, false)?;473input.with_values(out, false, Some(&self.expr))?;474}),475EvalVariant::ArrayAgg => feature_gated!("dtype-array", {476let arr_col = input.flat_naive();477let out = self.evaluate_on_array_chunked(arr_col.array()?, state, true, true)?;478input.with_values(out, false, Some(&self.expr))?;479}),480EvalVariant::Cumulative { min_samples } => {481let mut builder = AnonymousOwnedListBuilder::new(482self.output_field.name().clone(),483input.groups().len(),484Some(self.output_field.dtype.clone()),485);486for group in input.iter_groups(false) {487match group {488None => {},489Some(group) => {490let out =491self.evaluate_cumulative_eval(group.as_ref(), min_samples, state)?;492builder.append_series(out.as_materialized_series())?;493},494}495}496497input.with_values(builder.finish().into_column(), true, Some(&self.expr))?;498},499}500Ok(input)501}502503fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {504Ok(self.output_field.clone())505}506507fn is_scalar(&self) -> bool {508self.is_scalar509}510}511512513