Path: blob/main/crates/polars-expr/src/expressions/binary.rs
6940 views
use polars_core::POOL;1use polars_core::prelude::*;2#[cfg(feature = "round_series")]3use polars_ops::prelude::floor_div_series;45use super::*;6use crate::expressions::{7AggState, AggregationContext, PartitionedAggregation, PhysicalExpr, UpdateGroups,8};910#[derive(Clone)]11pub struct BinaryExpr {12left: Arc<dyn PhysicalExpr>,13op: Operator,14right: Arc<dyn PhysicalExpr>,15expr: Expr,16has_literal: bool,17allow_threading: bool,18is_scalar: bool,19output_field: Field,20}2122impl BinaryExpr {23#[expect(clippy::too_many_arguments)]24pub fn new(25left: Arc<dyn PhysicalExpr>,26op: Operator,27right: Arc<dyn PhysicalExpr>,28expr: Expr,29has_literal: bool,30allow_threading: bool,31is_scalar: bool,32output_field: Field,33) -> Self {34Self {35left,36op,37right,38expr,39has_literal,40allow_threading,41is_scalar,42output_field,43}44}45}4647/// Can partially do operations in place.48fn apply_operator_owned(left: Column, right: Column, op: Operator) -> PolarsResult<Column> {49match op {50Operator::Plus => left.try_add_owned(right),51Operator::Minus => left.try_sub_owned(right),52Operator::Multiply53if left.dtype().is_primitive_numeric() && right.dtype().is_primitive_numeric() =>54{55left.try_mul_owned(right)56},57_ => apply_operator(&left, &right, op),58}59}6061pub fn apply_operator(left: &Column, right: &Column, op: Operator) -> PolarsResult<Column> {62use DataType::*;63match op {64Operator::Gt => ChunkCompareIneq::gt(left, right).map(|ca| ca.into_column()),65Operator::GtEq => ChunkCompareIneq::gt_eq(left, right).map(|ca| ca.into_column()),66Operator::Lt => ChunkCompareIneq::lt(left, right).map(|ca| ca.into_column()),67Operator::LtEq => ChunkCompareIneq::lt_eq(left, right).map(|ca| ca.into_column()),68Operator::Eq => ChunkCompareEq::equal(left, right).map(|ca| ca.into_column()),69Operator::NotEq => ChunkCompareEq::not_equal(left, right).map(|ca| ca.into_column()),70Operator::Plus => left + right,71Operator::Minus => left - right,72Operator::Multiply => left * right,73Operator::Divide => left / right,74Operator::TrueDivide => match left.dtype() {75#[cfg(feature = "dtype-decimal")]76Decimal(_, _) => left / right,77Duration(_) | Date | Datetime(_, _) | Float32 | Float64 => left / right,78#[cfg(feature = "dtype-array")]79Array(..) => left / right,80#[cfg(feature = "dtype-array")]81_ if right.dtype().is_array() => left / right,82List(_) => left / right,83_ if right.dtype().is_list() => left / right,84_ if left.dtype().is_string() || right.dtype().is_string() => {85polars_bail!(InvalidOperation: "cannot divide using strings")86},87_ => {88if right.dtype().is_temporal() {89return left / right;90}91left.cast(&Float64)? / right.cast(&Float64)?92},93},94Operator::FloorDivide => {95#[cfg(feature = "round_series")]96{97floor_div_series(98left.as_materialized_series(),99right.as_materialized_series(),100)101.map(Column::from)102}103#[cfg(not(feature = "round_series"))]104{105panic!("activate 'round_series' feature")106}107},108Operator::And => left.bitand(right),109Operator::Or => left.bitor(right),110Operator::LogicalOr => left111.cast(&DataType::Boolean)?112.bitor(&right.cast(&DataType::Boolean)?),113Operator::LogicalAnd => left114.cast(&DataType::Boolean)?115.bitand(&right.cast(&DataType::Boolean)?),116Operator::Xor => left.bitxor(right),117Operator::Modulus => left % right,118Operator::EqValidity => left.equal_missing(right).map(|ca| ca.into_column()),119Operator::NotEqValidity => left.not_equal_missing(right).map(|ca| ca.into_column()),120}121}122123impl BinaryExpr {124fn apply_elementwise<'a>(125&self,126mut ac_l: AggregationContext<'a>,127ac_r: AggregationContext,128aggregated: bool,129) -> PolarsResult<AggregationContext<'a>> {130// We want to be able to mutate in place, so we take the lhs to make sure that we drop.131let lhs = ac_l.get_values().clone();132let rhs = ac_r.get_values().clone();133134// Drop lhs so that we might operate in place.135drop(ac_l.take());136137let out = apply_operator_owned(lhs, rhs, self.op)?;138ac_l.with_values(out, aggregated, Some(&self.expr))?;139Ok(ac_l)140}141142fn apply_all_literal<'a>(143&self,144mut ac_l: AggregationContext<'a>,145mut ac_r: AggregationContext<'a>,146) -> PolarsResult<AggregationContext<'a>> {147let name = self.output_field.name().clone();148ac_l.groups();149ac_r.groups();150polars_ensure!(ac_l.groups.len() == ac_r.groups.len(), ComputeError: "lhs and rhs should have same group length");151let left_c = ac_l.get_values().rechunk().into_column();152let right_c = ac_r.get_values().rechunk().into_column();153let res_c = apply_operator(&left_c, &right_c, self.op)?;154ac_l.with_update_groups(UpdateGroups::WithSeriesLen);155let res_s = if res_c.len() == 1 {156res_c.new_from_index(0, ac_l.groups.len())157} else {158ListChunked::full(name, res_c.as_materialized_series(), ac_l.groups.len()).into_column()159};160ac_l.with_values(res_s, true, Some(&self.expr))?;161Ok(ac_l)162}163164fn apply_group_aware<'a>(165&self,166mut ac_l: AggregationContext<'a>,167mut ac_r: AggregationContext<'a>,168) -> PolarsResult<AggregationContext<'a>> {169let name = self.output_field.name().clone();170let mut ca = ac_l171.iter_groups(false)172.zip(ac_r.iter_groups(false))173.map(|(l, r)| {174Some(apply_operator(175&l?.as_ref().clone().into_column(),176&r?.as_ref().clone().into_column(),177self.op,178))179})180.map(|opt_res| opt_res.transpose())181.collect::<PolarsResult<ListChunked>>()?182.with_name(name.clone());183if ca.is_empty() {184ca = ListChunked::full_null_with_dtype(name, 0, self.output_field.dtype());185}186187ac_l.with_update_groups(UpdateGroups::WithSeriesLen);188ac_l.with_agg_state(AggState::AggregatedList(ca.into_column()));189Ok(ac_l)190}191}192193impl PhysicalExpr for BinaryExpr {194fn as_expression(&self) -> Option<&Expr> {195Some(&self.expr)196}197198fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {199// Window functions may set a global state that determine their output200// state, so we don't let them run in parallel as they race201// they also saturate the thread pool by themselves, so that's fine.202let has_window = state.has_window();203204let (lhs, rhs);205if has_window {206let mut state = state.split();207state.remove_cache_window_flag();208lhs = self.left.evaluate(df, &state)?;209rhs = self.right.evaluate(df, &state)?;210} else if !self.allow_threading || self.has_literal {211// Literals are free, don't pay par cost.212lhs = self.left.evaluate(df, state)?;213rhs = self.right.evaluate(df, state)?;214} else {215let (opt_lhs, opt_rhs) = POOL.install(|| {216rayon::join(217|| self.left.evaluate(df, state),218|| self.right.evaluate(df, state),219)220});221(lhs, rhs) = (opt_lhs?, opt_rhs?);222};223polars_ensure!(224lhs.len() == rhs.len() || lhs.len() == 1 || rhs.len() == 1,225expr = self.expr,226ShapeMismatch: "cannot evaluate two Series of different lengths ({} and {})",227lhs.len(), rhs.len(),228);229apply_operator_owned(lhs, rhs, self.op)230}231232#[allow(clippy::ptr_arg)]233fn evaluate_on_groups<'a>(234&self,235df: &DataFrame,236groups: &'a GroupPositions,237state: &ExecutionState,238) -> PolarsResult<AggregationContext<'a>> {239let (result_a, result_b) = POOL.install(|| {240rayon::join(241|| self.left.evaluate_on_groups(df, groups, state),242|| self.right.evaluate_on_groups(df, groups, state),243)244});245let mut ac_l = result_a?;246let ac_r = result_b?;247248match (ac_l.agg_state(), ac_r.agg_state()) {249(AggState::LiteralScalar(s), AggState::NotAggregated(_))250| (AggState::NotAggregated(_), AggState::LiteralScalar(s)) => match s.len() {2511 => self.apply_elementwise(ac_l, ac_r, false),252_ => self.apply_group_aware(ac_l, ac_r),253},254(AggState::LiteralScalar(_), AggState::LiteralScalar(_)) => {255self.apply_all_literal(ac_l, ac_r)256},257(AggState::NotAggregated(_), AggState::NotAggregated(_)) => {258self.apply_elementwise(ac_l, ac_r, false)259},260(261AggState::AggregatedScalar(_) | AggState::LiteralScalar(_),262AggState::AggregatedScalar(_) | AggState::LiteralScalar(_),263) => self.apply_elementwise(ac_l, ac_r, true),264(AggState::AggregatedScalar(_), AggState::NotAggregated(_))265| (AggState::NotAggregated(_), AggState::AggregatedScalar(_)) => {266self.apply_group_aware(ac_l, ac_r)267},268(AggState::AggregatedList(lhs), AggState::AggregatedList(rhs)) => {269let lhs = lhs.list().unwrap();270let rhs = rhs.list().unwrap();271let out = lhs.apply_to_inner(&|lhs| {272apply_operator(&lhs.into_column(), &rhs.get_inner().into_column(), self.op)273.map(|c| c.take_materialized_series())274})?;275ac_l.with_values(out.into_column(), true, Some(&self.expr))?;276Ok(ac_l)277},278_ => self.apply_group_aware(ac_l, ac_r),279}280}281282fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {283self.expr.to_field(input_schema)284}285286fn is_scalar(&self) -> bool {287self.is_scalar288}289290fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {291Some(self)292}293}294295impl PartitionedAggregation for BinaryExpr {296fn evaluate_partitioned(297&self,298df: &DataFrame,299groups: &GroupPositions,300state: &ExecutionState,301) -> PolarsResult<Column> {302let left = self.left.as_partitioned_aggregator().unwrap();303let right = self.right.as_partitioned_aggregator().unwrap();304let left = left.evaluate_partitioned(df, groups, state)?;305let right = right.evaluate_partitioned(df, groups, state)?;306apply_operator(&left, &right, self.op)307}308309fn finalize(310&self,311partitioned: Column,312_groups: &GroupPositions,313_state: &ExecutionState,314) -> PolarsResult<Column> {315Ok(partitioned)316}317}318319320