Path: blob/main/crates/polars-plan/src/plans/conversion/dsl_to_ir/join.rs
7889 views
use arrow::legacy::error::PolarsResult;1use either::Either;2use polars_core::chunked_array::cast::CastOptions;3use polars_core::error::feature_gated;4use polars_core::utils::{get_numeric_upcast_supertype_lossless, try_get_supertype};5use polars_utils::format_pl_smallstr;6use polars_utils::itertools::Itertools;78use super::*;9use crate::constants::POLARS_TMP_PREFIX;10use crate::dsl::Expr;11#[cfg(feature = "iejoin")]12use crate::plans::AExpr;1314fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> {15for e in keys {16if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) {17polars_bail!(18InvalidOperation:19"'alias' is not allowed in a join key, use 'with_columns' first",20)21}22}23Ok(())24}2526/// Returns: left: join_node, right: last_node (often both the same)27pub fn resolve_join(28input_left: Either<Arc<DslPlan>, Node>,29input_right: Either<Arc<DslPlan>, Node>,30left_on: Vec<Expr>,31right_on: Vec<Expr>,32predicates: Vec<Expr>,33mut options: JoinOptionsIR,34ctxt: &mut DslConversionContext,35) -> PolarsResult<(Node, Node)> {36if !predicates.is_empty() {37feature_gated!("iejoin", {38debug_assert!(left_on.is_empty() && right_on.is_empty());39return resolve_join_where(40input_left.unwrap_left(),41input_right.unwrap_left(),42predicates,43options,44ctxt,45);46})47}4849let owned = Arc::unwrap_or_clone;50let mut input_left = input_left.map_right(Ok).right_or_else(|input| {51to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left)))52})?;53let mut input_right = input_right.map_right(Ok).right_or_else(|input| {54to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right)))55})?;5657let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);58let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);5960if options.args.how.is_cross() {61polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys");62} else {63polars_ensure!(left_on.len() + right_on.len() > 0, InvalidOperation: "expected join keys/predicates");64check_join_keys(&left_on)?;65check_join_keys(&right_on)?;6667let mut turn_off_coalesce = false;68for e in left_on.iter().chain(right_on.iter()) {69// Any expression that is not a simple column expression will turn of coalescing.70turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_)));71}72if turn_off_coalesce {73if matches!(options.args.coalesce, JoinCoalesce::CoalesceColumns) {74polars_warn!(75"coalescing join requested but not all join keys are column references, turning off key coalescing"76);77}78options.args.coalesce = JoinCoalesce::KeepColumns;79}8081options.args.validation.is_valid_join(&options.args.how)?;8283#[cfg(feature = "asof_join")]84if let JoinType::AsOf(options) = &options.args.how {85match (&options.left_by, &options.right_by) {86(None, None) => {},87(Some(l), Some(r)) => {88polars_ensure!(l.len() == r.len(), InvalidOperation: "expected equal number of columns in 'by_left' and 'by_right' in 'asof_join'");89validate_columns_in_input(l, &schema_left, "asof_join")?;90validate_columns_in_input(r, &schema_right, "asof_join")?;91},92_ => {93polars_bail!(InvalidOperation: "expected both 'by_left' and 'by_right' to be set in 'asof_join'")94},95}96}9798polars_ensure!(99left_on.len() == right_on.len(),100InvalidOperation:101format!(102"the number of columns given as join key (left: {}, right:{}) should be equal",103left_on.len(),104right_on.len()105)106);107}108109let mut left_on = left_on110.into_iter()111.map(|e| {112to_expr_ir_materialized_lit(113e,114&mut ExprToIRContext::new_with_opt_eager(115ctxt.expr_arena,116&schema_left,117ctxt.opt_flags,118),119)120})121.collect::<PolarsResult<Vec<_>>>()?;122let mut right_on = right_on123.into_iter()124.map(|e| {125to_expr_ir_materialized_lit(126e,127&mut ExprToIRContext::new_with_opt_eager(128ctxt.expr_arena,129&schema_right,130ctxt.opt_flags,131),132)133})134.collect::<PolarsResult<Vec<_>>>()?;135let mut joined_on = PlHashSet::new();136137#[cfg(feature = "iejoin")]138let check = !matches!(options.args.how, JoinType::IEJoin);139#[cfg(not(feature = "iejoin"))]140let check = true;141if check {142for (l, r) in left_on.iter().zip(right_on.iter()) {143polars_ensure!(144joined_on.insert((l.output_name(), r.output_name())),145InvalidOperation: "joining with repeated key names; already joined on {} and {}",146l.output_name(),147r.output_name()148)149}150}151drop(joined_on);152153ctxt.conversion_optimizer154.fill_scratch(&left_on, ctxt.expr_arena);155ctxt.conversion_optimizer156.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_left, true)157.map_err(|e| e.context("'join' failed".into()))?;158ctxt.conversion_optimizer159.fill_scratch(&right_on, ctxt.expr_arena);160ctxt.conversion_optimizer161.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_right, true)162.map_err(|e| e.context("'join' failed".into()))?;163164// Re-evaluate because of mutable borrows earlier.165let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);166let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);167168// # Resolve scalars169//170// Scalars need to be expanded. We translate them to temporary columns added with171// `with_columns` and remove them later with `project`172// This way the backends don't have to expand the literals in the join implementation173174let has_scalars = left_on175.iter()176.chain(right_on.iter())177.any(|e| e.is_scalar(ctxt.expr_arena));178179let (schema_left, schema_right) = if has_scalars {180let mut as_with_columns_l = vec![];181let mut as_with_columns_r = vec![];182for (i, e) in left_on.iter().enumerate() {183if e.is_scalar(ctxt.expr_arena) {184as_with_columns_l.push((i, e.clone()));185}186}187for (i, e) in right_on.iter().enumerate() {188if e.is_scalar(ctxt.expr_arena) {189as_with_columns_r.push((i, e.clone()));190}191}192193let mut count = 0;194let get_tmp_name = |i| format_pl_smallstr!("{POLARS_TMP_PREFIX}{i}");195196// Early clone because of bck.197let mut schema_right_new = if !as_with_columns_r.is_empty() {198(**schema_right).clone()199} else {200Default::default()201};202if !as_with_columns_l.is_empty() {203let mut schema_left_new = (**schema_left).clone();204205let mut exprs = Vec::with_capacity(as_with_columns_l.len());206for (i, mut e) in as_with_columns_l {207let tmp_name = get_tmp_name(count);208count += 1;209e.set_alias(tmp_name.clone());210let dtype = e.dtype(&schema_left_new, ctxt.expr_arena)?;211schema_left_new.with_column(tmp_name.clone(), dtype.clone());212213let col = ctxt.expr_arena.add(AExpr::Column(tmp_name));214left_on[i] = ExprIR::from_node(col, ctxt.expr_arena);215exprs.push(e);216}217input_left = ctxt.lp_arena.add(IR::HStack {218input: input_left,219exprs,220schema: Arc::new(schema_left_new),221options: ProjectionOptions::default(),222})223}224if !as_with_columns_r.is_empty() {225let mut exprs = Vec::with_capacity(as_with_columns_r.len());226for (i, mut e) in as_with_columns_r {227let tmp_name = get_tmp_name(count);228count += 1;229e.set_alias(tmp_name.clone());230let dtype = e.dtype(&schema_right_new, ctxt.expr_arena)?;231schema_right_new.with_column(tmp_name.clone(), dtype.clone());232233let col = ctxt.expr_arena.add(AExpr::Column(tmp_name));234right_on[i] = ExprIR::from_node(col, ctxt.expr_arena);235exprs.push(e);236}237input_right = ctxt.lp_arena.add(IR::HStack {238input: input_right,239exprs,240schema: Arc::new(schema_right_new),241options: ProjectionOptions::default(),242})243}244245(246ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena),247ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena),248)249} else {250(schema_left, schema_right)251};252253// Not a closure to avoid borrow issues because we mutate expr_arena as well.254macro_rules! get_dtype {255($expr:expr, $schema:expr) => {256ctxt.expr_arena257.get($expr.node())258.to_dtype(&ToFieldContext::new(ctxt.expr_arena, $schema))259};260}261262// As an optimization, when inserting casts for coalescing joins we only insert them beforehand for full-join.263// This means for e.g. left-join, the LHS key preserves its dtype in the output even if it is joined264// with an RHS key of wider type.265let key_cols_coalesced =266options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full);267let mut as_with_columns_l = vec![];268let mut as_with_columns_r = vec![];269for (lnode, rnode) in left_on.iter_mut().zip(right_on.iter_mut()) {270let ltype = get_dtype!(lnode, &schema_left)?;271let rtype = get_dtype!(rnode, &schema_right)?;272273if let Some(dtype) = get_numeric_upcast_supertype_lossless(<ype, &rtype) {274// We use overflowing cast to allow better optimization as we are casting to a known275// lossless supertype.276//277// We have unique references to these nodes (they are created by this function),278// so we can mutate in-place without causing side effects somewhere else.279let casted_l = ctxt.expr_arena.add(AExpr::Cast {280expr: lnode.node(),281dtype: dtype.clone(),282options: CastOptions::Overflowing,283});284let casted_r = ctxt.expr_arena.add(AExpr::Cast {285expr: rnode.node(),286dtype,287options: CastOptions::Overflowing,288});289290if key_cols_coalesced {291let mut lnode = lnode.clone();292let mut rnode = rnode.clone();293294let ae_l = ctxt.expr_arena.get(lnode.node());295let ae_r = ctxt.expr_arena.get(rnode.node());296297polars_ensure!(298ae_l.is_col() && ae_r.is_col(),299SchemaMismatch: "can only 'coalesce' full join if join keys are column expressions",300);301302lnode.set_node(casted_l);303rnode.set_node(casted_r);304305as_with_columns_r.push(rnode);306as_with_columns_l.push(lnode);307} else {308lnode.set_node(casted_l);309rnode.set_node(casted_r);310}311} else {312polars_ensure!(313ltype == rtype,314SchemaMismatch: "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right (and no other type was available to cast to)",315lnode.output_name(), ltype, rnode.output_name(), rtype316)317}318}319320// Every expression must be elementwise so that we are321// guaranteed the keys for a join are all the same length.322323polars_ensure!(324all_elementwise(&left_on, ctxt.expr_arena) && all_elementwise(&right_on, ctxt.expr_arena),325InvalidOperation: "all join key expressions must be elementwise."326);327328#[cfg(feature = "asof_join")]329if let JoinType::AsOf(options) = &mut options.args.how {330use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY;331332// prepare the tolerance333// we must ensure that we use the right units334if let Some(tol) = &options.tolerance_str {335let duration = polars_time::Duration::try_parse(tol)?;336polars_ensure!(337duration.months() == 0,338ComputeError: "cannot use month offset in timedelta of an asof join; \339consider using 4 weeks"340);341use DataType::*;342match ctxt343.expr_arena344.get(left_on[0].node())345.to_dtype(&ToFieldContext::new(ctxt.expr_arena, &schema_left))?346{347Datetime(tu, _) | Duration(tu) => {348let tolerance = match tu {349TimeUnit::Nanoseconds => duration.duration_ns(),350TimeUnit::Microseconds => duration.duration_us(),351TimeUnit::Milliseconds => duration.duration_ms(),352};353options.tolerance = Some(Scalar::from(tolerance))354},355Date => {356let days = (duration.duration_ms() / MILLISECONDS_IN_DAY) as i32;357options.tolerance = Some(Scalar::from(days))358},359Time => {360let tolerance = duration.duration_ns();361options.tolerance = Some(Scalar::from(tolerance))362},363_ => {364panic!(365"can only use timedelta string language with Date/Datetime/Duration/Time dtypes"366)367},368}369}370}371372// These are Arc<Schema>, into_owned is free.373let schema_left = schema_left.into_owned();374let schema_right = schema_right.into_owned();375376let join_schema = det_join_schema(377&schema_left,378&schema_right,379&left_on,380&right_on,381&options,382ctxt.expr_arena,383)384.map_err(|e| e.context(failed_here!(join schema resolving)))?;385386if key_cols_coalesced {387input_left = if as_with_columns_l.is_empty() {388input_left389} else {390ctxt.lp_arena.add(IR::HStack {391input: input_left,392exprs: as_with_columns_l,393schema: schema_left,394options: ProjectionOptions::default(),395})396};397398input_right = if as_with_columns_r.is_empty() {399input_right400} else {401ctxt.lp_arena.add(IR::HStack {402input: input_right,403exprs: as_with_columns_r,404schema: schema_right,405options: ProjectionOptions::default(),406})407};408}409410let ir = IR::Join {411input_left,412input_right,413schema: join_schema.clone(),414left_on,415right_on,416options: Arc::new(options),417};418let join_node = ctxt.lp_arena.add(ir);419420if has_scalars {421let names = join_schema422.iter_names()423.filter_map(|n| {424if n.starts_with(POLARS_TMP_PREFIX) {425None426} else {427Some(n.clone())428}429})430.collect_vec();431432let builder = IRBuilder::new(join_node, ctxt.expr_arena, ctxt.lp_arena);433let ir = builder.project_simple(names).map(|b| b.build())?;434let select_node = ctxt.lp_arena.add(ir);435436Ok((select_node, join_node))437} else {438Ok((join_node, join_node))439}440}441442#[cfg(feature = "iejoin")]443impl From<InequalityOperator> for Operator {444fn from(value: InequalityOperator) -> Self {445match value {446InequalityOperator::LtEq => Operator::LtEq,447InequalityOperator::Lt => Operator::Lt,448InequalityOperator::GtEq => Operator::GtEq,449InequalityOperator::Gt => Operator::Gt,450}451}452}453454#[cfg(feature = "iejoin")]455/// Returns: left: join_node, right: last_node (often both the same)456fn resolve_join_where(457input_left: Arc<DslPlan>,458input_right: Arc<DslPlan>,459predicates: Vec<Expr>,460mut options: JoinOptionsIR,461ctxt: &mut DslConversionContext,462) -> PolarsResult<(Node, Node)> {463// If not eager, respect the flag.464if ctxt.opt_flags.eager() {465ctxt.opt_flags.set(OptFlags::PREDICATE_PUSHDOWN, true);466}467check_join_keys(&predicates)?;468let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt)469.map_err(|e| e.context(failed_here!(join left)))?;470let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt)471.map_err(|e| e.context(failed_here!(join left)))?;472473let schema_left = ctxt474.lp_arena475.get(input_left)476.schema(ctxt.lp_arena)477.into_owned();478479options.args.how = JoinType::Cross;480481let (mut last_node, join_node) = resolve_join(482Either::Right(input_left),483Either::Right(input_right),484vec![],485vec![],486vec![],487options,488ctxt,489)?;490491let schema_merged = ctxt492.lp_arena493.get(last_node)494.schema(ctxt.lp_arena)495.into_owned();496497// Perform predicate validation.498let mut upcast_exprs = Vec::<(Node, DataType)>::new();499for e in predicates {500let arena = &mut ctxt.expr_arena;501let predicate = to_expr_ir_materialized_lit(502e,503&mut ExprToIRContext::new_with_opt_eager(arena, &schema_merged, ctxt.opt_flags),504)?;505let node = predicate.node();506507// Ensure the predicate dtype output of the root node is Boolean508let ae = arena.get(node);509let dt_out = ae.to_dtype(&ToFieldContext::new(arena, &schema_merged))?;510polars_ensure!(511dt_out == DataType::Boolean,512ComputeError: "'join_where' predicates must resolve to boolean"513);514515ensure_lossless_binary_comparisons(516&node,517&schema_left,518&schema_merged,519arena,520&mut upcast_exprs,521)?;522523ctxt.conversion_optimizer524.push_scratch(predicate.node(), ctxt.expr_arena);525526let ir = IR::Filter {527input: last_node,528predicate,529};530531last_node = ctxt.lp_arena.add(ir);532}533534ctxt.conversion_optimizer535.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, last_node, false)536.map_err(|e| e.context("'join_where' failed".into()))?;537538Ok((last_node, join_node))539}540541/// Locate nodes that are operands in a binary comparison involving both tables, and ensure that542/// these nodes are losslessly upcast to a safe dtype.543fn ensure_lossless_binary_comparisons(544node: &Node,545schema_left: &Schema,546schema_merged: &Schema,547expr_arena: &mut Arena<AExpr>,548upcast_exprs: &mut Vec<(Node, DataType)>,549) -> PolarsResult<()> {550// let mut upcast_exprs = Vec::<(Node, DataType)>::new();551// Ensure that all binary comparisons that use both tables are lossless.552build_upcast_node_list(node, schema_left, schema_merged, expr_arena, upcast_exprs)?;553// Replace each node with its casted counterpart554for (expr, dtype) in upcast_exprs.drain(..) {555let old_expr = expr_arena.duplicate(expr);556let new_aexpr = AExpr::Cast {557expr: old_expr,558dtype,559options: CastOptions::Overflowing,560};561expr_arena.replace(expr, new_aexpr);562}563Ok(())564}565566/// If we are dealing with a binary comparison involving columns from exclusively the left table567/// on the LHS and the right table on the RHS side, ensure that the cast is lossless.568/// Expressions involving binaries using either table alone we leave up to the user to verify569/// that they are valid, as they could theoretically be pushed outside of the join.570#[recursive]571fn build_upcast_node_list(572node: &Node,573schema_left: &Schema,574schema_merged: &Schema,575expr_arena: &Arena<AExpr>,576to_replace: &mut Vec<(Node, DataType)>,577) -> PolarsResult<ExprOrigin> {578let expr_origin = match expr_arena.get(*node) {579AExpr::Column(name) => {580if schema_left.contains(name) {581ExprOrigin::Left582} else if schema_merged.contains(name) {583ExprOrigin::Right584} else {585polars_bail!(ColumnNotFound: "{}", name);586}587},588AExpr::Literal(..) => ExprOrigin::None,589AExpr::Cast { expr: node, .. } => {590build_upcast_node_list(node, schema_left, schema_merged, expr_arena, to_replace)?591},592AExpr::BinaryExpr {593left: left_node,594op,595right: right_node,596} => {597// If left and right node has both, ensure the dtypes are valid.598let left_origin = build_upcast_node_list(599left_node,600schema_left,601schema_merged,602expr_arena,603to_replace,604)?;605let right_origin = build_upcast_node_list(606right_node,607schema_left,608schema_merged,609expr_arena,610to_replace,611)?;612// We only update casts during comparisons if the operands are from different tables.613if op.is_comparison() {614match (left_origin, right_origin) {615(ExprOrigin::Left, ExprOrigin::Right)616| (ExprOrigin::Right, ExprOrigin::Left) => {617// Ensure our dtype casts are lossless618let left = expr_arena.get(*left_node);619let right = expr_arena.get(*right_node);620let dtype_left =621left.to_dtype(&ToFieldContext::new(expr_arena, schema_merged))?;622let dtype_right =623right.to_dtype(&ToFieldContext::new(expr_arena, schema_merged))?;624if dtype_left != dtype_right {625// Ensure that we have a lossless cast between the two types.626let dt = if dtype_left.is_primitive_numeric()627|| dtype_right.is_primitive_numeric()628{629get_numeric_upcast_supertype_lossless(&dtype_left, &dtype_right)630.ok_or(PolarsError::SchemaMismatch(631format!(632"'join_where' cannot compare {dtype_left:?} with {dtype_right:?}"633)634.into(),635))636} else {637try_get_supertype(&dtype_left, &dtype_right)638}?;639640// Store the nodes and their replacements if a cast is required.641let replace_left = dt != dtype_left;642let replace_right = dt != dtype_right;643if replace_left && replace_right {644to_replace.push((*left_node, dt.clone()));645to_replace.push((*right_node, dt));646} else if replace_left {647to_replace.push((*left_node, dt));648} else if replace_right {649to_replace.push((*right_node, dt));650}651}652},653_ => (),654}655}656left_origin | right_origin657},658_ => ExprOrigin::None,659};660Ok(expr_origin)661}662663664