Path: blob/main/crates/polars-expr/src/expressions/eval.rs
6940 views
use std::borrow::Cow;1use std::sync::{Arc, Mutex};23use arrow::array::{Array, ListArray};4use polars_core::POOL;5use polars_core::chunked_array::builder::AnonymousOwnedListBuilder;6use polars_core::chunked_array::from_iterator_par::ChunkedCollectParIterExt;7use polars_core::error::{PolarsResult, polars_ensure};8use polars_core::frame::DataFrame;9use polars_core::prelude::{10AnyValue, ChunkCast, ChunkNestingUtils, Column, CompatLevel, DataType, Field, GroupPositions,11GroupsType, IntoColumn, ListBuilderTrait, ListChunked,12};13use polars_core::schema::Schema;14use polars_core::series::Series;15use polars_core::utils::CustomIterTools;16use polars_plan::dsl::{EvalVariant, Expr};17use polars_plan::plans::ExprPushdownGroup;18use polars_utils::IdxSize;19use polars_utils::pl_str::PlSmallStr;20use rayon::iter::{IntoParallelIterator, ParallelIterator};2122use super::{AggState, AggregationContext, PhysicalExpr};23use crate::state::ExecutionState;2425#[derive(Clone)]26pub struct EvalExpr {27input: Arc<dyn PhysicalExpr>,28evaluation: Arc<dyn PhysicalExpr>,29variant: EvalVariant,30expr: Expr,31allow_threading: bool,32// `output_field_with_ctx`` accounts for the aggregation context, if any33// It will 'auto-implode/expplode' if needed.34output_field_with_ctx: Field,35// `non_aggregated_output_dtype`` ignores any aggregation context36non_aggregated_output_dtype: DataType,37is_scalar: bool,38pd_group: ExprPushdownGroup,39evaluation_is_scalar: bool,40}4142fn offsets_to_groups(offsets: &[i64]) -> Option<GroupPositions> {43let mut start = offsets[0];44let end = *offsets.last().unwrap();45if IdxSize::try_from(end - start).is_err() {46return None;47}48let groups = offsets49.iter()50.skip(1)51.map(|end| {52let offset = start as IdxSize;53let len = (*end - start) as IdxSize;54start = *end;55[offset, len]56})57.collect();58Some(59GroupsType::Slice {60groups,61rolling: false,62}63.into_sliceable(),64)65}6667impl EvalExpr {68#[allow(clippy::too_many_arguments)]69pub(crate) fn new(70input: Arc<dyn PhysicalExpr>,71evaluation: Arc<dyn PhysicalExpr>,72variant: EvalVariant,73expr: Expr,74allow_threading: bool,75output_field_with_ctx: Field,76non_aggregated_output_dtype: DataType,77is_scalar: bool,78pd_group: ExprPushdownGroup,79evaluation_is_scalar: bool,80) -> Self {81Self {82input,83evaluation,84variant,85expr,86allow_threading,87output_field_with_ctx,88non_aggregated_output_dtype,89is_scalar,90pd_group,91evaluation_is_scalar,92}93}9495fn run_elementwise_on_values(96&self,97lst: &ListChunked,98state: &ExecutionState,99) -> PolarsResult<Column> {100if lst.chunks().is_empty() {101return Ok(Column::new_empty(102self.output_field_with_ctx.name.clone(),103&self.non_aggregated_output_dtype,104));105}106107let lst = lst108.trim_lists_to_normalized_offsets()109.map_or(Cow::Borrowed(lst), Cow::Owned);110111let output_arrow_dtype = self112.non_aggregated_output_dtype113.clone()114.to_arrow(CompatLevel::newest());115let output_arrow_dtype_physical = output_arrow_dtype.underlying_physical_type();116117let apply_to_chunk = |arr: &dyn Array| {118let arr: &ListArray<i64> = arr.as_any().downcast_ref().unwrap();119120let values = unsafe {121Series::from_chunks_and_dtype_unchecked(122PlSmallStr::EMPTY,123vec![arr.values().clone()],124lst.inner_dtype(),125)126};127128let df = values.into_frame();129130self.evaluation.evaluate(&df, state).map(|values| {131let values = values.take_materialized_series().rechunk().chunks()[0].clone();132133ListArray::<i64>::new(134output_arrow_dtype_physical.clone(),135arr.offsets().clone(),136values,137arr.validity().cloned(),138)139.boxed()140})141};142143let chunks = if self.allow_threading && lst.chunks().len() > 1 {144POOL.install(|| {145lst.chunks()146.into_par_iter()147.map(|x| apply_to_chunk(&**x))148.collect::<PolarsResult<Vec<Box<dyn Array>>>>()149})?150} else {151lst.chunks()152.iter()153.map(|x| apply_to_chunk(&**x))154.collect::<PolarsResult<Vec<Box<dyn Array>>>>()?155};156157Ok(unsafe {158ListChunked::from_chunks(self.output_field_with_ctx.name.clone(), chunks)159.cast_unchecked(&self.non_aggregated_output_dtype)160.unwrap()161}162.into_column())163}164165fn run_per_sublist(&self, lst: &ListChunked, state: &ExecutionState) -> PolarsResult<Column> {166let mut err = None;167let mut ca: ListChunked = if self.allow_threading {168let m_err = Mutex::new(None);169let ca: ListChunked = POOL.install(|| {170lst.par_iter()171.map(|opt_s| {172opt_s.and_then(|s| {173let df = s.into_frame();174let out = self.evaluation.evaluate(&df, state);175match out {176Ok(s) => Some(s.take_materialized_series()),177Err(e) => {178*m_err.lock().unwrap() = Some(e);179None180},181}182})183})184.collect_ca_with_dtype(185PlSmallStr::EMPTY,186self.non_aggregated_output_dtype.clone(),187)188});189err = m_err.into_inner().unwrap();190ca191} else {192let mut df_container = DataFrame::empty();193194lst.into_iter()195.map(|s| {196s.and_then(|s| unsafe {197df_container.with_column_unchecked(s.into_column());198let out = self.evaluation.evaluate(&df_container, state);199df_container.clear_columns();200match out {201Ok(s) => Some(s.take_materialized_series()),202Err(e) => {203err = Some(e);204None205},206}207})208})209.collect_trusted()210};211if let Some(err) = err {212return Err(err);213}214215ca.rename(lst.name().clone());216217// Cast may still be required in some cases, e.g. for an empty frame when running single-threaded218if ca.dtype() != &self.non_aggregated_output_dtype {219ca.cast(&self.non_aggregated_output_dtype).map(Column::from)220} else {221Ok(ca.into_column())222}223}224225fn run_on_group_by_engine(226&self,227lst: &ListChunked,228state: &ExecutionState,229) -> PolarsResult<Column> {230let lst = lst.rechunk();231let arr = lst.downcast_as_array();232let groups = offsets_to_groups(arr.offsets()).unwrap();233234// List elements in a series.235let values = Series::try_from((PlSmallStr::EMPTY, arr.values().clone())).unwrap();236let inner_dtype = lst.inner_dtype();237// SAFETY:238// Invariant in List means values physicals can be cast to inner dtype239let values = unsafe { values.from_physical_unchecked(inner_dtype).unwrap() };240241let df_context = values.into_frame();242243let mut ac = self244.evaluation245.evaluate_on_groups(&df_context, &groups, state)?;246let out = match ac.agg_state() {247AggState::AggregatedScalar(_) => {248let out = ac.aggregated();249out.as_list().into_column()250},251_ => ac.aggregated(),252};253Ok(out254.with_name(self.output_field_with_ctx.name.clone())255.into_column())256}257258fn evaluate_on_list_chunked(259&self,260lst: &ListChunked,261state: &ExecutionState,262) -> PolarsResult<Column> {263let fits_idx_size = lst.get_inner().len() < (IdxSize::MAX as usize);264if match self.pd_group {265ExprPushdownGroup::Pushable => true,266ExprPushdownGroup::Fallible => !lst.has_nulls(),267ExprPushdownGroup::Barrier => false,268} && !self.evaluation_is_scalar269{270self.run_elementwise_on_values(lst, state)271} else if fits_idx_size && lst.null_count() == 0 && self.evaluation_is_scalar {272self.run_on_group_by_engine(lst, state)273} else {274self.run_per_sublist(lst, state)275}276}277278fn evaluate_cumulative_eval(279&self,280input: &Series,281min_samples: usize,282state: &ExecutionState,283) -> PolarsResult<Series> {284let finish = |out: Series| {285polars_ensure!(286out.len() <= 1,287ComputeError:288"expected single value, got a result with length {}, {:?}",289out.len(), out,290);291Ok(out.get(0).unwrap().into_static())292};293294let input = input.clone().with_name(PlSmallStr::EMPTY);295let avs = if self.allow_threading {296POOL.install(|| {297(1..input.len() + 1)298.into_par_iter()299.map(|len| {300let c = input.slice(0, len);301if (len - c.null_count()) >= min_samples {302let df = c.into_frame();303let out = self304.evaluation305.evaluate(&df, state)?306.take_materialized_series();307finish(out)308} else {309Ok(AnyValue::Null)310}311})312.collect::<PolarsResult<Vec<_>>>()313})?314} else {315let mut df_container = DataFrame::empty();316(1..input.len() + 1)317.map(|len| {318let c = input.slice(0, len);319if (len - c.null_count()) >= min_samples {320unsafe {321df_container.with_column_unchecked(c.into_column());322let out = self323.evaluation324.evaluate(&df_container, state)?325.take_materialized_series();326df_container.clear_columns();327finish(out)328}329} else {330Ok(AnyValue::Null)331}332})333.collect::<PolarsResult<Vec<_>>>()?334};335336Series::from_any_values_and_dtype(337self.output_field_with_ctx.name().clone(),338&avs,339&self.non_aggregated_output_dtype,340true,341)342}343}344345impl PhysicalExpr for EvalExpr {346fn as_expression(&self) -> Option<&Expr> {347Some(&self.expr)348}349350fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {351let input = self.input.evaluate(df, state)?;352match self.variant {353EvalVariant::List => {354let lst = input.list()?;355self.evaluate_on_list_chunked(lst, state)356},357EvalVariant::Cumulative { min_samples } => self358.evaluate_cumulative_eval(input.as_materialized_series(), min_samples, state)359.map(Column::from),360}361}362363fn evaluate_on_groups<'a>(364&self,365df: &DataFrame,366groups: &'a GroupPositions,367state: &ExecutionState,368) -> PolarsResult<AggregationContext<'a>> {369let mut input = self.input.evaluate_on_groups(df, groups, state)?;370match self.variant {371EvalVariant::List => {372let out = self.evaluate_on_list_chunked(input.get_values().list()?, state)?;373input.with_values(out, false, Some(&self.expr))?;374},375EvalVariant::Cumulative { min_samples } => {376let mut builder = AnonymousOwnedListBuilder::new(377self.output_field_with_ctx.name().clone(),378input.groups().len(),379Some(self.non_aggregated_output_dtype.clone()),380);381for group in input.iter_groups(false) {382match group {383None => {},384Some(group) => {385let out =386self.evaluate_cumulative_eval(group.as_ref(), min_samples, state)?;387builder.append_series(&out)?;388},389}390}391392input.with_values(builder.finish().into_column(), true, Some(&self.expr))?;393},394}395Ok(input)396}397398fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {399Ok(self.output_field_with_ctx.clone())400}401402fn is_scalar(&self) -> bool {403self.is_scalar404}405}406407408