Path: blob/main/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs
7889 views
use polars_utils::format_pl_smallstr;12use super::*;3use crate::plans::optimizer::join_utils::remove_suffix;45const IEJOIN_MAX_PREDICATES: usize = 2;67#[allow(clippy::too_many_arguments)]8pub(super) fn process_join(9opt: &mut PredicatePushDown,10lp_arena: &mut Arena<IR>,11expr_arena: &mut Arena<AExpr>,12input_left: Node,13input_right: Node,14mut left_on: Vec<ExprIR>,15mut right_on: Vec<ExprIR>,16mut schema: SchemaRef,17mut options: Arc<JoinOptionsIR>,18mut acc_predicates: PlHashMap<PlSmallStr, ExprIR>,19streaming: bool,20) -> PolarsResult<IR> {21let schema_left = lp_arena.get(input_left).schema(lp_arena).into_owned();22let schema_right = lp_arena.get(input_right).schema(lp_arena).into_owned();2324let opt_post_select = try_rewrite_join_type(25&schema_left,26&schema_right,27&mut schema,28&mut options,29&mut left_on,30&mut right_on,31&mut acc_predicates,32expr_arena,33streaming,34)?;3536if match &options.args.how {37// Full-join with no coalesce. We can only push filters if they do not remove NULLs, but38// we don't have a reliable way to guarantee this.39JoinType::Full => !options.args.should_coalesce(),4041_ => false,42} || acc_predicates.is_empty()43{44let lp = IR::Join {45input_left,46input_right,47left_on,48right_on,49schema,50options,51};5253return opt.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena);54}5556let should_coalesce = options.args.should_coalesce();5758// AsOf has the equality join keys under `asof_options.left/right_by`. This code builds an59// iterator to address these generically without creating a `Box<dyn Iterator>`.60let get_lhs_column_keys_iter = || {61let len = match &options.args.how {62#[cfg(feature = "asof_join")]63JoinType::AsOf(asof_options) => {64asof_options.left_by.as_deref().unwrap_or_default().len()65},66_ => left_on.len(),67};6869(0..len).map(|i| match &options.args.how {70#[cfg(feature = "asof_join")]71JoinType::AsOf(asof_options) => Some(72asof_options73.left_by74.as_deref()75.unwrap_or_default()76.get(i)77.unwrap(),78),79_ => {80let expr = left_on.get(i).unwrap();8182// For non full-joins coalesce can still insert casts into the key exprs.83let node = match expr_arena.get(expr.node()) {84AExpr::Cast {85expr,86dtype: _,87options: _,88} if should_coalesce => *expr,8990_ => expr.node(),91};9293if let AExpr::Column(name) = expr_arena.get(node) {94Some(name)95} else {96None97}98},99})100};101102let get_rhs_column_keys_iter = || {103let len = match &options.args.how {104#[cfg(feature = "asof_join")]105JoinType::AsOf(asof_options) => {106asof_options.right_by.as_deref().unwrap_or_default().len()107},108_ => right_on.len(),109};110111(0..len).map(|i| match &options.args.how {112#[cfg(feature = "asof_join")]113JoinType::AsOf(asof_options) => Some(114asof_options115.right_by116.as_deref()117.unwrap_or_default()118.get(i)119.unwrap(),120),121_ => {122let expr = right_on.get(i).unwrap();123124// For non full-joins coalesce can still insert casts into the key exprs.125let node = match expr_arena.get(expr.node()) {126AExpr::Cast {127expr,128dtype: _,129options: _,130} if should_coalesce => *expr,131132_ => expr.node(),133};134135if let AExpr::Column(name) = expr_arena.get(node) {136Some(name)137} else {138None139}140},141})142};143144if cfg!(debug_assertions) && options.args.should_coalesce() {145match &options.args.how {146#[cfg(feature = "asof_join")]147JoinType::AsOf(_) => {},148149_ => {150assert!(get_lhs_column_keys_iter().len() > 0);151assert!(get_rhs_column_keys_iter().len() > 0);152},153}154155assert!(get_lhs_column_keys_iter().all(|x| x.is_some()));156assert!(get_rhs_column_keys_iter().all(|x| x.is_some()));157}158159// Key columns of the left table that are coalesced into an output column of the right table.160let coalesced_to_right: PlHashSet<PlSmallStr> =161if matches!(&options.args.how, JoinType::Right) && options.args.should_coalesce() {162get_lhs_column_keys_iter()163.map(|x| x.unwrap().clone())164.collect()165} else {166Default::default()167};168169let mut output_key_to_left_input_map: PlHashMap<PlSmallStr, PlSmallStr> =170PlHashMap::with_capacity(get_lhs_column_keys_iter().len());171let mut output_key_to_right_input_map: PlHashMap<PlSmallStr, PlSmallStr> =172PlHashMap::with_capacity(get_rhs_column_keys_iter().len());173174for (lhs_input_key, rhs_input_key) in get_lhs_column_keys_iter().zip(get_rhs_column_keys_iter())175{176let (Some(lhs_input_key), Some(rhs_input_key)) = (lhs_input_key, rhs_input_key) else {177continue;178};179180// lhs_input_key: Column name within the left table.181use JoinType::*;182// Map output name of an LHS join key output to an input key column of the right table.183// This will cause predicates referring to LHS join keys to also be pushed to the RHS table.184if match &options.args.how {185Left | Inner | Full => true,186187#[cfg(feature = "asof_join")]188AsOf(_) => true,189#[cfg(feature = "semi_anti_join")]190Semi | Anti => true,191192// NOTE: Right-join is excluded.193Right => false,194195#[cfg(feature = "iejoin")]196IEJoin => false,197198Cross => unreachable!(), // Cross left/right_on should be empty199} {200// Note: `lhs_input_key` maintains its name in the output column for all cases except201// for a coalescing right-join.202output_key_to_right_input_map.insert(lhs_input_key.clone(), rhs_input_key.clone());203}204205// Map output name of an RHS join key output to a key column of the left table.206// This will cause predicates referring to RHS join keys to also be pushed to the LHS table.207if match &options.args.how {208JoinType::Right => true,209// Non-coalesced output columns of an inner join are equivalent between LHS and RHS.210JoinType::Inner => !options.args.should_coalesce(),211_ => false,212} {213let rhs_output_key: PlSmallStr = if schema_left.contains(rhs_input_key.as_str())214&& !coalesced_to_right.contains(rhs_input_key.as_str())215{216format_pl_smallstr!("{}{}", rhs_input_key, options.args.suffix())217} else {218rhs_input_key.clone()219};220221assert!(schema.contains(&rhs_output_key));222223output_key_to_left_input_map.insert(rhs_output_key.clone(), lhs_input_key.clone());224}225}226227let mut pushdown_left: PlHashMap<PlSmallStr, ExprIR> = init_hashmap(Some(acc_predicates.len()));228let mut pushdown_right: PlHashMap<PlSmallStr, ExprIR> =229init_hashmap(Some(acc_predicates.len()));230let mut local_predicates = Vec::with_capacity(acc_predicates.len());231232for (_, predicate) in acc_predicates {233let mut push_left = true;234let mut push_right = true;235236for col_name in aexpr_to_leaf_names_iter(predicate.node(), expr_arena) {237let origin: ExprOrigin = ExprOrigin::get_column_origin(238col_name.as_str(),239&schema_left,240&schema_right,241options.args.suffix(),242Some(&|name| coalesced_to_right.contains(name)),243)244.unwrap();245246push_left &= matches!(origin, ExprOrigin::Left | ExprOrigin::None)247|| output_key_to_left_input_map.contains_key(col_name);248249push_right &= matches!(origin, ExprOrigin::Right | ExprOrigin::None)250|| output_key_to_right_input_map.contains_key(col_name);251}252253// Note: If `push_left` and `push_right` are both `true`, it means the predicate refers only254// to the join key columns, or the predicate does not refer any columns.255256let has_residual = match &options.args.how {257// Pushing to a single side is enough to observe the full effect of the filter.258JoinType::Inner => !(push_left || push_right),259260// Left-join: Pushing filters to the left table is enough to observe the effect of the261// filter. Pushing filters to the right is optional, but can only be done if the262// filter is also pushed to the left (if this is the case it means the filter only263// references join key columns).264JoinType::Left => {265push_right &= push_left;266!push_left267},268269// Same as left-join, just flipped around.270JoinType::Right => {271push_left &= push_right;272!push_right273},274275// Full-join: Filters must strictly apply only to coalesced output key columns.276JoinType::Full => {277assert!(options.args.should_coalesce());278279let push = push_left && push_right;280push_left = push;281push_right = push;282283!push284},285286JoinType::Cross => {287// Predicate should only refer to a single side.288assert!(output_key_to_left_input_map.is_empty());289assert!(output_key_to_right_input_map.is_empty());290!(push_left || push_right)291},292293// Behaves similarly to left-join on "by" columns (takes a single match instead of294// all matches according to asof strategy).295#[cfg(feature = "asof_join")]296JoinType::AsOf(_) => {297push_right &= push_left;298!push_left299},300301// Same as inner-join.302#[cfg(feature = "semi_anti_join")]303JoinType::Semi => !(push_left || push_right),304305// Anti-join is an exclusion of key tuples that exist in the right table, meaning that306// filters can only be pushed to the right table if they are also pushed to the left.307#[cfg(feature = "semi_anti_join")]308JoinType::Anti => {309push_right &= push_left;310!push_left311},312313// Same as inner-join.314#[cfg(feature = "iejoin")]315JoinType::IEJoin => !(push_left || push_right),316};317318if has_residual {319local_predicates.push(predicate.clone())320}321322if push_left {323let mut predicate = predicate.clone();324map_column_references(&mut predicate, expr_arena, &output_key_to_left_input_map);325insert_predicate_dedup(&mut pushdown_left, &predicate, expr_arena);326}327328if push_right {329let mut predicate = predicate;330map_column_references(&mut predicate, expr_arena, &output_key_to_right_input_map);331remove_suffix(332&mut predicate,333expr_arena,334&schema_right,335options.args.suffix(),336);337insert_predicate_dedup(&mut pushdown_right, &predicate, expr_arena);338}339}340341opt.pushdown_and_assign(input_left, pushdown_left, lp_arena, expr_arena)?;342opt.pushdown_and_assign(input_right, pushdown_right, lp_arena, expr_arena)?;343344let lp = IR::Join {345input_left,346input_right,347left_on,348right_on,349schema,350options,351};352353let lp = opt.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena);354355let lp = if let Some((projections, schema)) = opt_post_select {356IR::Select {357input: lp_arena.add(lp),358expr: projections,359schema,360options: ProjectionOptions {361run_parallel: false,362duplicate_check: false,363should_broadcast: false,364},365}366} else {367lp368};369370Ok(lp)371}372373/// Attempts to rewrite the join-type based on NULL-removing filters.374///375/// Changing between some join types may cause the output column order to change. If this is the376/// case, a Vec of column selectors will be returned that restore the original column order.377#[expect(clippy::too_many_arguments)]378fn try_rewrite_join_type(379schema_left: &SchemaRef,380schema_right: &SchemaRef,381output_schema: &mut SchemaRef,382options: &mut Arc<JoinOptionsIR>,383left_on: &mut Vec<ExprIR>,384right_on: &mut Vec<ExprIR>,385acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,386expr_arena: &mut Arena<AExpr>,387streaming: bool,388) -> PolarsResult<Option<(Vec<ExprIR>, SchemaRef)>> {389if acc_predicates.is_empty() {390return Ok(None);391}392393let suffix = options.args.suffix().clone();394395// * Cross -> Inner | IEJoin396// * IEJoin -> Inner397//398// Note: The join rewrites here all maintain output column ordering, hence this does not need399// to return any post-select (inserted inner joins will use JoinCoalesce::KeepColumns).400(|| {401match &options.args.how {402#[cfg(feature = "iejoin")]403JoinType::IEJoin => {},404JoinType::Cross => {},405406_ => return PolarsResult::Ok(()),407}408409match &options.options {410Some(JoinTypeOptionsIR::CrossAndFilter { .. }) => {411let Some(JoinTypeOptionsIR::CrossAndFilter { predicate }) =412Arc::make_mut(options).options.take()413else {414unreachable!()415};416417insert_predicate_dedup(acc_predicates, &predicate, expr_arena);418},419420#[cfg(feature = "iejoin")]421Some(JoinTypeOptionsIR::IEJoin(_)) => {},422None => {},423}424425// Try converting to inner join426let equality_conditions = take_inner_join_compatible_filters(427acc_predicates,428expr_arena,429schema_left,430schema_right,431&suffix,432)?;433434for InnerJoinKeys {435input_lhs,436input_rhs,437} in equality_conditions438{439let join_options = Arc::make_mut(options);440join_options.args.how = JoinType::Inner;441join_options.args.coalesce = JoinCoalesce::KeepColumns;442443left_on.push(ExprIR::from_node(input_lhs, expr_arena));444let mut rexpr = ExprIR::from_node(input_rhs, expr_arena);445remove_suffix(&mut rexpr, expr_arena, schema_right, &suffix);446right_on.push(rexpr);447}448449if options.args.how == JoinType::Inner {450return Ok(());451}452453// Try converting cross join to IEJoin454#[cfg(feature = "iejoin")]455if matches!(options.args.maintain_order, MaintainOrderJoin::None)456&& left_on.len() < IEJOIN_MAX_PREDICATES457{458let ie_conditions = take_iejoin_compatible_filters(459acc_predicates,460expr_arena,461schema_left,462schema_right,463output_schema,464&suffix,465)?;466467for IEJoinCompatiblePredicate {468input_lhs,469input_rhs,470ie_op,471source_node,472} in ie_conditions473{474let join_options = Arc::make_mut(options);475join_options.args.how = JoinType::IEJoin;476477if left_on.len() >= IEJOIN_MAX_PREDICATES {478// Important: Place these back into acc_predicates.479insert_predicate_dedup(480acc_predicates,481&ExprIR::from_node(source_node, expr_arena),482expr_arena,483);484} else {485left_on.push(ExprIR::from_node(input_lhs, expr_arena));486let mut rexpr = ExprIR::from_node(input_rhs, expr_arena);487remove_suffix(&mut rexpr, expr_arena, schema_right, &suffix);488right_on.push(rexpr);489490let JoinTypeOptionsIR::IEJoin(ie_options) = join_options491.options492.get_or_insert(JoinTypeOptionsIR::IEJoin(IEJoinOptions::default()))493else {494unreachable!()495};496497match left_on.len() {4981 => ie_options.operator1 = ie_op,4992 => ie_options.operator2 = Some(ie_op),500_ => unreachable!("{}", IEJOIN_MAX_PREDICATES),501};502}503}504505if options.args.how == JoinType::IEJoin {506return Ok(());507}508}509510debug_assert_eq!(options.args.how, JoinType::Cross);511512if options.args.how != JoinType::Cross {513return Ok(());514}515516if streaming {517return Ok(());518}519520let Some(nested_loop_predicates) = take_nested_loop_join_compatible_filters(521acc_predicates,522expr_arena,523schema_left,524schema_right,525&suffix,526)?527.reduce(|left, right| {528expr_arena.add(AExpr::BinaryExpr {529left,530op: Operator::And,531right,532})533}) else {534return Ok(());535};536537let existing = Arc::make_mut(options)538.options539.replace(JoinTypeOptionsIR::CrossAndFilter {540predicate: ExprIR::from_node(nested_loop_predicates, expr_arena),541});542assert!(existing.is_none()); // Important543544Ok(())545})()?;546547if !matches!(548&options.args.how,549JoinType::Full | JoinType::Left | JoinType::Right550) {551return Ok(None);552}553554let should_coalesce = options.args.should_coalesce();555556/// Note: This may panic if `args.should_coalesce()` is false.557macro_rules! lhs_input_column_keys_iter {558() => {{559left_on.iter().map(|expr| {560let node = match expr_arena.get(expr.node()) {561AExpr::Cast {562expr,563dtype: _,564options: _,565} if should_coalesce => *expr,566567_ => expr.node(),568};569570let AExpr::Column(name) = expr_arena.get(node) else {571// All keys should be columns when coalesce=True572unreachable!()573};574575name.clone()576})577}};578}579580let mut coalesced_to_right: PlHashSet<PlSmallStr> = Default::default();581// Removing NULLs on these columns do not allow for join downgrading.582// We only need to track these for full-join - e.g. for left-join, removing NULLs from any left583// column does not cause any join rewrites.584let mut coalesced_full_join_key_outputs: PlHashSet<PlSmallStr> = Default::default();585586if options.args.should_coalesce() {587match &options.args.how {588JoinType::Full => {589coalesced_full_join_key_outputs = lhs_input_column_keys_iter!().collect()590},591JoinType::Right => coalesced_to_right = lhs_input_column_keys_iter!().collect(),592_ => {},593}594}595596let mut non_null_side = ExprOrigin::None;597598for predicate in acc_predicates.values() {599for node in MintermIter::new(predicate.node(), expr_arena) {600predicate_non_null_column_outputs(node, expr_arena, &mut |non_null_column| {601if coalesced_full_join_key_outputs.contains(non_null_column) {602return;603}604605non_null_side |= ExprOrigin::get_column_origin(606non_null_column.as_str(),607schema_left,608schema_right,609options.args.suffix(),610Some(&|x| coalesced_to_right.contains(x)),611)612.unwrap();613});614}615}616617let Some(new_join_type) = (match non_null_side {618ExprOrigin::Both => Some(JoinType::Inner),619620ExprOrigin::Left => match &options.args.how {621JoinType::Full => Some(JoinType::Left),622JoinType::Right => Some(JoinType::Inner),623_ => None,624},625626ExprOrigin::Right => match &options.args.how {627JoinType::Full => Some(JoinType::Right),628JoinType::Left => Some(JoinType::Inner),629_ => None,630},631632ExprOrigin::None => None,633}) else {634return Ok(None);635};636637let options = Arc::make_mut(options);638// Ensure JoinSpecific is materialized to a specific config option, as we change the join type.639options.args.coalesce = if options.args.should_coalesce() {640JoinCoalesce::CoalesceColumns641} else {642JoinCoalesce::KeepColumns643};644let original_join_type = std::mem::replace(&mut options.args.how, new_join_type.clone());645let original_output_schema = match (&original_join_type, &new_join_type) {646(JoinType::Right, _) | (_, JoinType::Right) => std::mem::replace(647output_schema,648det_join_schema(649schema_left,650schema_right,651left_on,652right_on,653options,654expr_arena,655)656.unwrap(),657),658_ => {659debug_assert_eq!(660output_schema,661&det_join_schema(662schema_left,663schema_right,664left_on,665right_on,666options,667expr_arena,668)669.unwrap()670);671output_schema.clone()672},673};674675// Maps the original join output names to the new join output names (used for mapping column676// references of the predicates).677let mut original_to_new_names_map: PlHashMap<PlSmallStr, PlSmallStr> = Default::default();678// Projects the new join output table back into the original join output table.679let mut project_to_original: Option<Vec<ExprIR>> = None;680681if options.args.should_coalesce() {682// If we changed join types between a coalescing right-join, we need to do a select() to restore the column683// order of the original join type. The column references in the predicates may also need to be changed.684match (&original_join_type, &new_join_type) {685(JoinType::Right, JoinType::Right) => unreachable!(),686687// Right-join rewritten to inner-join.688//689// E.g.690// Left: | a | b | c |691// Right: | a | b | c |692//693// right_join(left_on='a', right_on='b'): | b | c | a | *b_right | c_right |694// inner_join(left_on='a', right_on='b'): | *a | b | c | a_right | c_right |695// note: '*' means coalesced key output column696//697// project_to_original: | col(b) | col(c) | col(a_right).alias(a) | col(a).alias(b_right) | col(c_right) |698// original_to_new_names_map: {'a': 'a_right', 'b_right': 'a'}699//700(JoinType::Right, JoinType::Inner) => {701let mut join_output_key_selectors = PlHashMap::with_capacity(right_on.len());702703for (l, r) in left_on.iter().zip(right_on) {704let (AExpr::Column(lhs_input_key), AExpr::Column(rhs_input_key)) =705(expr_arena.get(l.node()), expr_arena.get(r.node()))706else {707// `should_coalesce() == true` should guarantee all are columns.708unreachable!()709};710711let original_key_output_name: PlSmallStr = if schema_left712.contains(rhs_input_key.as_str())713&& !coalesced_to_right.contains(rhs_input_key.as_str())714{715format_pl_smallstr!("{}{}", rhs_input_key, options.args.suffix())716} else {717rhs_input_key.clone()718};719720let new_key_output_name = lhs_input_key.clone();721let rhs_input_key = rhs_input_key.clone();722723let node = expr_arena.add(AExpr::Column(lhs_input_key.clone()));724let mut ae = ExprIR::from_node(node, expr_arena);725726if original_key_output_name != new_key_output_name {727// E.g. left_on=col(a), right_on=col(b)728// rhs_output_key = 'b', lhs_input_key = 'a', the original right-join is supposed to output 'b'.729original_to_new_names_map.insert(730original_key_output_name.clone(),731new_key_output_name.clone(),732);733ae.set_alias(original_key_output_name)734}735736join_output_key_selectors.insert(rhs_input_key, ae);737}738739let mut column_selectors: Vec<ExprIR> = Vec::with_capacity(output_schema.len());740741for lhs_input_col in schema_left.iter_names() {742if coalesced_to_right.contains(lhs_input_col) {743continue;744}745746let node = expr_arena.add(AExpr::Column(lhs_input_col.clone()));747column_selectors.push(ExprIR::from_node(node, expr_arena));748}749750for rhs_input_col in schema_right.iter_names() {751let expr = if let Some(expr) = join_output_key_selectors.get(rhs_input_col) {752expr.clone()753} else if schema_left.contains(rhs_input_col) {754let new_join_output_name =755format_pl_smallstr!("{}{}", rhs_input_col, options.args.suffix());756757let node = expr_arena.add(AExpr::Column(new_join_output_name.clone()));758let mut expr = ExprIR::from_node(node, expr_arena);759760// The column with the same name from the LHS is not projected in the original761// right-join, so we alias to remove the suffix that was added from the inner-join.762if coalesced_to_right.contains(rhs_input_col.as_str()) {763original_to_new_names_map764.insert(rhs_input_col.clone(), new_join_output_name);765expr.set_alias(rhs_input_col.clone());766}767768expr769} else {770let node = expr_arena.add(AExpr::Column(rhs_input_col.clone()));771ExprIR::from_node(node, expr_arena)772};773774column_selectors.push(expr)775}776777assert_eq!(column_selectors.len(), output_schema.len());778assert_eq!(column_selectors.len(), original_output_schema.len());779780if cfg!(debug_assertions) {781assert!(782column_selectors783.iter()784.zip(original_output_schema.iter_names())785.all(|(l, r)| l.output_name() == r)786)787}788789project_to_original = Some(column_selectors)790},791792// Full-join rewritten to right-join793//794// E.g.795// Left: | a | b | c |796// Right: | a | b | c |797//798// full_join(left_on='a', right_on='b'): | *a | b | c | a_right | c_right |799// right_join(left_on='a', right_on='b'): | b | c | a | *b_right | c_right |800// note: '*' means coalesced key output column801//802// project_to_original: | col(b_right).alias(a) | col(b) | col(c) | col(a).alias(a_right) | col(c_right) |803// original_to_new_names_map: {'a': 'b_right', 'a_right': 'a'}804//805(JoinType::Full, JoinType::Right) => {806let mut join_output_key_selectors = PlHashMap::with_capacity(left_on.len());807808// The existing one is empty because the original join type was not a right-join.809assert!(coalesced_to_right.is_empty());810// LHS input key columns that are coalesced (i.e. not projected) for the right-join.811let coalesced_to_right: PlHashSet<PlSmallStr> =812lhs_input_column_keys_iter!().collect();813// RHS input key columns that are coalesced (i.e. not projected) for the full-join.814let mut coalesced_to_left: PlHashSet<PlSmallStr> =815PlHashSet::with_capacity(right_on.len());816817for (l, r) in left_on.iter().zip(right_on) {818let (AExpr::Column(lhs_input_key), AExpr::Column(rhs_input_key)) =819(expr_arena.get(l.node()), expr_arena.get(r.node()))820else {821// `should_coalesce() == true` should guarantee all columns.822unreachable!()823};824825let new_key_output_name: PlSmallStr = if schema_left826.contains(rhs_input_key.as_str())827&& !coalesced_to_right.contains(rhs_input_key.as_str())828{829format_pl_smallstr!("{}{}", rhs_input_key, options.args.suffix())830} else {831rhs_input_key.clone()832};833834let lhs_input_key = lhs_input_key.clone();835let rhs_input_key = rhs_input_key.clone();836let original_key_output_name = &lhs_input_key;837838coalesced_to_left.insert(rhs_input_key);839840let node = expr_arena.add(AExpr::Column(new_key_output_name.clone()));841842let mut ae = ExprIR::from_node(node, expr_arena);843844// E.g. left_on=col(a), right_on=col(b)845// rhs_output_key = 'b', lhs_input_key = 'a'846if new_key_output_name != original_key_output_name {847original_to_new_names_map.insert(848original_key_output_name.clone(),849new_key_output_name.clone(),850);851ae.set_alias(original_key_output_name.clone())852}853854join_output_key_selectors.insert(lhs_input_key.clone(), ae);855}856857let mut column_selectors = Vec::with_capacity(output_schema.len());858859for lhs_input_col in schema_left.iter_names() {860let expr = if let Some(expr) = join_output_key_selectors.get(lhs_input_col) {861expr.clone()862} else {863let node = expr_arena.add(AExpr::Column(lhs_input_col.clone()));864ExprIR::from_node(node, expr_arena)865};866867column_selectors.push(expr)868}869870for rhs_input_col in schema_right.iter_names() {871if coalesced_to_left.contains(rhs_input_col) {872continue;873}874875let mut original_output_name: Option<PlSmallStr> = None;876877let new_join_output_name = if schema_left.contains(rhs_input_col) {878let suffixed =879format_pl_smallstr!("{}{}", rhs_input_col, options.args.suffix());880881if coalesced_to_right.contains(rhs_input_col) {882original_output_name = Some(suffixed);883rhs_input_col.clone()884} else {885suffixed886}887} else {888rhs_input_col.clone()889};890891let node = expr_arena.add(AExpr::Column(new_join_output_name));892893let mut expr = ExprIR::from_node(node, expr_arena);894895if let Some(original_output_name) = original_output_name {896original_to_new_names_map897.insert(original_output_name.clone(), rhs_input_col.clone());898expr.set_alias(original_output_name);899}900901column_selectors.push(expr);902}903904assert_eq!(column_selectors.len(), output_schema.len());905assert_eq!(column_selectors.len(), original_output_schema.len());906907if cfg!(debug_assertions) {908assert!(909column_selectors910.iter()911.zip(original_output_schema.iter_names())912.all(|(l, r)| l.output_name() == r)913)914}915916project_to_original = Some(column_selectors)917},918919(JoinType::Right, _) | (_, JoinType::Right) => unreachable!(),920921_ => {},922}923}924925if !original_to_new_names_map.is_empty() {926assert!(project_to_original.is_some());927928for (_, predicate_expr) in acc_predicates.iter_mut() {929map_column_references(predicate_expr, expr_arena, &original_to_new_names_map);930}931}932933Ok(project_to_original.map(|p| (p, original_output_schema)))934}935936struct InnerJoinKeys {937input_lhs: Node,938input_rhs: Node,939}940941/// Removes all equality predicates that can be used as inner-join conditions from `acc_predicates`.942fn take_inner_join_compatible_filters(943acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,944expr_arena: &mut Arena<AExpr>,945schema_left: &Schema,946schema_right: &Schema,947suffix: &str,948) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, InnerJoinKeys>> {949take_predicates_mut(acc_predicates, expr_arena, |ae, _ae_node, expr_arena| {950Ok(match ae {951AExpr::BinaryExpr {952left,953op: Operator::Eq,954right,955} => {956let left_origin = ExprOrigin::get_expr_origin(957*left,958expr_arena,959schema_left,960schema_right,961suffix,962None, // is_coalesced_to_right963)?;964let right_origin = ExprOrigin::get_expr_origin(965*right,966expr_arena,967schema_left,968schema_right,969suffix,970None,971)?;972973match (left_origin, right_origin) {974(ExprOrigin::Left, ExprOrigin::Right) => Some(InnerJoinKeys {975input_lhs: *left,976input_rhs: *right,977}),978(ExprOrigin::Right, ExprOrigin::Left) => Some(InnerJoinKeys {979input_lhs: *right,980input_rhs: *left,981}),982_ => None,983}984},985_ => None,986})987})988}989990#[cfg(feature = "iejoin")]991struct IEJoinCompatiblePredicate {992input_lhs: Node,993input_rhs: Node,994ie_op: InequalityOperator,995/// Original input node.996source_node: Node,997}998999#[cfg(feature = "iejoin")]1000/// Removes all inequality filters that can be used as iejoin conditions from `acc_predicates`.1001fn take_iejoin_compatible_filters(1002acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,1003expr_arena: &mut Arena<AExpr>,1004schema_left: &Schema,1005schema_right: &Schema,1006output_schema: &Schema,1007suffix: &str,1008) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, IEJoinCompatiblePredicate>> {1009return take_predicates_mut(acc_predicates, expr_arena, |ae, ae_node, expr_arena| {1010Ok(match ae {1011AExpr::BinaryExpr { left, op, right } => {1012if to_inequality_operator(op).is_none() {1013return Ok(None);1014}10151016let left_origin = ExprOrigin::get_expr_origin(1017*left,1018expr_arena,1019schema_left,1020schema_right,1021suffix,1022None, // is_coalesced_to_right1023)?;10241025let right_origin = ExprOrigin::get_expr_origin(1026*right,1027expr_arena,1028schema_left,1029schema_right,1030suffix,1031None,1032)?;10331034macro_rules! is_supported_type {1035($node:expr) => {{1036let node = $node;1037let field = expr_arena1038.get(node)1039.to_field(&ToFieldContext::new(expr_arena, output_schema))?;1040let dtype = field.dtype();10411042!dtype.is_nested() && dtype.to_physical().is_primitive_numeric()1043}};1044}10451046// IEJoin only supports numeric.1047if !is_supported_type!(*left) || !is_supported_type!(*right) {1048return Ok(None);1049}10501051match (left_origin, right_origin) {1052(ExprOrigin::Left, ExprOrigin::Right) => Some(IEJoinCompatiblePredicate {1053input_lhs: *left,1054input_rhs: *right,1055ie_op: to_inequality_operator(op).unwrap(),1056source_node: ae_node,1057}),1058(ExprOrigin::Right, ExprOrigin::Left) => {1059let op = op.swap_operands();10601061Some(IEJoinCompatiblePredicate {1062input_lhs: *right,1063input_rhs: *left,1064ie_op: to_inequality_operator(&op).unwrap(),1065source_node: ae_node,1066})1067},1068_ => None,1069}1070},1071_ => None,1072})1073});10741075fn to_inequality_operator(op: &Operator) -> Option<InequalityOperator> {1076match op {1077Operator::Lt => Some(InequalityOperator::Lt),1078Operator::LtEq => Some(InequalityOperator::LtEq),1079Operator::Gt => Some(InequalityOperator::Gt),1080Operator::GtEq => Some(InequalityOperator::GtEq),1081_ => None,1082}1083}1084}10851086/// Removes all filters that can be used as nested loop join conditions from `acc_predicates`.1087///1088/// Note that filters that refer only to a single side are not removed so that they can be pushed1089/// into the LHS/RHS tables.1090fn take_nested_loop_join_compatible_filters(1091acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,1092expr_arena: &mut Arena<AExpr>,1093schema_left: &Schema,1094schema_right: &Schema,1095suffix: &str,1096) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, Node>> {1097take_predicates_mut(acc_predicates, expr_arena, |_ae, ae_node, expr_arena| {1098Ok(1099match ExprOrigin::get_expr_origin(1100ae_node,1101expr_arena,1102schema_left,1103schema_right,1104suffix,1105None,1106)? {1107// Leave single-origin exprs as they get pushed to the left/right tables individually.1108ExprOrigin::Left | ExprOrigin::Right | ExprOrigin::None => None,1109_ => Some(ae_node),1110},1111)1112})1113}11141115/// Removes predicates from the map according to a function.1116fn take_predicates_mut<F, T>(1117acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,1118expr_arena: &mut Arena<AExpr>,1119take_predicate: F,1120) -> PolarsResult<hashbrown::hash_map::IntoValues<Node, T>>1121where1122F: Fn(&AExpr, Node, &Arena<AExpr>) -> PolarsResult<Option<T>>,1123{1124let mut selected_predicates: PlHashMap<Node, T> = PlHashMap::new();11251126for predicate in acc_predicates.values() {1127for node in MintermIter::new(predicate.node(), expr_arena) {1128let ae = expr_arena.get(node);11291130if let Some(t) = take_predicate(ae, node, expr_arena)? {1131selected_predicates.insert(node, t);1132}1133}1134}11351136if !selected_predicates.is_empty() {1137remove_min_terms(acc_predicates, expr_arena, &|node| {1138selected_predicates.contains_key(node)1139});1140}11411142return Ok(selected_predicates.into_values());11431144#[inline(never)]1145fn remove_min_terms(1146acc_predicates: &mut PlHashMap<PlSmallStr, ExprIR>,1147expr_arena: &mut Arena<AExpr>,1148should_remove: &dyn Fn(&Node) -> bool,1149) {1150let mut remove_keys = PlHashSet::new();1151let mut nodes_scratch = vec![];11521153for (k, predicate) in acc_predicates.iter_mut() {1154let mut has_removed = false;11551156nodes_scratch.clear();1157nodes_scratch.extend(1158MintermIter::new(predicate.node(), expr_arena).filter(|node| {1159let remove = should_remove(node);1160has_removed |= remove;1161!remove1162}),1163);11641165if nodes_scratch.is_empty() {1166remove_keys.insert(k.clone());1167continue;1168};11691170if has_removed {1171let new_predicate_node = nodes_scratch1172.drain(..)1173.reduce(|left, right| {1174expr_arena.add(AExpr::BinaryExpr {1175left,1176op: Operator::And,1177right,1178})1179})1180.unwrap();11811182*predicate = ExprIR::from_node(new_predicate_node, expr_arena);1183}1184}11851186for k in remove_keys {1187let v = acc_predicates.remove(&k);1188assert!(v.is_some());1189}1190}1191}119211931194