Path: blob/main/crates/polars-ops/src/frame/join/args.rs
8446 views
use super::*;12pub(super) type JoinIds = Vec<IdxSize>;3pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);4pub type InnerJoinIds = (JoinIds, JoinIds);56#[cfg(feature = "chunked_ids")]7pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;8#[cfg(feature = "chunked_ids")]9pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;1011#[cfg(not(feature = "chunked_ids"))]12pub type ChunkJoinOptIds = Vec<NullableIdxSize>;1314#[cfg(not(feature = "chunked_ids"))]15pub type ChunkJoinIds = Vec<IdxSize>;1617#[cfg(feature = "serde")]18use serde::{Deserialize, Serialize};19use strum_macros::IntoStaticStr;2021/// Parameters for which side to use as the build side in a join. Currently only22/// respected by the streaming engine.23#[derive(Clone, PartialEq, Debug, Hash)]24#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]25#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]26pub enum JoinBuildSide {27/// Unless there's a very good reason to believe that the right side is28/// smaller, use the left side.29PreferLeft,30/// Regardless of other heuristics, use the left side as build side.31ForceLeft,3233// Similar to above.34PreferRight,35ForceRight,36}3738#[derive(Clone, PartialEq, Debug, Hash, Default)]39#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]40#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]41pub struct JoinArgs {42pub how: JoinType,43pub validation: JoinValidation,44pub suffix: Option<PlSmallStr>,45pub slice: Option<(i64, usize)>,46pub nulls_equal: bool,47pub coalesce: JoinCoalesce,48pub maintain_order: MaintainOrderJoin,49pub build_side: Option<JoinBuildSide>,50}5152impl JoinArgs {53pub fn should_coalesce(&self) -> bool {54self.coalesce.coalesce(&self.how)55}56}5758#[derive(Clone, PartialEq, Hash, Default, IntoStaticStr)]59#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]60#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]61pub enum JoinType {62#[default]63Inner,64Left,65Right,66Full,67// Box is okay because this is inside a `Arc<JoinOptionsIR>`68#[cfg(feature = "asof_join")]69AsOf(Box<AsOfOptions>),70#[cfg(feature = "semi_anti_join")]71Semi,72#[cfg(feature = "semi_anti_join")]73Anti,74#[cfg(feature = "iejoin")]75// Options are set by optimizer/planner in Options76IEJoin,77// Options are set by optimizer/planner in Options78Cross,79}8081#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]82#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]83#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]84pub enum JoinCoalesce {85#[default]86JoinSpecific,87CoalesceColumns,88KeepColumns,89}9091impl JoinCoalesce {92pub fn coalesce(&self, join_type: &JoinType) -> bool {93use JoinCoalesce::*;94use JoinType::*;95match join_type {96Left | Inner | Right => {97matches!(self, JoinSpecific | CoalesceColumns)98},99Full => {100matches!(self, CoalesceColumns)101},102#[cfg(feature = "asof_join")]103AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),104#[cfg(feature = "iejoin")]105IEJoin => false,106Cross => false,107#[cfg(feature = "semi_anti_join")]108Semi | Anti => false,109}110}111}112113#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]114#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]115#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]116#[strum(serialize_all = "snake_case")]117pub enum MaintainOrderJoin {118#[default]119None,120Left,121Right,122LeftRight,123RightLeft,124}125126impl MaintainOrderJoin {127pub(super) fn flip(&self) -> Self {128match self {129MaintainOrderJoin::None => MaintainOrderJoin::None,130MaintainOrderJoin::Left => MaintainOrderJoin::Right,131MaintainOrderJoin::Right => MaintainOrderJoin::Left,132MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,133MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,134}135}136}137138impl JoinArgs {139pub fn new(how: JoinType) -> Self {140Self {141how,142validation: Default::default(),143suffix: None,144slice: None,145nulls_equal: false,146coalesce: Default::default(),147maintain_order: Default::default(),148build_side: None,149}150}151152pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {153self.coalesce = coalesce;154self155}156157pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {158self.suffix = suffix;159self160}161162pub fn with_build_side(mut self, build_side: Option<JoinBuildSide>) -> Self {163self.build_side = build_side;164self165}166167pub fn suffix(&self) -> &PlSmallStr {168const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");169self.suffix.as_ref().unwrap_or(DEFAULT)170}171}172173impl From<JoinType> for JoinArgs {174fn from(value: JoinType) -> Self {175JoinArgs::new(value)176}177}178179pub trait CrossJoinFilter: Send + Sync {180fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;181}182183impl<T> CrossJoinFilter for T184where185T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,186{187fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {188self(df)189}190}191192#[derive(Clone)]193pub struct CrossJoinOptions {194pub predicate: Arc<dyn CrossJoinFilter>,195}196197impl CrossJoinOptions {198fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {199Arc::as_ptr(&self.predicate)200}201}202203impl Eq for CrossJoinOptions {}204205impl PartialEq for CrossJoinOptions {206fn eq(&self, other: &Self) -> bool {207std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())208}209}210211impl Hash for CrossJoinOptions {212fn hash<H: std::hash::Hasher>(&self, state: &mut H) {213self.as_ptr_ref().hash(state);214}215}216217impl Debug for CrossJoinOptions {218fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {219write!(f, "CrossJoinOptions",)220}221}222223#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]224#[strum(serialize_all = "snake_case")]225pub enum JoinTypeOptions {226#[cfg(feature = "iejoin")]227IEJoin(IEJoinOptions),228Cross(CrossJoinOptions),229}230231impl Display for JoinType {232fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {233use JoinType::*;234let val = match self {235Left => "LEFT",236Right => "RIGHT",237Inner => "INNER",238Full => "FULL",239#[cfg(feature = "asof_join")]240AsOf(_) => "ASOF",241#[cfg(feature = "iejoin")]242IEJoin => "IEJOIN",243Cross => "CROSS",244#[cfg(feature = "semi_anti_join")]245Semi => "SEMI",246#[cfg(feature = "semi_anti_join")]247Anti => "ANTI",248};249write!(f, "{val}")250}251}252253impl Debug for JoinType {254fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {255write!(f, "{self}")256}257}258259impl JoinType {260pub fn is_equi(&self) -> bool {261matches!(262self,263JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full264)265}266267pub fn is_semi_anti(&self) -> bool {268#[cfg(feature = "semi_anti_join")]269{270matches!(self, JoinType::Semi | JoinType::Anti)271}272#[cfg(not(feature = "semi_anti_join"))]273{274false275}276}277278pub fn is_semi(&self) -> bool {279#[cfg(feature = "semi_anti_join")]280{281matches!(self, JoinType::Semi)282}283#[cfg(not(feature = "semi_anti_join"))]284{285false286}287}288289pub fn is_anti(&self) -> bool {290#[cfg(feature = "semi_anti_join")]291{292matches!(self, JoinType::Anti)293}294#[cfg(not(feature = "semi_anti_join"))]295{296false297}298}299300pub fn is_asof(&self) -> bool {301#[cfg(feature = "asof_join")]302{303matches!(self, JoinType::AsOf(_))304}305#[cfg(not(feature = "asof_join"))]306{307false308}309}310311pub fn is_cross(&self) -> bool {312matches!(self, JoinType::Cross)313}314315pub fn is_ie(&self) -> bool {316#[cfg(feature = "iejoin")]317{318matches!(self, JoinType::IEJoin)319}320#[cfg(not(feature = "iejoin"))]321{322false323}324}325}326327#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]328#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]329#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]330pub enum JoinValidation {331/// No unique checks332#[default]333ManyToMany,334/// Check if join keys are unique in right dataset.335ManyToOne,336/// Check if join keys are unique in left dataset.337OneToMany,338/// Check if join keys are unique in both left and right datasets339OneToOne,340}341342impl JoinValidation {343pub fn needs_checks(&self) -> bool {344!matches!(self, JoinValidation::ManyToMany)345}346347fn swap(self, swap: bool) -> Self {348use JoinValidation::*;349if swap {350match self {351ManyToMany => ManyToMany,352ManyToOne => OneToMany,353OneToMany => ManyToOne,354OneToOne => OneToOne,355}356} else {357self358}359}360361pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {362if !self.needs_checks() {363return Ok(());364}365polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),366ComputeError: "{self} validation on a {join_type} join is not supported");367Ok(())368}369370pub(super) fn validate_probe(371&self,372s_left: &Series,373s_right: &Series,374build_shortest_table: bool,375nulls_equal: bool,376) -> PolarsResult<()> {377// In default, probe is the left series.378//379// In inner join and outer join, the shortest relation will be used to create a hash table.380// In left join, always use the right side to create.381//382// If `build_shortest_table` and left is shorter, swap. Then rhs will be the probe.383// If left == right, swap too. (apply the same logic as `det_hash_prone_order`)384let should_swap = build_shortest_table && s_left.len() <= s_right.len();385let probe = if should_swap { s_right } else { s_left };386387use JoinValidation::*;388let valid = match self.swap(should_swap) {389// Only check the `build` side.390// The other side use `validate_build` to check391ManyToMany | ManyToOne => true,392OneToMany | OneToOne => {393if !nulls_equal && probe.null_count() > 0 {394probe.n_unique()? - 1 == probe.len() - probe.null_count()395} else {396probe.n_unique()? == probe.len()397}398},399};400polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);401Ok(())402}403404pub(super) fn validate_build(405&self,406build_size: usize,407expected_size: usize,408swapped: bool,409) -> PolarsResult<()> {410use JoinValidation::*;411412// In default, build is in rhs.413let valid = match self.swap(swapped) {414// Only check the `build` side.415// The other side use `validate_prone` to check416ManyToMany | OneToMany => true,417ManyToOne | OneToOne => build_size == expected_size,418};419polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);420Ok(())421}422}423424impl Display for JoinValidation {425fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {426let s = match self {427JoinValidation::ManyToMany => "m:m",428JoinValidation::ManyToOne => "m:1",429JoinValidation::OneToMany => "1:m",430JoinValidation::OneToOne => "1:1",431};432write!(f, "{s}")433}434}435436impl Debug for JoinValidation {437fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {438write!(f, "JoinValidation: {self}")439}440}441442443