Path: blob/main/crates/polars-plan/src/plans/ir/inputs.rs
6940 views
use std::iter;12use super::*;34impl IR {5/// Returns a node with updated expressions.6///7/// Panics if the expression count doesn't match8/// [`Self::exprs`]/[`Self::exprs_mut`]/[`Self::copy_exprs`].9pub fn with_exprs<E>(mut self, exprs: E) -> Self10where11E: IntoIterator<Item = ExprIR>,12{13let mut exprs_mut = self.exprs_mut();14let mut new_exprs = exprs.into_iter();1516for (expr, new_expr) in exprs_mut.by_ref().zip(new_exprs.by_ref()) {17*expr = new_expr;18}1920assert!(exprs_mut.next().is_none(), "not enough exprs");21assert!(new_exprs.next().is_none(), "too many exprs");2223drop(exprs_mut);2425self26}2728/// Returns a node with updated inputs.29///30/// Panics if the input count doesn't match31/// [`Self::inputs`]/[`Self::inputs_mut`]/[`Self::copy_inputs`]/[`Self::get_inputs`].32pub fn with_inputs<I>(mut self, inputs: I) -> Self33where34I: IntoIterator<Item = Node>,35{36let mut inputs_mut = self.inputs_mut();37let mut new_inputs = inputs.into_iter();3839for (input, new_input) in inputs_mut.by_ref().zip(new_inputs.by_ref()) {40*input = new_input;41}4243assert!(inputs_mut.next().is_none(), "not enough inputs");44assert!(new_inputs.next().is_none(), "too many inputs");4546drop(inputs_mut);4748self49}5051pub fn exprs(&'_ self) -> Exprs<'_> {52use IR::*;53match self {54Slice { .. } => Exprs::Empty,55Cache { .. } => Exprs::Empty,56Distinct { .. } => Exprs::Empty,57Union { .. } => Exprs::Empty,58MapFunction { .. } => Exprs::Empty,59DataFrameScan { .. } => Exprs::Empty,60HConcat { .. } => Exprs::Empty,61ExtContext { .. } => Exprs::Empty,62SimpleProjection { .. } => Exprs::Empty,63SinkMultiple { .. } => Exprs::Empty,64#[cfg(feature = "merge_sorted")]65MergeSorted { .. } => Exprs::Empty,6667#[cfg(feature = "python")]68PythonScan { options } => match &options.predicate {69PythonPredicate::Polars(predicate) => Exprs::single(predicate),70_ => Exprs::Empty,71},7273Scan { predicate, .. } => match predicate {74Some(predicate) => Exprs::single(predicate),75_ => Exprs::Empty,76},7778Filter { predicate, .. } => Exprs::single(predicate),7980Sort { by_column, .. } => Exprs::slice(by_column),81Select { expr, .. } => Exprs::slice(expr),82HStack { exprs, .. } => Exprs::slice(exprs),8384GroupBy { keys, aggs, .. } => Exprs::double_slice(keys, aggs),8586Join {87left_on,88right_on,89options,90..91} => match &options.options {92Some(JoinTypeOptionsIR::CrossAndFilter { predicate }) => Exprs::Boxed(Box::new(93left_on94.iter()95.chain(right_on.iter())96.chain(iter::once(predicate)),97)),98_ => Exprs::double_slice(left_on, right_on),99},100101Sink { payload, .. } => match payload {102SinkTypeIR::Memory => Exprs::Empty,103SinkTypeIR::File(_) => Exprs::Empty,104SinkTypeIR::Partition(p) => {105let key_iter = match &p.variant {106PartitionVariantIR::Parted { key_exprs, .. }107| PartitionVariantIR::ByKey { key_exprs, .. } => key_exprs.iter(),108_ => [].iter(),109};110let sort_by_iter = match &p.per_partition_sort_by {111Some(sort_by) => sort_by.iter(),112_ => [].iter(),113}114.map(|s| &s.expr);115Exprs::Boxed(Box::new(key_iter.chain(sort_by_iter)))116},117},118119Invalid => unreachable!(),120}121}122123pub fn exprs_mut(&'_ mut self) -> ExprsMut<'_> {124use IR::*;125match self {126Slice { .. } => ExprsMut::Empty,127Cache { .. } => ExprsMut::Empty,128Distinct { .. } => ExprsMut::Empty,129Union { .. } => ExprsMut::Empty,130MapFunction { .. } => ExprsMut::Empty,131DataFrameScan { .. } => ExprsMut::Empty,132HConcat { .. } => ExprsMut::Empty,133ExtContext { .. } => ExprsMut::Empty,134SimpleProjection { .. } => ExprsMut::Empty,135SinkMultiple { .. } => ExprsMut::Empty,136#[cfg(feature = "merge_sorted")]137MergeSorted { .. } => ExprsMut::Empty,138139#[cfg(feature = "python")]140PythonScan { options } => match &mut options.predicate {141PythonPredicate::Polars(predicate) => ExprsMut::single(predicate),142_ => ExprsMut::Empty,143},144145Scan { predicate, .. } => match predicate {146Some(predicate) => ExprsMut::single(predicate),147_ => ExprsMut::Empty,148},149150Filter { predicate, .. } => ExprsMut::single(predicate),151152Sort { by_column, .. } => ExprsMut::slice(by_column),153Select { expr, .. } => ExprsMut::slice(expr),154HStack { exprs, .. } => ExprsMut::slice(exprs),155156GroupBy { keys, aggs, .. } => ExprsMut::double_slice(keys, aggs),157158Join {159left_on,160right_on,161options,162..163} => match Arc::make_mut(options).options.as_mut() {164Some(JoinTypeOptionsIR::CrossAndFilter { predicate }) => ExprsMut::Boxed(Box::new(165left_on166.iter_mut()167.chain(right_on.iter_mut())168.chain(iter::once(predicate)),169)),170_ => ExprsMut::double_slice(left_on, right_on),171},172173Sink { payload, .. } => match payload {174SinkTypeIR::Memory => ExprsMut::Empty,175SinkTypeIR::File(_) => ExprsMut::Empty,176SinkTypeIR::Partition(p) => {177let key_iter = match &mut p.variant {178PartitionVariantIR::Parted { key_exprs, .. }179| PartitionVariantIR::ByKey { key_exprs, .. } => key_exprs.iter_mut(),180_ => [].iter_mut(),181};182let sort_by_iter = match &mut p.per_partition_sort_by {183Some(sort_by) => sort_by.iter_mut(),184_ => [].iter_mut(),185}186.map(|s| &mut s.expr);187ExprsMut::Boxed(Box::new(key_iter.chain(sort_by_iter)))188},189},190191Invalid => unreachable!(),192}193}194195/// Copy the exprs in this LP node to an existing container.196pub fn copy_exprs<T>(&self, container: &mut T)197where198T: Extend<ExprIR>,199{200container.extend(self.exprs().cloned())201}202203pub fn inputs(&'_ self) -> Inputs<'_> {204use IR::*;205match self {206Union { inputs, .. } | HConcat { inputs, .. } | SinkMultiple { inputs } => {207Inputs::slice(inputs)208},209Slice { input, .. } => Inputs::single(*input),210Filter { input, .. } => Inputs::single(*input),211Select { input, .. } => Inputs::single(*input),212SimpleProjection { input, .. } => Inputs::single(*input),213Sort { input, .. } => Inputs::single(*input),214Cache { input, .. } => Inputs::single(*input),215GroupBy { input, .. } => Inputs::single(*input),216Join {217input_left,218input_right,219..220} => Inputs::double(*input_left, *input_right),221HStack { input, .. } => Inputs::single(*input),222Distinct { input, .. } => Inputs::single(*input),223MapFunction { input, .. } => Inputs::single(*input),224Sink { input, .. } => Inputs::single(*input),225ExtContext {226input, contexts, ..227} => Inputs::Boxed(Box::new(iter::once(*input).chain(contexts.iter().copied()))),228Scan { .. } => Inputs::Empty,229DataFrameScan { .. } => Inputs::Empty,230#[cfg(feature = "python")]231PythonScan { .. } => Inputs::Empty,232#[cfg(feature = "merge_sorted")]233MergeSorted {234input_left,235input_right,236..237} => Inputs::double(*input_left, *input_right),238Invalid => unreachable!(),239}240}241242pub fn inputs_mut(&'_ mut self) -> InputsMut<'_> {243use IR::*;244match self {245Union { inputs, .. } | HConcat { inputs, .. } | SinkMultiple { inputs } => {246InputsMut::slice(inputs)247},248Slice { input, .. } => InputsMut::single(input),249Filter { input, .. } => InputsMut::single(input),250Select { input, .. } => InputsMut::single(input),251SimpleProjection { input, .. } => InputsMut::single(input),252Sort { input, .. } => InputsMut::single(input),253Cache { input, .. } => InputsMut::single(input),254GroupBy { input, .. } => InputsMut::single(input),255Join {256input_left,257input_right,258..259} => InputsMut::double(input_left, input_right),260HStack { input, .. } => InputsMut::single(input),261Distinct { input, .. } => InputsMut::single(input),262MapFunction { input, .. } => InputsMut::single(input),263Sink { input, .. } => InputsMut::single(input),264ExtContext {265input, contexts, ..266} => InputsMut::Boxed(Box::new(iter::once(input).chain(contexts.iter_mut()))),267Scan { .. } => InputsMut::Empty,268DataFrameScan { .. } => InputsMut::Empty,269#[cfg(feature = "python")]270PythonScan { .. } => InputsMut::Empty,271#[cfg(feature = "merge_sorted")]272MergeSorted {273input_left,274input_right,275..276} => InputsMut::double(input_left, input_right),277Invalid => unreachable!(),278}279}280281/// Push inputs of the LP in of this node to an existing container.282/// Most plans have typically one input. A join has two and a scan (CsvScan)283/// or an in-memory DataFrame has none. A Union has multiple.284pub fn copy_inputs<T>(&self, container: &mut T)285where286T: Extend<Node>,287{288container.extend(self.inputs())289}290291pub fn get_inputs(&self) -> UnitVec<Node> {292self.inputs().collect()293}294295pub(crate) fn get_input(&self) -> Option<Node> {296self.inputs().next()297}298}299300pub enum Inputs<'a> {301Empty,302Single(iter::Once<Node>),303Double(std::array::IntoIter<Node, 2>),304Slice(iter::Copied<std::slice::Iter<'a, Node>>),305Boxed(Box<dyn Iterator<Item = Node> + 'a>),306}307308impl<'a> Inputs<'a> {309fn single(node: Node) -> Self {310Self::Single(iter::once(node))311}312313fn double(left: Node, right: Node) -> Self {314Self::Double([left, right].into_iter())315}316317fn slice(inputs: &'a [Node]) -> Self {318Self::Slice(inputs.iter().copied())319}320}321322impl<'a> Iterator for Inputs<'a> {323type Item = Node;324325fn next(&mut self) -> Option<Self::Item> {326match self {327Self::Empty => None,328Self::Single(it) => it.next(),329Self::Double(it) => it.next(),330Self::Slice(it) => it.next(),331Self::Boxed(it) => it.next(),332}333}334}335336pub enum InputsMut<'a> {337Empty,338Single(iter::Once<&'a mut Node>),339Double(std::array::IntoIter<&'a mut Node, 2>),340Slice(std::slice::IterMut<'a, Node>),341Boxed(Box<dyn Iterator<Item = &'a mut Node> + 'a>),342}343344impl<'a> InputsMut<'a> {345fn single(node: &'a mut Node) -> Self {346Self::Single(iter::once(node))347}348349fn double(left: &'a mut Node, right: &'a mut Node) -> Self {350Self::Double([left, right].into_iter())351}352353fn slice(inputs: &'a mut [Node]) -> Self {354Self::Slice(inputs.iter_mut())355}356}357358impl<'a> Iterator for InputsMut<'a> {359type Item = &'a mut Node;360361fn next(&mut self) -> Option<Self::Item> {362match self {363Self::Empty => None,364Self::Single(it) => it.next(),365Self::Double(it) => it.next(),366Self::Slice(it) => it.next(),367Self::Boxed(it) => it.next(),368}369}370}371372pub enum Exprs<'a> {373Empty,374Single(iter::Once<&'a ExprIR>),375Slice(std::slice::Iter<'a, ExprIR>),376DoubleSlice(iter::Chain<std::slice::Iter<'a, ExprIR>, std::slice::Iter<'a, ExprIR>>),377Boxed(Box<dyn Iterator<Item = &'a ExprIR> + 'a>),378}379380impl<'a> Exprs<'a> {381fn single(expr: &'a ExprIR) -> Self {382Self::Single(iter::once(expr))383}384385fn slice(inputs: &'a [ExprIR]) -> Self {386Self::Slice(inputs.iter())387}388389fn double_slice(left: &'a [ExprIR], right: &'a [ExprIR]) -> Self {390Self::DoubleSlice(left.iter().chain(right.iter()))391}392}393394impl<'a> Iterator for Exprs<'a> {395type Item = &'a ExprIR;396397fn next(&mut self) -> Option<Self::Item> {398match self {399Self::Empty => None,400Self::Single(it) => it.next(),401Self::Slice(it) => it.next(),402Self::DoubleSlice(it) => it.next(),403Self::Boxed(it) => it.next(),404}405}406}407408pub enum ExprsMut<'a> {409Empty,410Single(iter::Once<&'a mut ExprIR>),411Slice(std::slice::IterMut<'a, ExprIR>),412DoubleSlice(iter::Chain<std::slice::IterMut<'a, ExprIR>, std::slice::IterMut<'a, ExprIR>>),413Boxed(Box<dyn Iterator<Item = &'a mut ExprIR> + 'a>),414}415416impl<'a> ExprsMut<'a> {417fn single(expr: &'a mut ExprIR) -> Self {418Self::Single(iter::once(expr))419}420421fn slice(inputs: &'a mut [ExprIR]) -> Self {422Self::Slice(inputs.iter_mut())423}424425fn double_slice(left: &'a mut [ExprIR], right: &'a mut [ExprIR]) -> Self {426Self::DoubleSlice(left.iter_mut().chain(right.iter_mut()))427}428}429430impl<'a> Iterator for ExprsMut<'a> {431type Item = &'a mut ExprIR;432433fn next(&mut self) -> Option<Self::Item> {434match self {435Self::Empty => None,436Self::Single(it) => it.next(),437Self::Slice(it) => it.next(),438Self::DoubleSlice(it) => it.next(),439Self::Boxed(it) => it.next(),440}441}442}443444445