Path: blob/main/crates/polars-plan/src/plans/builder_ir.rs
6940 views
use std::borrow::Cow;12use super::*;34pub struct IRBuilder<'a> {5root: Node,6expr_arena: &'a mut Arena<AExpr>,7lp_arena: &'a mut Arena<IR>,8}910impl<'a> IRBuilder<'a> {11pub fn new(root: Node, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {12IRBuilder {13root,14expr_arena,15lp_arena,16}17}1819pub fn from_lp(lp: IR, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {20let root = lp_arena.add(lp);21IRBuilder {22root,23expr_arena,24lp_arena,25}26}2728pub fn add_alp(self, lp: IR) -> Self {29let node = self.lp_arena.add(lp);30IRBuilder::new(node, self.expr_arena, self.lp_arena)31}3233/// Adds IR and runs optimizations on its expressions (simplify, coerce, type-check).34pub fn add_alp_optimize_exprs<F>(self, f: F) -> PolarsResult<Self>35where36F: FnOnce(Node) -> IR,37{38let lp = f(self.root);39let ir_name = lp.name();4041let b = self.add_alp(lp);4243// Run the optimizer44let mut conversion_optimizer = ConversionOptimizer::new(true, true, true);45conversion_optimizer.fill_scratch(b.lp_arena.get(b.root).exprs(), b.expr_arena);46conversion_optimizer47.optimize_exprs(b.expr_arena, b.lp_arena, b.root, false)48.map_err(|e| e.context(format!("optimizing '{ir_name}' failed").into()))?;4950Ok(b)51}5253/// An escape hatch to add an `Expr`. Working with IR is preferred.54pub fn add_expr(&mut self, expr: Expr) -> PolarsResult<ExprIR> {55let schema = self.lp_arena.get(self.root).schema(self.lp_arena);56let mut ctx = ExprToIRContext::new(self.expr_arena, &schema);57to_expr_ir(expr, &mut ctx)58}5960pub fn project(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {61// if len == 0, no projection has to be done. This is a select all operation.62if exprs.is_empty() {63self64} else {65let input_schema = self.schema();66let schema = expr_irs_to_schema(&exprs, &input_schema, self.expr_arena);6768let lp = IR::Select {69expr: exprs,70input: self.root,71schema: Arc::new(schema),72options,73};74let node = self.lp_arena.add(lp);75IRBuilder::new(node, self.expr_arena, self.lp_arena)76}77}7879pub fn project_simple_nodes<I, N>(self, nodes: I) -> PolarsResult<Self>80where81I: IntoIterator<Item = N>,82N: Into<Node>,83I::IntoIter: ExactSizeIterator,84{85let names = nodes86.into_iter()87.map(|node| match self.expr_arena.get(node.into()) {88AExpr::Column(name) => name,89_ => unreachable!(),90});91// This is a duplication of `project_simple` because we already borrow self.expr_arena :/92if names.size_hint().0 == 0 {93Ok(self)94} else {95let input_schema = self.schema();96let mut count = 0;97let schema = names98.map(|name| {99let dtype = input_schema.try_get(name)?;100count += 1;101Ok(Field::new(name.clone(), dtype.clone()))102})103.collect::<PolarsResult<Schema>>()?;104105polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");106107let lp = IR::SimpleProjection {108input: self.root,109columns: Arc::new(schema),110};111let node = self.lp_arena.add(lp);112Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))113}114}115116pub fn project_simple<I, S>(self, names: I) -> PolarsResult<Self>117where118I: IntoIterator<Item = S>,119I::IntoIter: ExactSizeIterator,120S: Into<PlSmallStr>,121{122let names = names.into_iter();123// if len == 0, no projection has to be done. This is a select all operation.124if names.size_hint().0 == 0 {125Ok(self)126} else {127let input_schema = self.schema();128let mut count = 0;129let schema = names130.map(|name| {131let name: PlSmallStr = name.into();132let dtype = input_schema.try_get(name.as_str())?;133count += 1;134Ok(Field::new(name, dtype.clone()))135})136.collect::<PolarsResult<Schema>>()?;137138polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");139140let lp = IR::SimpleProjection {141input: self.root,142columns: Arc::new(schema),143};144let node = self.lp_arena.add(lp);145Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))146}147}148149pub fn drop<I, S>(self, names: I) -> Self150where151I: IntoIterator<Item = S>,152I::IntoIter: ExactSizeIterator,153S: Into<PlSmallStr>,154{155let names = names.into_iter();156// if len == 0, no projection has to be done. This is a select all operation.157if names.size_hint().0 == 0 {158self159} else {160let mut schema = self.schema().as_ref().as_ref().clone();161162for name in names {163let name: PlSmallStr = name.into();164schema.remove(&name);165}166167let lp = IR::SimpleProjection {168input: self.root,169columns: Arc::new(schema),170};171let node = self.lp_arena.add(lp);172IRBuilder::new(node, self.expr_arena, self.lp_arena)173}174}175176pub fn sort(177self,178by_column: Vec<ExprIR>,179slice: Option<(i64, usize)>,180sort_options: SortMultipleOptions,181) -> Self {182let ir = IR::Sort {183input: self.root,184by_column,185slice,186sort_options,187};188let node = self.lp_arena.add(ir);189IRBuilder::new(node, self.expr_arena, self.lp_arena)190}191192pub fn node(self) -> Node {193self.root194}195196pub fn build(self) -> IR {197if self.root.0 == self.lp_arena.len() {198self.lp_arena.pop().unwrap()199} else {200self.lp_arena.take(self.root)201}202}203204pub fn schema(&'a self) -> Cow<'a, SchemaRef> {205self.lp_arena.get(self.root).schema(self.lp_arena)206}207208pub fn with_columns(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {209let schema = self.schema();210let mut new_schema = (**schema).clone();211212let hstack_schema = expr_irs_to_schema(&exprs, &schema, self.expr_arena);213new_schema.merge(hstack_schema);214215let lp = IR::HStack {216input: self.root,217exprs,218schema: Arc::new(new_schema),219options,220};221self.add_alp(lp)222}223224pub fn with_columns_simple<I, J: Into<Node>>(self, exprs: I, options: ProjectionOptions) -> Self225where226I: IntoIterator<Item = J>,227{228let schema = self.schema();229let mut new_schema = (**schema).clone();230231let iter = exprs.into_iter();232let mut expr_irs = Vec::with_capacity(iter.size_hint().0);233for node in iter {234let node = node.into();235let field = self236.expr_arena237.get(node)238.to_field(&schema, self.expr_arena)239.unwrap();240241expr_irs.push(242ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))243.with_dtype(field.dtype.clone()),244);245new_schema.with_column(field.name().clone(), field.dtype().clone());246}247248let lp = IR::HStack {249input: self.root,250exprs: expr_irs,251schema: Arc::new(new_schema),252options,253};254self.add_alp(lp)255}256257// call this if the schema needs to be updated258pub fn explode(self, columns: Arc<[PlSmallStr]>) -> Self {259let lp = IR::MapFunction {260input: self.root,261function: FunctionIR::Explode {262columns,263schema: Default::default(),264},265};266self.add_alp(lp)267}268269pub fn group_by(270self,271keys: Vec<ExprIR>,272aggs: Vec<ExprIR>,273apply: Option<PlanCallback<DataFrame, DataFrame>>,274maintain_order: bool,275options: Arc<GroupbyOptions>,276) -> Self {277let current_schema = self.schema();278let mut schema = expr_irs_to_schema(&keys, ¤t_schema, self.expr_arena);279280#[cfg(feature = "dynamic_group_by")]281{282if let Some(options) = options.rolling.as_ref() {283let name = &options.index_column;284let dtype = current_schema.get(name).unwrap();285schema.with_column(name.clone(), dtype.clone());286} else if let Some(options) = options.dynamic.as_ref() {287let name = &options.index_column;288let dtype = current_schema.get(name).unwrap();289if options.include_boundaries {290schema.with_column("_lower_boundary".into(), dtype.clone());291schema.with_column("_upper_boundary".into(), dtype.clone());292}293schema.with_column(name.clone(), dtype.clone());294}295}296297let mut aggs_schema = expr_irs_to_schema(&aggs, ¤t_schema, self.expr_arena);298299// Coerce aggregation column(s) into List unless not needed (auto-implode)300debug_assert!(aggs_schema.len() == aggs.len());301for ((_name, dtype), expr) in aggs_schema.iter_mut().zip(&aggs) {302if !expr.is_scalar(self.expr_arena) {303*dtype = dtype.clone().implode();304}305}306307schema.merge(aggs_schema);308309let lp = IR::GroupBy {310input: self.root,311keys,312aggs,313schema: Arc::new(schema),314apply,315maintain_order,316options,317};318self.add_alp(lp)319}320321pub fn join(322self,323other: Node,324left_on: Vec<ExprIR>,325right_on: Vec<ExprIR>,326options: Arc<JoinOptionsIR>,327) -> Self {328let schema_left = self.schema();329let schema_right = self.lp_arena.get(other).schema(self.lp_arena);330331let schema = det_join_schema(332&schema_left,333&schema_right,334&left_on,335&right_on,336&options,337self.expr_arena,338)339.unwrap();340341let lp = IR::Join {342input_left: self.root,343input_right: other,344schema,345left_on,346right_on,347options,348};349350self.add_alp(lp)351}352353#[cfg(feature = "pivot")]354pub fn unpivot(self, args: Arc<UnpivotArgsIR>) -> Self {355let lp = IR::MapFunction {356input: self.root,357function: FunctionIR::Unpivot {358args,359schema: Default::default(),360},361};362self.add_alp(lp)363}364365pub fn row_index(self, name: PlSmallStr, offset: Option<IdxSize>) -> Self {366let lp = IR::MapFunction {367input: self.root,368function: FunctionIR::RowIndex {369name,370offset,371schema: Default::default(),372},373};374self.add_alp(lp)375}376}377378379