Path: blob/main/crates/polars-expr/src/expressions/binary.rs
8406 views
use polars_core::POOL;1use polars_core::prelude::*;2#[cfg(feature = "round_series")]3use polars_ops::prelude::floor_div_series;45use super::*;6use crate::expressions::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};78#[derive(Clone)]9pub struct BinaryExpr {10left: Arc<dyn PhysicalExpr>,11op: Operator,12right: Arc<dyn PhysicalExpr>,13expr: Expr,14has_literal: bool,15allow_threading: bool,16is_scalar: bool,17output_field: Field,18}1920impl BinaryExpr {21#[expect(clippy::too_many_arguments)]22pub fn new(23left: Arc<dyn PhysicalExpr>,24op: Operator,25right: Arc<dyn PhysicalExpr>,26expr: Expr,27has_literal: bool,28allow_threading: bool,29is_scalar: bool,30output_field: Field,31) -> Self {32Self {33left,34op,35right,36expr,37has_literal,38allow_threading,39is_scalar,40output_field,41}42}43}4445/// Can partially do operations in place.46fn apply_operator_owned(left: Column, right: Column, op: Operator) -> PolarsResult<Column> {47match op {48Operator::Plus => left.try_add_owned(right),49Operator::Minus => left.try_sub_owned(right),50Operator::Multiply51if left.dtype().is_primitive_numeric() && right.dtype().is_primitive_numeric() =>52{53left.try_mul_owned(right)54},55_ => apply_operator(&left, &right, op),56}57}5859pub fn apply_operator(left: &Column, right: &Column, op: Operator) -> PolarsResult<Column> {60use DataType::*;61match op {62Operator::Gt => ChunkCompareIneq::gt(left, right).map(|ca| ca.into_column()),63Operator::GtEq => ChunkCompareIneq::gt_eq(left, right).map(|ca| ca.into_column()),64Operator::Lt => ChunkCompareIneq::lt(left, right).map(|ca| ca.into_column()),65Operator::LtEq => ChunkCompareIneq::lt_eq(left, right).map(|ca| ca.into_column()),66Operator::Eq => ChunkCompareEq::equal(left, right).map(|ca| ca.into_column()),67Operator::NotEq => ChunkCompareEq::not_equal(left, right).map(|ca| ca.into_column()),68Operator::Plus => left + right,69Operator::Minus => left - right,70Operator::Multiply => left * right,71Operator::RustDivide => left / right,72Operator::TrueDivide => match left.dtype() {73#[cfg(feature = "dtype-decimal")]74Decimal(_, _) => left / right,75#[cfg(feature = "dtype-f16")]76Float16 => left / right,77Duration(_) | Date | Datetime(_, _) | Float32 | Float64 => left / right,78#[cfg(feature = "dtype-array")]79Array(..) => left / right,80#[cfg(feature = "dtype-array")]81_ if right.dtype().is_array() => left / right,82List(_) => left / right,83_ if right.dtype().is_list() => left / right,84_ if left.dtype().is_string() || right.dtype().is_string() => {85polars_bail!(InvalidOperation: "cannot divide using strings")86},87_ => {88if right.dtype().is_temporal() {89return left / right;90}91left.cast(&Float64)? / right.cast(&Float64)?92},93},94Operator::FloorDivide => {95#[cfg(feature = "round_series")]96{97floor_div_series(98left.as_materialized_series(),99right.as_materialized_series(),100)101.map(Column::from)102}103#[cfg(not(feature = "round_series"))]104{105panic!("activate 'round_series' feature")106}107},108Operator::And => left.bitand(right),109Operator::Or => left.bitor(right),110Operator::LogicalOr => left111.cast(&DataType::Boolean)?112.bitor(&right.cast(&DataType::Boolean)?),113Operator::LogicalAnd => left114.cast(&DataType::Boolean)?115.bitand(&right.cast(&DataType::Boolean)?),116Operator::Xor => left.bitxor(right),117Operator::Modulus => left % right,118Operator::EqValidity => left.equal_missing(right).map(|ca| ca.into_column()),119Operator::NotEqValidity => left.not_equal_missing(right).map(|ca| ca.into_column()),120}121}122123impl BinaryExpr {124fn apply_elementwise<'a>(125&self,126mut ac_l: AggregationContext<'a>,127mut ac_r: AggregationContext<'a>,128aggregated: bool,129) -> PolarsResult<AggregationContext<'a>> {130// At this stage, there is no combination of AggregatedList and NotAggregated ACs.131132// Check group lengths in case of all AggList133if [&ac_l, &ac_r]134.iter()135.all(|ac| matches!(ac.state, AggState::AggregatedList(_)))136{137ac_l.groups().check_lengths(ac_r.groups())?;138}139140match (ac_l.agg_state(), ac_r.agg_state()) {141(AggState::AggregatedList(s), _) | (_, AggState::AggregatedList(s)) => {142let ca = s.list().unwrap();143let [col_l, col_r] = [&ac_l, &ac_r].map(|ac| ac.flat_naive().into_owned());144145let out = ca.apply_to_inner(&|_| {146apply_operator(&col_l, &col_r, self.op).map(|c| c.take_materialized_series())147})?;148let out = out.into_column();149150if ac_l.is_literal() {151std::mem::swap(&mut ac_l, &mut ac_r);152}153154ac_l.with_values(out.into_column(), true, Some(&self.expr))?;155Ok(ac_l)156},157158_ => {159// We want to be able to mutate in place, so we take the lhs to make sure that we drop.160let lhs = ac_l.get_values().clone();161let rhs = ac_r.get_values().clone();162163let out = apply_operator_owned(lhs, rhs, self.op)?;164165if ac_l.is_literal() {166std::mem::swap(&mut ac_l, &mut ac_r);167}168169// Drop lhs so that we might operate in place.170drop(ac_l.take());171172ac_l.with_values(out, aggregated, Some(&self.expr))?;173Ok(ac_l)174},175}176}177178fn apply_all_literal<'a>(179&self,180mut ac_l: AggregationContext<'a>,181ac_r: AggregationContext<'a>,182) -> PolarsResult<AggregationContext<'a>> {183debug_assert!(ac_l.is_literal() && ac_r.is_literal());184polars_ensure!(ac_l.groups.len() == ac_r.groups.len(),185ComputeError: "lhs and rhs should have same number of groups");186187let left_c = ac_l.get_values().rechunk().into_column();188let right_c = ac_r.get_values().rechunk().into_column();189let res_c = apply_operator(&left_c, &right_c, self.op)?;190polars_ensure!(res_c.len() == 1,191ComputeError: "binary operation on literals expected 1 value, found {}", res_c.len());192193ac_l.with_literal(res_c);194Ok(ac_l)195}196197fn apply_group_aware<'a>(198&self,199mut ac_l: AggregationContext<'a>,200mut ac_r: AggregationContext<'a>,201) -> PolarsResult<AggregationContext<'a>> {202let name = self.output_field.name().clone();203let mut ca = ac_l204.iter_groups(false)205.zip(ac_r.iter_groups(false))206.map(|(l, r)| {207Some(apply_operator(208&l?.as_ref().clone().into_column(),209&r?.as_ref().clone().into_column(),210self.op,211))212})213.map(|opt_res| opt_res.transpose())214.collect::<PolarsResult<ListChunked>>()?215.with_name(name.clone());216if ca.is_empty() {217ca = ListChunked::full_null_with_dtype(name, 0, self.output_field.dtype());218}219220ac_l.with_update_groups(UpdateGroups::WithSeriesLen);221ac_l.with_agg_state(AggState::AggregatedList(ca.into_column()));222Ok(ac_l)223}224}225226impl PhysicalExpr for BinaryExpr {227fn as_expression(&self) -> Option<&Expr> {228Some(&self.expr)229}230231fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {232// Window functions may set a global state that determine their output233// state, so we don't let them run in parallel as they race234// they also saturate the thread pool by themselves, so that's fine.235let has_window = state.has_window();236237let (lhs, rhs);238if has_window {239let mut state = state.split();240state.remove_cache_window_flag();241lhs = self.left.evaluate(df, &state)?;242rhs = self.right.evaluate(df, &state)?;243} else if !self.allow_threading || self.has_literal {244// Literals are free, don't pay par cost.245lhs = self.left.evaluate(df, state)?;246rhs = self.right.evaluate(df, state)?;247} else {248let (opt_lhs, opt_rhs) = POOL.install(|| {249rayon::join(250|| self.left.evaluate(df, state),251|| self.right.evaluate(df, state),252)253});254(lhs, rhs) = (opt_lhs?, opt_rhs?);255};256polars_ensure!(257lhs.len() == rhs.len() || lhs.len() == 1 || rhs.len() == 1,258expr = self.expr,259ShapeMismatch: "cannot evaluate two Series of different lengths ({} and {})",260lhs.len(), rhs.len(),261);262apply_operator_owned(lhs, rhs, self.op)263}264265#[allow(clippy::ptr_arg)]266fn evaluate_on_groups<'a>(267&self,268df: &DataFrame,269groups: &'a GroupPositions,270state: &ExecutionState,271) -> PolarsResult<AggregationContext<'a>> {272let (result_a, result_b) = POOL.install(|| {273rayon::join(274|| self.left.evaluate_on_groups(df, groups, state),275|| self.right.evaluate_on_groups(df, groups, state),276)277});278let mut ac_l = result_a?;279let mut ac_r = result_b?;280281// Aggregate NotAggregated into AggregatedList, but only if strictly required AND282// when there is no risk of memory explosion.283// See ApplyExpr for additional context284let mut has_agg_list = false;285let mut has_agg_scalar = false;286let mut has_not_agg = false;287let mut has_not_agg_with_overlapping_groups = false;288let mut not_agg_groups_may_diverge = false;289290let mut previous: Option<&AggregationContext<'_>> = None;291for ac in [&ac_l, &ac_r] {292match ac.state {293AggState::AggregatedList(_) => {294has_agg_list = true;295},296AggState::AggregatedScalar(_) => has_agg_scalar = true,297AggState::NotAggregated(_) => {298has_not_agg = true;299if let Some(p) = previous {300not_agg_groups_may_diverge |= !p.groups.is_same(&ac.groups)301}302previous = Some(ac);303if ac.groups.is_overlapping() {304has_not_agg_with_overlapping_groups = true;305}306},307_ => {},308}309}310311let all_literal = !(has_agg_list || has_agg_scalar || has_not_agg);312let elementwise_must_aggregate =313has_not_agg && (has_agg_list || not_agg_groups_may_diverge);314let mut aggregated = has_agg_list || has_agg_scalar;315316// Arithmetic on Decimal is fallible317let has_decimal_dtype =318ac_l.get_values().dtype().is_decimal() || ac_r.get_values().dtype().is_decimal();319let is_fallible = has_decimal_dtype && self.op.is_arithmetic();320321// Broadcast in NotAgg or AggList requires group_aware322let check_broadcast = [&ac_l, &ac_r].iter().all(|ac| {323matches!(324ac.agg_state(),325AggState::NotAggregated(_) | AggState::AggregatedList(_)326)327});328let has_broadcast = check_broadcast329&& ac_l330.groups()331.iter()332.zip(ac_r.groups().iter())333.any(|(l, r)| l.len() != r.len() && (l.len() == 1 || r.len() == 1));334335// Dispatch336// See ApplyExpr for reference logic, except that we do any required337// aggregation inline. All BinaryExprs are elementwise.338if all_literal {339// Fast path340self.apply_all_literal(ac_l, ac_r)341} else if has_agg_scalar && (has_agg_list || has_not_agg) {342// Not compatible343self.apply_group_aware(ac_l, ac_r)344} else if elementwise_must_aggregate && has_not_agg_with_overlapping_groups {345// Compatible but calling aggregated() is too expensive346self.apply_group_aware(ac_l, ac_r)347} else if is_fallible348&& (!ac_l.groups_cover_all_values() || !ac_r.groups_cover_all_values())349{350// Fallible expression and there are elements that are masked out.351self.apply_group_aware(ac_l, ac_r)352} else {353if elementwise_must_aggregate {354for ac in [&mut ac_l, &mut ac_r] {355if matches!(ac.state, AggState::NotAggregated(_)) {356ac.aggregated();357}358}359aggregated = true;360}361if has_broadcast {362self.apply_group_aware(ac_l, ac_r)363} else {364self.apply_elementwise(ac_l, ac_r, aggregated)365}366}367}368369fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {370Ok(self.output_field.clone())371}372373fn is_scalar(&self) -> bool {374self.is_scalar375}376}377378379