Path: blob/main/crates/polars-expr/src/expressions/cast.rs
6940 views
use polars_core::chunked_array::cast::CastOptions;1use polars_core::prelude::*;23use super::*;4use crate::expressions::{AggState, AggregationContext, PartitionedAggregation, PhysicalExpr};56pub struct CastExpr {7pub(crate) input: Arc<dyn PhysicalExpr>,8pub(crate) dtype: DataType,9pub(crate) expr: Expr,10pub(crate) options: CastOptions,11}1213impl CastExpr {14fn finish(&self, input: &Column) -> PolarsResult<Column> {15input.cast_with_options(&self.dtype, self.options)16}17}1819impl PhysicalExpr for CastExpr {20fn as_expression(&self) -> Option<&Expr> {21Some(&self.expr)22}2324fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {25let column = self.input.evaluate(df, state)?;26self.finish(&column)27}2829#[allow(clippy::ptr_arg)]30fn evaluate_on_groups<'a>(31&self,32df: &DataFrame,33groups: &'a GroupPositions,34state: &ExecutionState,35) -> PolarsResult<AggregationContext<'a>> {36let mut ac = self.input.evaluate_on_groups(df, groups, state)?;3738match ac.agg_state() {39// this will not explode and potentially increase memory due to overlapping groups40AggState::AggregatedList(s) => {41let ca = s.list().unwrap();42let casted = ca.apply_to_inner(&|s| {43self.finish(&s.into_column())44.map(|c| c.take_materialized_series())45})?;46ac.with_values(casted.into_column(), true, None)?;47},48AggState::AggregatedScalar(s) => {49let s = self.finish(&s.clone().into_column())?;50if ac.is_literal() {51ac.with_literal(s);52} else {53ac.with_values(s, true, None)?;54}55},56AggState::NotAggregated(_) => {57if match self.options {58CastOptions::NonStrict | CastOptions::Overflowing => true,59CastOptions::Strict => ac.original_len,60} {61// before we flatten, make sure that groups are updated62ac.groups();6364let s = ac.flat_naive();65let s = self.finish(&s.as_ref().clone().into_column())?;6667ac.with_values(s, false, None)?;68} else {69// We need to perform aggregation only for strict mode, since if this is not done,70// filtered-out values may incorrectly cause a cast error.71let s = ac.aggregated();72let ca = s.list().unwrap();73let casted = ca.apply_to_inner(&|s| {74self.finish(&s.into_column())75.map(|c| c.take_materialized_series())76})?;77ac.with_values(casted.into_column(), true, None)?;78}79},8081AggState::LiteralScalar(s) => {82let s = self.finish(s)?;83ac.with_literal(s);84},85}8687Ok(ac)88}8990fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {91self.input.to_field(input_schema).map(|mut fld| {92fld.coerce(self.dtype.clone());93fld94})95}9697fn is_scalar(&self) -> bool {98self.input.is_scalar()99}100101fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {102Some(self)103}104}105106impl PartitionedAggregation for CastExpr {107fn evaluate_partitioned(108&self,109df: &DataFrame,110groups: &GroupPositions,111state: &ExecutionState,112) -> PolarsResult<Column> {113let e = self.input.as_partitioned_aggregator().unwrap();114self.finish(&e.evaluate_partitioned(df, groups, state)?)115}116117fn finalize(118&self,119partitioned: Column,120groups: &GroupPositions,121state: &ExecutionState,122) -> PolarsResult<Column> {123let agg = self.input.as_partitioned_aggregator().unwrap();124agg.finalize(partitioned, groups, state)125}126}127128129