Path: blob/main/crates/polars-ops/src/frame/join/args.rs
6940 views
use super::*;12pub(super) type JoinIds = Vec<IdxSize>;3pub type LeftJoinIds = (ChunkJoinIds, ChunkJoinOptIds);4pub type InnerJoinIds = (JoinIds, JoinIds);56#[cfg(feature = "chunked_ids")]7pub(super) type ChunkJoinIds = Either<Vec<IdxSize>, Vec<ChunkId>>;8#[cfg(feature = "chunked_ids")]9pub type ChunkJoinOptIds = Either<Vec<NullableIdxSize>, Vec<ChunkId>>;1011#[cfg(not(feature = "chunked_ids"))]12pub type ChunkJoinOptIds = Vec<NullableIdxSize>;1314#[cfg(not(feature = "chunked_ids"))]15pub type ChunkJoinIds = Vec<IdxSize>;1617#[cfg(feature = "serde")]18use serde::{Deserialize, Serialize};19use strum_macros::IntoStaticStr;2021#[derive(Clone, PartialEq, Debug, Hash, Default)]22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]23#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]24pub struct JoinArgs {25pub how: JoinType,26pub validation: JoinValidation,27pub suffix: Option<PlSmallStr>,28pub slice: Option<(i64, usize)>,29pub nulls_equal: bool,30pub coalesce: JoinCoalesce,31pub maintain_order: MaintainOrderJoin,32}3334impl JoinArgs {35pub fn should_coalesce(&self) -> bool {36self.coalesce.coalesce(&self.how)37}38}3940#[derive(Clone, PartialEq, Hash, Default, IntoStaticStr)]41#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]42#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]43pub enum JoinType {44#[default]45Inner,46Left,47Right,48Full,49// Box is okay because this is inside a `Arc<JoinOptionsIR>`50#[cfg(feature = "asof_join")]51AsOf(Box<AsOfOptions>),52#[cfg(feature = "semi_anti_join")]53Semi,54#[cfg(feature = "semi_anti_join")]55Anti,56#[cfg(feature = "iejoin")]57// Options are set by optimizer/planner in Options58IEJoin,59// Options are set by optimizer/planner in Options60Cross,61}6263#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default)]64#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]65#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]66pub enum JoinCoalesce {67#[default]68JoinSpecific,69CoalesceColumns,70KeepColumns,71}7273impl JoinCoalesce {74pub fn coalesce(&self, join_type: &JoinType) -> bool {75use JoinCoalesce::*;76use JoinType::*;77match join_type {78Left | Inner | Right => {79matches!(self, JoinSpecific | CoalesceColumns)80},81Full => {82matches!(self, CoalesceColumns)83},84#[cfg(feature = "asof_join")]85AsOf(_) => matches!(self, JoinSpecific | CoalesceColumns),86#[cfg(feature = "iejoin")]87IEJoin => false,88Cross => false,89#[cfg(feature = "semi_anti_join")]90Semi | Anti => false,91}92}93}9495#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash, Default, IntoStaticStr)]96#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]97#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]98#[strum(serialize_all = "snake_case")]99pub enum MaintainOrderJoin {100#[default]101None,102Left,103Right,104LeftRight,105RightLeft,106}107108impl MaintainOrderJoin {109pub(super) fn flip(&self) -> Self {110match self {111MaintainOrderJoin::None => MaintainOrderJoin::None,112MaintainOrderJoin::Left => MaintainOrderJoin::Right,113MaintainOrderJoin::Right => MaintainOrderJoin::Left,114MaintainOrderJoin::LeftRight => MaintainOrderJoin::RightLeft,115MaintainOrderJoin::RightLeft => MaintainOrderJoin::LeftRight,116}117}118}119120impl JoinArgs {121pub fn new(how: JoinType) -> Self {122Self {123how,124validation: Default::default(),125suffix: None,126slice: None,127nulls_equal: false,128coalesce: Default::default(),129maintain_order: Default::default(),130}131}132133pub fn with_coalesce(mut self, coalesce: JoinCoalesce) -> Self {134self.coalesce = coalesce;135self136}137138pub fn with_suffix(mut self, suffix: Option<PlSmallStr>) -> Self {139self.suffix = suffix;140self141}142143pub fn suffix(&self) -> &PlSmallStr {144const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");145self.suffix.as_ref().unwrap_or(DEFAULT)146}147}148149impl From<JoinType> for JoinArgs {150fn from(value: JoinType) -> Self {151JoinArgs::new(value)152}153}154155pub trait CrossJoinFilter: Send + Sync {156fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame>;157}158159impl<T> CrossJoinFilter for T160where161T: Fn(DataFrame) -> PolarsResult<DataFrame> + Send + Sync,162{163fn apply(&self, df: DataFrame) -> PolarsResult<DataFrame> {164self(df)165}166}167168#[derive(Clone)]169pub struct CrossJoinOptions {170pub predicate: Arc<dyn CrossJoinFilter>,171}172173impl CrossJoinOptions {174fn as_ptr_ref(&self) -> *const dyn CrossJoinFilter {175Arc::as_ptr(&self.predicate)176}177}178179impl Eq for CrossJoinOptions {}180181impl PartialEq for CrossJoinOptions {182fn eq(&self, other: &Self) -> bool {183std::ptr::addr_eq(self.as_ptr_ref(), other.as_ptr_ref())184}185}186187impl Hash for CrossJoinOptions {188fn hash<H: std::hash::Hasher>(&self, state: &mut H) {189self.as_ptr_ref().hash(state);190}191}192193impl Debug for CrossJoinOptions {194fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {195write!(f, "CrossJoinOptions",)196}197}198199#[derive(Clone, PartialEq, Eq, Hash, IntoStaticStr, Debug)]200#[strum(serialize_all = "snake_case")]201pub enum JoinTypeOptions {202#[cfg(feature = "iejoin")]203IEJoin(IEJoinOptions),204Cross(CrossJoinOptions),205}206207impl Display for JoinType {208fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {209use JoinType::*;210let val = match self {211Left => "LEFT",212Right => "RIGHT",213Inner => "INNER",214Full => "FULL",215#[cfg(feature = "asof_join")]216AsOf(_) => "ASOF",217#[cfg(feature = "iejoin")]218IEJoin => "IEJOIN",219Cross => "CROSS",220#[cfg(feature = "semi_anti_join")]221Semi => "SEMI",222#[cfg(feature = "semi_anti_join")]223Anti => "ANTI",224};225write!(f, "{val}")226}227}228229impl Debug for JoinType {230fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {231write!(f, "{self}")232}233}234235impl JoinType {236pub fn is_equi(&self) -> bool {237matches!(238self,239JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full240)241}242243pub fn is_semi_anti(&self) -> bool {244#[cfg(feature = "semi_anti_join")]245{246matches!(self, JoinType::Semi | JoinType::Anti)247}248#[cfg(not(feature = "semi_anti_join"))]249{250false251}252}253254pub fn is_semi(&self) -> bool {255#[cfg(feature = "semi_anti_join")]256{257matches!(self, JoinType::Semi)258}259#[cfg(not(feature = "semi_anti_join"))]260{261false262}263}264265pub fn is_anti(&self) -> bool {266#[cfg(feature = "semi_anti_join")]267{268matches!(self, JoinType::Anti)269}270#[cfg(not(feature = "semi_anti_join"))]271{272false273}274}275276pub fn is_asof(&self) -> bool {277#[cfg(feature = "asof_join")]278{279matches!(self, JoinType::AsOf(_))280}281#[cfg(not(feature = "asof_join"))]282{283false284}285}286287pub fn is_cross(&self) -> bool {288matches!(self, JoinType::Cross)289}290291pub fn is_ie(&self) -> bool {292#[cfg(feature = "iejoin")]293{294matches!(self, JoinType::IEJoin)295}296#[cfg(not(feature = "iejoin"))]297{298false299}300}301}302303#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]304#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]305#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]306pub enum JoinValidation {307/// No unique checks308#[default]309ManyToMany,310/// Check if join keys are unique in right dataset.311ManyToOne,312/// Check if join keys are unique in left dataset.313OneToMany,314/// Check if join keys are unique in both left and right datasets315OneToOne,316}317318impl JoinValidation {319pub fn needs_checks(&self) -> bool {320!matches!(self, JoinValidation::ManyToMany)321}322323fn swap(self, swap: bool) -> Self {324use JoinValidation::*;325if swap {326match self {327ManyToMany => ManyToMany,328ManyToOne => OneToMany,329OneToMany => ManyToOne,330OneToOne => OneToOne,331}332} else {333self334}335}336337pub fn is_valid_join(&self, join_type: &JoinType) -> PolarsResult<()> {338if !self.needs_checks() {339return Ok(());340}341polars_ensure!(matches!(join_type, JoinType::Inner | JoinType::Full | JoinType::Left),342ComputeError: "{self} validation on a {join_type} join is not supported");343Ok(())344}345346pub(super) fn validate_probe(347&self,348s_left: &Series,349s_right: &Series,350build_shortest_table: bool,351nulls_equal: bool,352) -> PolarsResult<()> {353// In default, probe is the left series.354//355// In inner join and outer join, the shortest relation will be used to create a hash table.356// In left join, always use the right side to create.357//358// If `build_shortest_table` and left is shorter, swap. Then rhs will be the probe.359// If left == right, swap too. (apply the same logic as `det_hash_prone_order`)360let should_swap = build_shortest_table && s_left.len() <= s_right.len();361let probe = if should_swap { s_right } else { s_left };362363use JoinValidation::*;364let valid = match self.swap(should_swap) {365// Only check the `build` side.366// The other side use `validate_build` to check367ManyToMany | ManyToOne => true,368OneToMany | OneToOne => {369if !nulls_equal && probe.null_count() > 0 {370probe.n_unique()? - 1 == probe.len() - probe.null_count()371} else {372probe.n_unique()? == probe.len()373}374},375};376polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);377Ok(())378}379380pub(super) fn validate_build(381&self,382build_size: usize,383expected_size: usize,384swapped: bool,385) -> PolarsResult<()> {386use JoinValidation::*;387388// In default, build is in rhs.389let valid = match self.swap(swapped) {390// Only check the `build` side.391// The other side use `validate_prone` to check392ManyToMany | OneToMany => true,393ManyToOne | OneToOne => build_size == expected_size,394};395polars_ensure!(valid, ComputeError: "join keys did not fulfill {} validation", self);396Ok(())397}398}399400impl Display for JoinValidation {401fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {402let s = match self {403JoinValidation::ManyToMany => "m:m",404JoinValidation::ManyToOne => "m:1",405JoinValidation::OneToMany => "1:m",406JoinValidation::OneToOne => "1:1",407};408write!(f, "{s}")409}410}411412impl Debug for JoinValidation {413fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {414write!(f, "JoinValidation: {self}")415}416}417418419