Path: blob/main/crates/polars-plan/src/plans/optimizer/collapse_joins.rs
6940 views
//! Optimization that collapses several a join with several filters into faster join.1//!2//! For example, `join(how='cross').filter(pl.col.l == pl.col.r)` can be collapsed to3//! `join(how='inner', left_on=pl.col.l, right_on=pl.col.r)`.45use std::sync::Arc;67use polars_core::schema::*;8#[cfg(feature = "iejoin")]9use polars_ops::frame::{IEJoinOptions, InequalityOperator};10use polars_ops::frame::{JoinCoalesce, JoinType, MaintainOrderJoin};11use polars_utils::arena::{Arena, Node};1213use super::{AExpr, ExprOrigin, IR, JoinOptionsIR, aexpr_to_leaf_names_iter};14use crate::dsl::{JoinTypeOptionsIR, Operator};15use crate::plans::optimizer::join_utils::remove_suffix;16use crate::plans::{ExprIR, MintermIter, is_elementwise_rec};1718fn and_expr(left: Node, right: Node, expr_arena: &mut Arena<AExpr>) -> Node {19expr_arena.add(AExpr::BinaryExpr {20left,21op: Operator::And,22right,23})24}2526pub fn optimize(27root: Node,28lp_arena: &mut Arena<IR>,29expr_arena: &mut Arena<AExpr>,30streaming: bool,31) {32let mut predicates = Vec::with_capacity(4);3334// Partition to:35// - equality predicates36// - IEjoin supported inequality predicates37// - remaining predicates38#[cfg(feature = "iejoin")]39let mut ie_op = Vec::new();40let mut remaining_predicates = Vec::new();4142let mut ir_stack = Vec::with_capacity(16);43ir_stack.push(root);4445while let Some(current) = ir_stack.pop() {46let current_ir = lp_arena.get(current);47current_ir.copy_inputs(&mut ir_stack);4849match current_ir {50IR::Filter {51input: _,52predicate,53} => {54predicates.push((current, predicate.node()));55},56IR::Join {57input_left,58input_right,59schema,60left_on,61right_on,62options,63} if options.args.how.is_cross() => {64if predicates.is_empty() {65continue;66}6768let suffix = options.args.suffix();6970debug_assert!(left_on.is_empty());71debug_assert!(right_on.is_empty());7273let mut eq_left_on = Vec::new();74let mut eq_right_on = Vec::new();7576#[cfg(feature = "iejoin")]77let mut ie_left_on = Vec::new();78#[cfg(feature = "iejoin")]79let mut ie_right_on = Vec::new();8081#[cfg(feature = "iejoin")]82{83ie_op.clear();84}8586remaining_predicates.clear();8788#[cfg(feature = "iejoin")]89fn to_inequality_operator(op: &Operator) -> Option<InequalityOperator> {90match op {91Operator::Lt => Some(InequalityOperator::Lt),92Operator::LtEq => Some(InequalityOperator::LtEq),93Operator::Gt => Some(InequalityOperator::Gt),94Operator::GtEq => Some(InequalityOperator::GtEq),95_ => None,96}97}9899let left_schema = lp_arena.get(*input_left).schema(lp_arena);100let right_schema = lp_arena.get(*input_right).schema(lp_arena);101102let left_schema = left_schema.as_ref();103let right_schema = right_schema.as_ref();104105for (_, predicate_node) in &predicates {106for node in MintermIter::new(*predicate_node, expr_arena) {107let AExpr::BinaryExpr { left, op, right } = expr_arena.get(node) else {108remaining_predicates.push(node);109continue;110};111112if !op.is_comparison_or_bitwise() {113// @NOTE: This is not a valid predicate, but we should not handle that114// here.115remaining_predicates.push(node);116continue;117}118119let mut left = *left;120let mut op = *op;121let mut right = *right;122123let left_origin = ExprOrigin::get_expr_origin(124left,125expr_arena,126left_schema,127right_schema,128suffix.as_str(),129None,130)131.unwrap();132let right_origin = ExprOrigin::get_expr_origin(133right,134expr_arena,135left_schema,136right_schema,137suffix.as_str(),138None,139)140.unwrap();141142use ExprOrigin as EO;143144// We can only join if both sides of the binary expression stem from145// different sides of the join.146match (left_origin, right_origin) {147(EO::Both, _) | (_, EO::Both) => {148// If either expression originates from the both sides, we need to149// filter it afterwards.150remaining_predicates.push(node);151continue;152},153(EO::None, _) | (_, EO::None) => {154// @TODO: This should probably be pushed down155remaining_predicates.push(node);156continue;157},158(EO::Left, EO::Left) | (EO::Right, EO::Right) => {159// @TODO: This can probably be pushed down in the predicate160// pushdown, but for now just take it as is.161remaining_predicates.push(node);162continue;163},164(EO::Right, EO::Left) => {165// Swap around the expressions so they match with the left_on and166// right_on.167std::mem::swap(&mut left, &mut right);168op = op.swap_operands();169},170(EO::Left, EO::Right) => {},171}172173if matches!(op, Operator::Eq) {174eq_left_on.push(ExprIR::from_node(left, expr_arena));175eq_right_on.push(ExprIR::from_node(right, expr_arena));176} else {177#[cfg(feature = "iejoin")]178if let Some(ie_op_) = to_inequality_operator(&op) {179fn is_numeric(180node: Node,181expr_arena: &Arena<AExpr>,182schema: &Schema,183) -> bool {184aexpr_to_leaf_names_iter(node, expr_arena).any(|name| {185if let Some(dt) = schema.get(name.as_str()) {186dt.to_physical().is_primitive_numeric()187} else {188false189}190})191}192193// We fallback to remaining if:194// - we already have an IEjoin or Inner join195// - we already have an Inner join196// - data is not numeric (our iejoin doesn't yet implement that)197if ie_op.len() >= 2198|| !eq_left_on.is_empty()199|| !is_numeric(left, expr_arena, left_schema)200{201remaining_predicates.push(node);202} else {203ie_left_on.push(ExprIR::from_node(left, expr_arena));204ie_right_on.push(ExprIR::from_node(right, expr_arena));205ie_op.push(ie_op_);206}207} else {208remaining_predicates.push(node);209}210211#[cfg(not(feature = "iejoin"))]212remaining_predicates.push(node);213}214}215}216217let mut can_simplify_join = false;218219if !eq_left_on.is_empty() {220for expr in eq_right_on.iter_mut() {221remove_suffix(expr, expr_arena, right_schema, suffix.as_str());222}223can_simplify_join = true;224} else {225#[cfg(feature = "iejoin")]226if !ie_op.is_empty() {227for expr in ie_right_on.iter_mut() {228remove_suffix(expr, expr_arena, right_schema, suffix.as_str());229}230can_simplify_join = true;231}232can_simplify_join |= options.args.how.is_cross();233}234235if can_simplify_join {236let new_join = insert_fitting_join(237eq_left_on,238eq_right_on,239#[cfg(feature = "iejoin")]240ie_left_on,241#[cfg(feature = "iejoin")]242ie_right_on,243#[cfg(feature = "iejoin")]244&ie_op,245&remaining_predicates,246lp_arena,247expr_arena,248options.as_ref().clone(),249*input_left,250*input_right,251schema.clone(),252streaming,253);254255lp_arena.swap(predicates[0].0, new_join);256}257258predicates.clear();259},260_ => {261predicates.clear();262},263}264}265}266267#[allow(clippy::too_many_arguments)]268fn insert_fitting_join(269eq_left_on: Vec<ExprIR>,270eq_right_on: Vec<ExprIR>,271#[cfg(feature = "iejoin")] ie_left_on: Vec<ExprIR>,272#[cfg(feature = "iejoin")] ie_right_on: Vec<ExprIR>,273#[cfg(feature = "iejoin")] ie_op: &[InequalityOperator],274remaining_predicates: &[Node],275lp_arena: &mut Arena<IR>,276expr_arena: &mut Arena<AExpr>,277mut options: JoinOptionsIR,278input_left: Node,279input_right: Node,280schema: SchemaRef,281streaming: bool,282) -> Node {283debug_assert_eq!(eq_left_on.len(), eq_right_on.len());284#[cfg(feature = "iejoin")]285{286debug_assert_eq!(ie_op.len(), ie_left_on.len());287debug_assert_eq!(ie_left_on.len(), ie_right_on.len());288debug_assert!(ie_op.len() <= 2);289}290debug_assert!(matches!(options.args.how, JoinType::Cross));291292let remaining_predicates = remaining_predicates293.iter()294.copied()295.reduce(|left, right| and_expr(left, right, expr_arena));296297let (left_on, right_on, remaining_predicates) = match () {298_ if !eq_left_on.is_empty() => {299options.args.how = JoinType::Inner;300// We need to make sure not to delete any columns301options.args.coalesce = JoinCoalesce::KeepColumns;302303#[cfg(feature = "iejoin")]304let remaining_predicates = ie_left_on.into_iter().zip(ie_op).zip(ie_right_on).fold(305remaining_predicates,306|acc, ((left, op), right)| {307let e = expr_arena.add(AExpr::BinaryExpr {308left: left.node(),309op: (*op).into(),310right: right.node(),311});312Some(acc.map_or(e, |acc| and_expr(acc, e, expr_arena)))313},314);315316(eq_left_on, eq_right_on, remaining_predicates)317},318#[cfg(feature = "iejoin")]319_ if !ie_op.is_empty() => {320// We can only IE join up to 2 operators321322let operator1 = ie_op[0];323let operator2 = ie_op.get(1).copied();324325// Do an IEjoin.326options.args.how = JoinType::IEJoin;327options.options = Some(JoinTypeOptionsIR::IEJoin(IEJoinOptions {328operator1,329operator2,330}));331// We need to make sure not to delete any columns332options.args.coalesce = JoinCoalesce::KeepColumns;333334(ie_left_on, ie_right_on, remaining_predicates)335},336// If anything just fall back to a cross join.337_ => {338options.args.how = JoinType::Cross;339// We need to make sure not to delete any columns340options.args.coalesce = JoinCoalesce::KeepColumns;341342#[cfg(feature = "iejoin")]343let remaining_predicates = ie_left_on.into_iter().zip(ie_op).zip(ie_right_on).fold(344remaining_predicates,345|acc, ((left, op), right)| {346let e = expr_arena.add(AExpr::BinaryExpr {347left: left.node(),348op: (*op).into(),349right: right.node(),350});351Some(acc.map_or(e, |acc| and_expr(acc, e, expr_arena)))352},353);354355let mut remaining_predicates = remaining_predicates;356if let Some(pred) = remaining_predicates.take_if(|pred| {357matches!(options.args.maintain_order, MaintainOrderJoin::None)358&& !streaming359&& is_elementwise_rec(*pred, expr_arena)360}) {361options.options = Some(JoinTypeOptionsIR::CrossAndFilter {362predicate: ExprIR::from_node(pred, expr_arena),363})364}365366(Vec::new(), Vec::new(), remaining_predicates)367},368};369370// Note: We expect key type upcasting / expression optimizations have already been done during371// DSL->IR conversion.372373let join_ir = IR::Join {374input_left,375input_right,376schema,377left_on,378right_on,379options: Arc::new(options),380};381382let join_node = lp_arena.add(join_ir);383384if let Some(predicate) = remaining_predicates {385lp_arena.add(IR::Filter {386input: join_node,387predicate: ExprIR::from_node(predicate, &*expr_arena),388})389} else {390join_node391}392}393394395