Path: blob/main/crates/polars-ops/src/frame/join/merge_join.rs
8480 views
use std::borrow::Cow;1use std::cmp::Ordering;2use std::iter::repeat_n;34use arrow::array::Array;5use arrow::array::builder::ShareStrategy;6use polars_core::frame::builder::DataFrameBuilder;7use polars_core::prelude::*;8use polars_core::with_match_physical_numeric_polars_type;9use polars_utils::itertools::Itertools;10use polars_utils::total_ord::TotalOrd;11use polars_utils::{IdxSize, format_pl_smallstr};1213use crate::frame::{JoinArgs, JoinType};14use crate::series::coalesce_columns;1516#[allow(clippy::too_many_arguments)]17pub fn match_keys(18build_keys: &Series,19probe_keys: &Series,20gather_build: &mut Vec<IdxSize>,21gather_probe: &mut Vec<IdxSize>,22gather_probe_unmatched: Option<&mut Vec<IdxSize>>,23build_emit_unmatched: bool,24descending: bool,25nulls_equal: bool,26limit_results: usize,27build_row_offset: &mut usize,28probe_row_offset: &mut usize,29probe_last_matched: &mut usize,30) {31macro_rules! dispatch {32($build_keys_ca:expr) => {33match_keys_impl(34$build_keys_ca,35probe_keys.as_ref().as_ref(),36gather_build,37gather_probe,38gather_probe_unmatched,39build_emit_unmatched,40descending,41nulls_equal,42limit_results,43build_row_offset,44probe_row_offset,45probe_last_matched,46)47};48}4950assert_eq!(build_keys.dtype(), probe_keys.dtype());51match build_keys.dtype() {52dt if dt.is_primitive_numeric() => {53with_match_physical_numeric_polars_type!(dt, |$T| {54type PhysCa = ChunkedArray<$T>;55let build_keys_ca: &PhysCa = build_keys.as_ref().as_ref();56dispatch!(build_keys_ca)57})58},59DataType::Boolean => dispatch!(build_keys.bool().unwrap()),60DataType::String => dispatch!(build_keys.str().unwrap()),61DataType::Binary => dispatch!(build_keys.binary().unwrap()),62DataType::BinaryOffset => dispatch!(build_keys.binary_offset().unwrap()),63#[cfg(feature = "dtype-categorical")]64DataType::Enum(cats, _) => with_match_categorical_physical_type!(cats.physical(), |$C| {65type PhysCa = ChunkedArray<<$C as PolarsCategoricalType>::PolarsPhysical>;66let build_keys_ca: &PhysCa = build_keys.as_ref().as_ref();67dispatch!(build_keys_ca)68}),69DataType::Null => match_null_keys_impl(70build_keys.len(),71probe_keys.len(),72gather_build,73gather_probe,74gather_probe_unmatched,75build_emit_unmatched,76descending,77nulls_equal,78limit_results,79build_row_offset,80probe_row_offset,81probe_last_matched,82),83dt => unimplemented!("merge-join kernel not implemented for {:?}", dt),84}85}8687#[allow(clippy::mut_range_bound, clippy::too_many_arguments)]88fn match_keys_impl<'a, T: PolarsDataType>(89build_keys: &'a ChunkedArray<T>,90probe_keys: &'a ChunkedArray<T>,91gather_build: &mut Vec<IdxSize>,92gather_probe: &mut Vec<IdxSize>,93mut gather_probe_unmatched: Option<&mut Vec<IdxSize>>,94build_emit_unmatched: bool,95descending: bool,96nulls_equal: bool,97limit_results: usize,98build_row_offset: &mut usize,99probe_row_offset: &mut usize,100probe_first_unmatched: &mut usize,101) where102T::Physical<'a>: TotalOrd,103{104assert!(gather_build.is_empty());105assert!(gather_probe.is_empty());106107let build_key = build_keys.downcast_as_array();108let probe_key = probe_keys.downcast_as_array();109110while *build_row_offset < build_key.len() {111if gather_build.len() >= limit_results {112return;113}114115let build_keyval = unsafe { build_key.get_unchecked(*build_row_offset) };116let build_keyval = build_keyval.as_ref();117let mut build_keyval_matched = false;118119if nulls_equal || build_keyval.is_some() {120for probe_idx in *probe_row_offset..probe_key.len() {121let probe_keyval = unsafe { probe_key.get_unchecked(probe_idx) };122let probe_keyval = probe_keyval.as_ref();123124let mut ord: Ordering = match (&build_keyval, &probe_keyval) {125(None, None) if nulls_equal => Ordering::Equal,126(Some(l), Some(r)) => TotalOrd::tot_cmp(*l, *r),127_ => continue,128};129if descending {130ord = ord.reverse();131}132133match ord {134Ordering::Equal => {135if let Some(probe_unmatched) = gather_probe_unmatched.as_mut() {136// All probe keys up to and *excluding* this matched key are unmatched137probe_unmatched138.extend(*probe_first_unmatched as IdxSize..probe_idx as IdxSize);139*probe_first_unmatched = (*probe_first_unmatched).max(probe_idx + 1);140}141gather_build.push(*build_row_offset as IdxSize);142gather_probe.push(probe_idx as IdxSize);143build_keyval_matched = true;144},145Ordering::Greater => {146if let Some(probe_unmatched) = gather_probe_unmatched.as_mut() {147// All probe keys up to and *including* this matched key are unmatched148probe_unmatched149.extend(*probe_first_unmatched as IdxSize..=probe_idx as IdxSize);150*probe_first_unmatched = (*probe_first_unmatched).max(probe_idx + 1);151}152*probe_row_offset = probe_idx + 1;153},154Ordering::Less => {155break;156},157}158}159}160if build_emit_unmatched && !build_keyval_matched {161gather_build.push(*build_row_offset as IdxSize);162gather_probe.push(IdxSize::MAX);163}164*build_row_offset += 1;165}166if let Some(probe_unmatched) = gather_probe_unmatched {167probe_unmatched.extend(*probe_first_unmatched as IdxSize..probe_key.len() as IdxSize);168*probe_first_unmatched = probe_key.len();169}170*probe_row_offset = probe_key.len();171}172173#[allow(clippy::mut_range_bound, clippy::too_many_arguments)]174fn match_null_keys_impl(175build_n: usize,176probe_n: usize,177gather_build: &mut Vec<IdxSize>,178gather_probe: &mut Vec<IdxSize>,179gather_probe_unmatched: Option<&mut Vec<IdxSize>>,180build_emit_unmatched: bool,181_descending: bool,182nulls_equal: bool,183limit_results: usize,184build_row_offset: &mut usize,185probe_row_offset: &mut usize,186probe_last_matched: &mut usize,187) {188assert!(gather_build.is_empty());189assert!(gather_probe.is_empty());190191if nulls_equal {192// All keys will match all other keys, so just emit the Cartesian product193while *build_row_offset < build_n {194if gather_build.len() >= limit_results {195return;196}197for probe_idx in *probe_row_offset..probe_n {198gather_build.push(*build_row_offset as IdxSize);199gather_probe.push(probe_idx as IdxSize);200}201*build_row_offset += 1;202}203} else {204// No keys can ever match, so just emit all build keys into gather_build205// and all probe keys into gather_probe_unmatched.206if build_emit_unmatched {207gather_build.extend(0..build_n as IdxSize);208gather_probe.extend(repeat_n(IdxSize::MAX, build_n));209}210if let Some(probe_unmatched) = gather_probe_unmatched {211probe_unmatched.extend(*probe_last_matched as IdxSize..probe_n as IdxSize);212*probe_last_matched = probe_n;213}214}215*build_row_offset = build_n;216*probe_row_offset = probe_n;217}218219#[allow(clippy::too_many_arguments)]220pub fn gather_and_postprocess(221build: DataFrame,222probe: DataFrame,223gather_build: Option<&[IdxSize]>,224gather_probe: Option<&[IdxSize]>,225df_builders: &mut Option<(DataFrameBuilder, DataFrameBuilder)>,226args: &JoinArgs,227left_on: &[PlSmallStr],228right_on: &[PlSmallStr],229left_is_build: bool,230output_schema: &Schema,231) -> PolarsResult<DataFrame> {232let should_coalesce = args.should_coalesce();233let left_emit_unmatched = matches!(args.how, JoinType::Left | JoinType::Full);234let right_emit_unmatched = matches!(args.how, JoinType::Right | JoinType::Full);235236let (mut left, mut right);237let (gather_left, gather_right);238if left_is_build {239(left, right) = (build, probe);240(gather_left, gather_right) = (gather_build, gather_probe);241} else {242(left, right) = (probe, build);243(gather_left, gather_right) = (gather_probe, gather_build);244}245246// Remove non-payload columns247for col in left248.columns()249.iter()250.map(Column::name)251.cloned()252.collect_vec()253{254if left_on.contains(&col) && should_coalesce {255continue;256}257if !output_schema.contains(&col) {258left.drop_in_place(&col).unwrap();259}260}261for col in right262.columns()263.iter()264.map(Column::name)265.cloned()266.collect_vec()267{268if left_on.contains(&col) && should_coalesce {269continue;270}271let renamed = match left.schema().contains(&col) {272true => Cow::Owned(format_pl_smallstr!("{}{}", col, args.suffix())),273false => Cow::Borrowed(&col),274};275if !output_schema.contains(&renamed) {276right.drop_in_place(&col).unwrap();277}278}279280if df_builders.is_none() {281*df_builders = Some((282DataFrameBuilder::new(left.schema().clone()),283DataFrameBuilder::new(right.schema().clone()),284));285}286287let (left_build, right_build) = df_builders.as_mut().unwrap();288let mut left = match gather_left {289Some(gather_left) if right_emit_unmatched => {290left_build.opt_gather_extend(&left, gather_left, ShareStrategy::Never);291left_build.freeze_reset()292},293Some(gather_left) => unsafe {294left_build.gather_extend(&left, gather_left, ShareStrategy::Never);295left_build.freeze_reset()296},297None => DataFrame::full_null(left.schema(), gather_right.unwrap().len()),298};299let mut right = match gather_right {300Some(gather_right) if left_emit_unmatched => {301right_build.opt_gather_extend(&right, gather_right, ShareStrategy::Never);302right_build.freeze_reset()303},304Some(gather_right) => unsafe {305right_build.gather_extend(&right, gather_right, ShareStrategy::Never);306right_build.freeze_reset()307},308None => DataFrame::full_null(right.schema(), gather_left.unwrap().len()),309};310311// Coalsesce the key columns312if args.how == JoinType::Left && should_coalesce {313for c in left_on {314if right.schema().contains(c) {315right.drop_in_place(c.as_str())?;316}317}318} else if args.how == JoinType::Right && should_coalesce {319for c in right_on {320if left.schema().contains(c) {321left.drop_in_place(c.as_str())?;322}323}324}325326// Rename any right columns to "{}_right"327let left_cols: PlHashSet<_> = left.columns().iter().map(Column::name).cloned().collect();328let right_cols_vec = right.get_column_names_owned();329let renames = right_cols_vec330.iter()331.filter(|c| left_cols.contains(*c))332.map(|c| {333let renamed = format_pl_smallstr!("{}{}", c, args.suffix());334(c.as_str(), renamed)335});336right.rename_many(renames).unwrap();337338left.hstack_mut(right.columns())?;339340if args.how == JoinType::Full && should_coalesce {341// Coalesce key columns342for (left_keycol, right_keycol) in Iterator::zip(left_on.iter(), right_on.iter()) {343let right_keycol = format_pl_smallstr!("{}{}", right_keycol, args.suffix());344let left_col = left.column(left_keycol).unwrap();345let right_col = left.column(&right_keycol).unwrap();346let coalesced = coalesce_columns(&[left_col.clone(), right_col.clone()]).unwrap();347left.replace(left_keycol, coalesced)348.unwrap()349.drop_in_place(&right_keycol)350.unwrap();351}352}353354if should_coalesce {355for col in left_on {356if left.schema().contains(col) && !output_schema.contains(col) {357left.drop_in_place(col).unwrap();358}359}360for col in right_on {361let renamed = match left.schema().contains(col) {362true => Cow::Owned(format_pl_smallstr!("{}{}", col, args.suffix())),363false => Cow::Borrowed(col),364};365if left.schema().contains(&renamed) && !output_schema.contains(&renamed) {366left.drop_in_place(&renamed).unwrap();367}368}369}370371debug_assert_eq!(**left.schema(), *output_schema);372Ok(left)373}374375376