Path: blob/main/crates/polars-plan/src/plans/conversion/dsl_to_ir/join.rs
8506 views
use arrow::legacy::error::PolarsResult;1use either::Either;2use polars_core::chunked_array::cast::CastOptions;3use polars_core::error::feature_gated;4use polars_core::utils::{get_numeric_upcast_supertype_lossless, try_get_supertype};5use polars_utils::format_pl_smallstr;6use polars_utils::itertools::Itertools;78use super::*;9use crate::constants::POLARS_TMP_PREFIX;10use crate::dsl::Expr;11#[cfg(feature = "iejoin")]12use crate::plans::AExpr;1314fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> {15for e in keys {16if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) {17polars_bail!(18InvalidOperation:19"'alias' is not allowed in a join key, use 'with_columns' first",20)21}22}23Ok(())24}2526/// Returns: left: join_node, right: last_node (often both the same)27pub fn resolve_join(28input_left: Either<Arc<DslPlan>, Node>,29input_right: Either<Arc<DslPlan>, Node>,30left_on: Vec<Expr>,31right_on: Vec<Expr>,32predicates: Vec<Expr>,33mut options: JoinOptionsIR,34ctxt: &mut DslConversionContext,35) -> PolarsResult<(Node, Node)> {36if !predicates.is_empty() {37feature_gated!("iejoin", {38debug_assert!(left_on.is_empty() && right_on.is_empty());39return resolve_join_where(40input_left.unwrap_left(),41input_right.unwrap_left(),42predicates,43options,44ctxt,45);46})47}4849let owned = Arc::unwrap_or_clone;50let mut input_left = input_left.map_right(Ok).right_or_else(|input| {51to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left)))52})?;53let mut input_right = input_right.map_right(Ok).right_or_else(|input| {54to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right)))55})?;5657let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);58let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);5960if options.args.how.is_cross() {61polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys");62} else {63polars_ensure!(left_on.len() + right_on.len() > 0, InvalidOperation: "expected join keys/predicates");64check_join_keys(&left_on)?;65check_join_keys(&right_on)?;6667let mut turn_off_coalesce = false;68for e in left_on.iter().chain(right_on.iter()) {69// Any expression that is not a simple column expression will turn of coalescing.70turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_)));71}72if turn_off_coalesce {73if matches!(options.args.coalesce, JoinCoalesce::CoalesceColumns) {74polars_warn!(75"coalescing join requested but not all join keys are column references, turning off key coalescing"76);77}78options.args.coalesce = JoinCoalesce::KeepColumns;79}8081options.args.validation.is_valid_join(&options.args.how)?;8283#[cfg(feature = "asof_join")]84if let JoinType::AsOf(options) = &options.args.how {85match (&options.left_by, &options.right_by) {86(None, None) => {},87(Some(l), Some(r)) => {88polars_ensure!(l.len() == r.len(), InvalidOperation: "expected equal number of columns in 'by_left' and 'by_right' in 'asof_join'");89validate_columns_in_input(l, &schema_left, "asof_join")?;90validate_columns_in_input(r, &schema_right, "asof_join")?;91},92_ => {93polars_bail!(InvalidOperation: "expected both 'by_left' and 'by_right' to be set in 'asof_join'")94},95}96}9798polars_ensure!(99left_on.len() == right_on.len(),100InvalidOperation:101"the number of columns given as join key (left: {}, right:{}) should be equal",102left_on.len(),103right_on.len()104);105}106107let mut left_on = left_on108.into_iter()109.map(|e| {110to_expr_ir_materialized_lit(111e,112&mut ExprToIRContext::new_with_opt_eager(113ctxt.expr_arena,114&schema_left,115ctxt.opt_flags,116),117)118})119.collect::<PolarsResult<Vec<_>>>()?;120let mut right_on = right_on121.into_iter()122.map(|e| {123to_expr_ir_materialized_lit(124e,125&mut ExprToIRContext::new_with_opt_eager(126ctxt.expr_arena,127&schema_right,128ctxt.opt_flags,129),130)131})132.collect::<PolarsResult<Vec<_>>>()?;133let mut joined_on = PlHashSet::new();134135#[cfg(feature = "iejoin")]136let check = !matches!(options.args.how, JoinType::IEJoin);137#[cfg(not(feature = "iejoin"))]138let check = true;139if check {140for (l, r) in left_on.iter().zip(right_on.iter()) {141polars_ensure!(142joined_on.insert((l.output_name(), r.output_name())),143InvalidOperation: "joining with repeated key names; already joined on {} and {}",144l.output_name(),145r.output_name()146)147}148}149drop(joined_on);150151ctxt.conversion_optimizer152.fill_scratch(&left_on, ctxt.expr_arena);153ctxt.conversion_optimizer154.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_left, true)155.map_err(|e| e.context("'join' failed".into()))?;156ctxt.conversion_optimizer157.fill_scratch(&right_on, ctxt.expr_arena);158ctxt.conversion_optimizer159.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_right, true)160.map_err(|e| e.context("'join' failed".into()))?;161162// Re-evaluate because of mutable borrows earlier.163let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);164let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);165166// # Resolve scalars167//168// Scalars need to be expanded. We translate them to temporary columns added with169// `with_columns` and remove them later with `project`170// This way the backends don't have to expand the literals in the join implementation171172let has_scalars = left_on173.iter()174.chain(right_on.iter())175.any(|e| e.is_scalar(ctxt.expr_arena));176177let (schema_left, schema_right) = if has_scalars {178let mut as_with_columns_l = vec![];179let mut as_with_columns_r = vec![];180for (i, e) in left_on.iter().enumerate() {181if e.is_scalar(ctxt.expr_arena) {182as_with_columns_l.push((i, e.clone()));183}184}185for (i, e) in right_on.iter().enumerate() {186if e.is_scalar(ctxt.expr_arena) {187as_with_columns_r.push((i, e.clone()));188}189}190191let mut count = 0;192let get_tmp_name = |i| format_pl_smallstr!("{POLARS_TMP_PREFIX}{i}");193194// Early clone because of bck.195let mut schema_right_new = if !as_with_columns_r.is_empty() {196(**schema_right).clone()197} else {198Default::default()199};200if !as_with_columns_l.is_empty() {201let mut schema_left_new = (**schema_left).clone();202203let mut exprs = Vec::with_capacity(as_with_columns_l.len());204for (i, mut e) in as_with_columns_l {205let tmp_name = get_tmp_name(count);206count += 1;207e.set_alias(tmp_name.clone());208let dtype = e.dtype(&schema_left_new, ctxt.expr_arena)?;209schema_left_new.with_column(tmp_name.clone(), dtype.clone());210211let col = ctxt.expr_arena.add(AExpr::Column(tmp_name));212left_on[i] = ExprIR::from_node(col, ctxt.expr_arena);213exprs.push(e);214}215input_left = ctxt.lp_arena.add(IR::HStack {216input: input_left,217exprs,218schema: Arc::new(schema_left_new),219options: ProjectionOptions::default(),220})221}222if !as_with_columns_r.is_empty() {223let mut exprs = Vec::with_capacity(as_with_columns_r.len());224for (i, mut e) in as_with_columns_r {225let tmp_name = get_tmp_name(count);226count += 1;227e.set_alias(tmp_name.clone());228let dtype = e.dtype(&schema_right_new, ctxt.expr_arena)?;229schema_right_new.with_column(tmp_name.clone(), dtype.clone());230231let col = ctxt.expr_arena.add(AExpr::Column(tmp_name));232right_on[i] = ExprIR::from_node(col, ctxt.expr_arena);233exprs.push(e);234}235input_right = ctxt.lp_arena.add(IR::HStack {236input: input_right,237exprs,238schema: Arc::new(schema_right_new),239options: ProjectionOptions::default(),240})241}242243(244ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena),245ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena),246)247} else {248(schema_left, schema_right)249};250251// Not a closure to avoid borrow issues because we mutate expr_arena as well.252macro_rules! get_dtype {253($expr:expr, $schema:expr) => {254ctxt.expr_arena255.get($expr.node())256.to_dtype(&ToFieldContext::new(ctxt.expr_arena, $schema))257};258}259260// As an optimization, when inserting casts for coalescing joins we only insert them beforehand for full-join.261// This means for e.g. left-join, the LHS key preserves its dtype in the output even if it is joined262// with an RHS key of wider type.263let key_cols_coalesced =264options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full);265let mut as_with_columns_l = vec![];266let mut as_with_columns_r = vec![];267for (lnode, rnode) in left_on.iter_mut().zip(right_on.iter_mut()) {268let ltype = get_dtype!(lnode, &schema_left)?;269let rtype = get_dtype!(rnode, &schema_right)?;270271if let Some(dtype) = get_numeric_upcast_supertype_lossless(<ype, &rtype) {272// We use overflowing cast to allow better optimization as we are casting to a known273// lossless supertype.274//275// We have unique references to these nodes (they are created by this function),276// so we can mutate in-place without causing side effects somewhere else.277let casted_l = ctxt.expr_arena.add(AExpr::Cast {278expr: lnode.node(),279dtype: dtype.clone(),280options: CastOptions::Overflowing,281});282let casted_r = ctxt.expr_arena.add(AExpr::Cast {283expr: rnode.node(),284dtype,285options: CastOptions::Overflowing,286});287288if key_cols_coalesced {289let mut lnode = lnode.clone();290let mut rnode = rnode.clone();291292let ae_l = ctxt.expr_arena.get(lnode.node());293let ae_r = ctxt.expr_arena.get(rnode.node());294295polars_ensure!(296ae_l.is_col() && ae_r.is_col(),297SchemaMismatch: "can only 'coalesce' full join if join keys are column expressions",298);299300lnode.set_node(casted_l);301rnode.set_node(casted_r);302303as_with_columns_r.push(rnode);304as_with_columns_l.push(lnode);305} else {306lnode.set_node(casted_l);307rnode.set_node(casted_r);308}309} else {310polars_ensure!(311ltype == rtype,312SchemaMismatch: "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right (and no other type was available to cast to)",313lnode.output_name(), ltype.pretty_format(), rnode.output_name(), rtype.pretty_format()314);315}316}317318// Every expression must be elementwise so that we are319// guaranteed the keys for a join are all the same length.320321polars_ensure!(322all_elementwise(&left_on, ctxt.expr_arena) && all_elementwise(&right_on, ctxt.expr_arena),323InvalidOperation: "all join key expressions must be elementwise."324);325326#[cfg(feature = "asof_join")]327if let JoinType::AsOf(options) = &mut options.args.how {328use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY;329330// prepare the tolerance331// we must ensure that we use the right units332if let Some(tol) = &options.tolerance_str {333let duration = polars_time::Duration::try_parse(tol)?;334polars_ensure!(335duration.months() == 0,336ComputeError: "cannot use month offset in timedelta of an asof join; \337consider using 4 weeks"338);339use DataType::*;340match ctxt341.expr_arena342.get(left_on[0].node())343.to_dtype(&ToFieldContext::new(ctxt.expr_arena, &schema_left))?344{345Datetime(tu, _) | Duration(tu) => {346let tolerance = match tu {347TimeUnit::Nanoseconds => duration.duration_ns(),348TimeUnit::Microseconds => duration.duration_us(),349TimeUnit::Milliseconds => duration.duration_ms(),350};351options.tolerance = Some(Scalar::from(tolerance))352},353Date => {354let days = (duration.duration_ms() / MILLISECONDS_IN_DAY) as i32;355options.tolerance = Some(Scalar::from(days))356},357Time => {358let tolerance = duration.duration_ns();359options.tolerance = Some(Scalar::from(tolerance))360},361_ => {362panic!(363"can only use timedelta string language with Date/Datetime/Duration/Time dtypes"364)365},366}367}368}369370// These are Arc<Schema>, into_owned is free.371let schema_left = schema_left.into_owned();372let schema_right = schema_right.into_owned();373374let join_schema = det_join_schema(375&schema_left,376&schema_right,377&left_on,378&right_on,379&options,380ctxt.expr_arena,381)382.map_err(|e| e.context(failed_here!(join schema resolving)))?;383384if key_cols_coalesced {385input_left = if as_with_columns_l.is_empty() {386input_left387} else {388ctxt.lp_arena.add(IR::HStack {389input: input_left,390exprs: as_with_columns_l,391schema: schema_left,392options: ProjectionOptions::default(),393})394};395396input_right = if as_with_columns_r.is_empty() {397input_right398} else {399ctxt.lp_arena.add(IR::HStack {400input: input_right,401exprs: as_with_columns_r,402schema: schema_right,403options: ProjectionOptions::default(),404})405};406}407408let ir = IR::Join {409input_left,410input_right,411schema: join_schema.clone(),412left_on,413right_on,414options: Arc::new(options),415};416let join_node = ctxt.lp_arena.add(ir);417418if has_scalars {419let names = join_schema420.iter_names()421.filter_map(|n| {422if n.starts_with(POLARS_TMP_PREFIX) {423None424} else {425Some(n.clone())426}427})428.collect_vec();429430let builder = IRBuilder::new(join_node, ctxt.expr_arena, ctxt.lp_arena);431let ir = builder.project_simple(names).map(|b| b.build())?;432let select_node = ctxt.lp_arena.add(ir);433434Ok((select_node, join_node))435} else {436Ok((join_node, join_node))437}438}439440#[cfg(feature = "iejoin")]441impl From<InequalityOperator> for Operator {442fn from(value: InequalityOperator) -> Self {443match value {444InequalityOperator::LtEq => Operator::LtEq,445InequalityOperator::Lt => Operator::Lt,446InequalityOperator::GtEq => Operator::GtEq,447InequalityOperator::Gt => Operator::Gt,448}449}450}451452#[cfg(feature = "iejoin")]453/// Returns: left: join_node, right: last_node (often both the same)454fn resolve_join_where(455input_left: Arc<DslPlan>,456input_right: Arc<DslPlan>,457predicates: Vec<Expr>,458mut options: JoinOptionsIR,459ctxt: &mut DslConversionContext,460) -> PolarsResult<(Node, Node)> {461// If not eager, respect the flag.462if ctxt.opt_flags.eager() {463ctxt.opt_flags.set(OptFlags::PREDICATE_PUSHDOWN, true);464}465check_join_keys(&predicates)?;466let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt)467.map_err(|e| e.context(failed_here!(join left)))?;468let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt)469.map_err(|e| e.context(failed_here!(join left)))?;470471let schema_left = ctxt472.lp_arena473.get(input_left)474.schema(ctxt.lp_arena)475.into_owned();476477options.args.how = JoinType::Cross;478479let (mut last_node, join_node) = resolve_join(480Either::Right(input_left),481Either::Right(input_right),482vec![],483vec![],484vec![],485options,486ctxt,487)?;488489let schema_merged = ctxt490.lp_arena491.get(last_node)492.schema(ctxt.lp_arena)493.into_owned();494495// Perform predicate validation.496let mut upcast_exprs = Vec::<(Node, DataType)>::new();497for e in predicates {498let arena = &mut ctxt.expr_arena;499let predicate = to_expr_ir_materialized_lit(500e,501&mut ExprToIRContext::new_with_opt_eager(arena, &schema_merged, ctxt.opt_flags),502)?;503let node = predicate.node();504505// Ensure the predicate dtype output of the root node is Boolean506let ae = arena.get(node);507let dt_out = ae.to_dtype(&ToFieldContext::new(arena, &schema_merged))?;508polars_ensure!(509dt_out == DataType::Boolean,510ComputeError: "'join_where' predicates must resolve to boolean"511);512513ensure_lossless_binary_comparisons(514&node,515&schema_left,516&schema_merged,517arena,518&mut upcast_exprs,519)?;520521ctxt.conversion_optimizer522.push_scratch(predicate.node(), ctxt.expr_arena);523524let ir = IR::Filter {525input: last_node,526predicate,527};528529last_node = ctxt.lp_arena.add(ir);530}531532ctxt.conversion_optimizer533.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, last_node, false)534.map_err(|e| e.context("'join_where' failed".into()))?;535536Ok((last_node, join_node))537}538539/// Locate nodes that are operands in a binary comparison involving both tables, and ensure that540/// these nodes are losslessly upcast to a safe dtype.541fn ensure_lossless_binary_comparisons(542node: &Node,543schema_left: &Schema,544schema_merged: &Schema,545expr_arena: &mut Arena<AExpr>,546upcast_exprs: &mut Vec<(Node, DataType)>,547) -> PolarsResult<()> {548// let mut upcast_exprs = Vec::<(Node, DataType)>::new();549// Ensure that all binary comparisons that use both tables are lossless.550build_upcast_node_list(node, schema_left, schema_merged, expr_arena, upcast_exprs)?;551// Replace each node with its casted counterpart552for (expr, dtype) in upcast_exprs.drain(..) {553let old_expr = expr_arena.duplicate(expr);554let new_aexpr = AExpr::Cast {555expr: old_expr,556dtype,557options: CastOptions::Overflowing,558};559expr_arena.replace(expr, new_aexpr);560}561Ok(())562}563564/// If we are dealing with a binary comparison involving columns from exclusively the left table565/// on the LHS and the right table on the RHS side, ensure that the cast is lossless.566/// Expressions involving binaries using either table alone we leave up to the user to verify567/// that they are valid, as they could theoretically be pushed outside of the join.568#[recursive]569fn build_upcast_node_list(570node: &Node,571schema_left: &Schema,572schema_merged: &Schema,573expr_arena: &Arena<AExpr>,574to_replace: &mut Vec<(Node, DataType)>,575) -> PolarsResult<ExprOrigin> {576let expr_origin = match expr_arena.get(*node) {577AExpr::Column(name) => {578if schema_left.contains(name) {579ExprOrigin::Left580} else if schema_merged.contains(name) {581ExprOrigin::Right582} else {583polars_bail!(ColumnNotFound: "{name}");584}585},586AExpr::Literal(..) => ExprOrigin::None,587AExpr::Cast { expr: node, .. } => {588build_upcast_node_list(node, schema_left, schema_merged, expr_arena, to_replace)?589},590AExpr::BinaryExpr {591left: left_node,592op,593right: right_node,594} => {595// If left and right node has both, ensure the dtypes are valid.596let left_origin = build_upcast_node_list(597left_node,598schema_left,599schema_merged,600expr_arena,601to_replace,602)?;603let right_origin = build_upcast_node_list(604right_node,605schema_left,606schema_merged,607expr_arena,608to_replace,609)?;610// We only update casts during comparisons if the operands are from different tables.611if op.is_comparison() {612match (left_origin, right_origin) {613(ExprOrigin::Left, ExprOrigin::Right)614| (ExprOrigin::Right, ExprOrigin::Left) => {615// Ensure our dtype casts are lossless616let left = expr_arena.get(*left_node);617let right = expr_arena.get(*right_node);618let dtype_left =619left.to_dtype(&ToFieldContext::new(expr_arena, schema_merged))?;620let dtype_right =621right.to_dtype(&ToFieldContext::new(expr_arena, schema_merged))?;622if dtype_left != dtype_right {623// Ensure that we have a lossless cast between the two types.624let dt = if dtype_left.is_primitive_numeric()625|| dtype_right.is_primitive_numeric()626{627get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right)628.ok_or(PolarsError::SchemaMismatch(629format!(630"'join_where' cannot compare {dtype_left:?} with {dtype_right:?}"631)632.into(),633))634} else {635try_get_supertype(&dtype_left, &dtype_right)636}?;637638// Store the nodes and their replacements if a cast is required.639let replace_left = dt != dtype_left;640let replace_right = dt != dtype_right;641if replace_left && replace_right {642to_replace.push((*left_node, dt.clone()));643to_replace.push((*right_node, dt));644} else if replace_left {645to_replace.push((*left_node, dt));646} else if replace_right {647to_replace.push((*right_node, dt));648}649}650},651_ => (),652}653}654left_origin | right_origin655},656_ => ExprOrigin::None,657};658Ok(expr_origin)659}660661662