Path: blob/main/crates/polars-plan/src/plans/builder_ir.rs
8430 views
use std::borrow::Cow;12use super::*;34pub struct IRBuilder<'a> {5root: Node,6expr_arena: &'a mut Arena<AExpr>,7lp_arena: &'a mut Arena<IR>,8}910impl<'a> IRBuilder<'a> {11pub fn new(root: Node, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {12IRBuilder {13root,14expr_arena,15lp_arena,16}17}1819pub fn from_lp(lp: IR, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {20let root = lp_arena.add(lp);21IRBuilder {22root,23expr_arena,24lp_arena,25}26}2728pub fn add_alp(self, lp: IR) -> Self {29let node = self.lp_arena.add(lp);30IRBuilder::new(node, self.expr_arena, self.lp_arena)31}3233/// Adds IR and runs optimizations on its expressions (simplify, coerce, type-check).34pub fn add_alp_optimize_exprs<F>(self, f: F) -> PolarsResult<Self>35where36F: FnOnce(Node) -> IR,37{38let lp = f(self.root);39let ir_name = lp.name();4041let b = self.add_alp(lp);4243// Run the optimizer44let mut conversion_optimizer = ConversionOptimizer::new(true, true, true);45conversion_optimizer.fill_scratch(b.lp_arena.get(b.root).exprs(), b.expr_arena);46conversion_optimizer47.optimize_exprs(b.expr_arena, b.lp_arena, b.root, false)48.map_err(|e| e.context(format!("optimizing '{ir_name}' failed").into()))?;4950Ok(b)51}5253/// An escape hatch to add an `Expr`. Working with IR is preferred.54pub fn add_expr(&mut self, expr: Expr) -> PolarsResult<ExprIR> {55let schema = self.lp_arena.get(self.root).schema(self.lp_arena);56let mut ctx = ExprToIRContext::new(self.expr_arena, &schema);57to_expr_ir(expr, &mut ctx)58}5960pub fn project(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {61// if len == 0, no projection has to be done. This is a select all operation.62if exprs.is_empty() {63self64} else {65let input_schema = self.schema();66let schema = expr_irs_to_schema(&exprs, &input_schema, self.expr_arena)67.expect("no valid schema can be derived for the query");6869let lp = IR::Select {70expr: exprs,71input: self.root,72schema: Arc::new(schema),73options,74};75let node = self.lp_arena.add(lp);76IRBuilder::new(node, self.expr_arena, self.lp_arena)77}78}7980pub fn project_simple_nodes<I, N>(self, nodes: I) -> PolarsResult<Self>81where82I: IntoIterator<Item = N>,83N: Into<Node>,84I::IntoIter: ExactSizeIterator,85{86let names = nodes87.into_iter()88.map(|node| match self.expr_arena.get(node.into()) {89AExpr::Column(name) => name,90_ => unreachable!(),91});92// This is a duplication of `project_simple` because we already borrow self.expr_arena :/93if names.size_hint().0 == 0 {94Ok(self)95} else {96let input_schema = self.schema();97let mut count = 0;98let schema = names99.map(|name| {100let dtype = input_schema.try_get(name)?;101count += 1;102Ok(Field::new(name.clone(), dtype.clone()))103})104.collect::<PolarsResult<Schema>>()?;105106polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");107108let lp = IR::SimpleProjection {109input: self.root,110columns: Arc::new(schema),111};112let node = self.lp_arena.add(lp);113Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))114}115}116117pub fn project_simple<I, S>(self, names: I) -> PolarsResult<Self>118where119I: IntoIterator<Item = S>,120I::IntoIter: ExactSizeIterator,121S: Into<PlSmallStr>,122{123let names = names.into_iter();124// if len == 0, no projection has to be done. This is a select all operation.125if names.size_hint().0 == 0 {126Ok(self)127} else {128let input_schema = self.schema();129let mut count = 0;130let schema = names131.map(|name| {132let name: PlSmallStr = name.into();133let dtype = input_schema.try_get(name.as_str())?;134count += 1;135Ok(Field::new(name, dtype.clone()))136})137.collect::<PolarsResult<Schema>>()?;138139polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");140141let lp = IR::SimpleProjection {142input: self.root,143columns: Arc::new(schema),144};145let node = self.lp_arena.add(lp);146Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))147}148}149150pub fn drop<I, S>(self, names: I) -> Self151where152I: IntoIterator<Item = S>,153I::IntoIter: ExactSizeIterator,154S: Into<PlSmallStr>,155{156let names = names.into_iter();157// if len == 0, no projection has to be done. This is a select all operation.158if names.size_hint().0 == 0 {159self160} else {161let mut schema = self.schema().as_ref().as_ref().clone();162163for name in names {164let name: PlSmallStr = name.into();165schema.remove(&name);166}167168let lp = IR::SimpleProjection {169input: self.root,170columns: Arc::new(schema),171};172let node = self.lp_arena.add(lp);173IRBuilder::new(node, self.expr_arena, self.lp_arena)174}175}176177pub fn sort(178self,179by_column: Vec<ExprIR>,180slice: Option<(i64, usize)>,181sort_options: SortMultipleOptions,182) -> Self {183let ir = IR::Sort {184input: self.root,185by_column,186slice,187sort_options,188};189let node = self.lp_arena.add(ir);190IRBuilder::new(node, self.expr_arena, self.lp_arena)191}192193pub fn node(self) -> Node {194self.root195}196197pub fn build(self) -> IR {198if self.root.0 == self.lp_arena.len() {199self.lp_arena.pop().unwrap()200} else {201self.lp_arena.take(self.root)202}203}204205pub fn schema(&'a self) -> Cow<'a, SchemaRef> {206self.lp_arena.get(self.root).schema(self.lp_arena)207}208209pub fn with_columns(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {210let schema = self.schema();211let mut new_schema = (**schema).clone();212213let hstack_schema = expr_irs_to_schema(&exprs, &schema, self.expr_arena)214.expect("no valid schema can be derived for the query");215new_schema.merge(hstack_schema);216217let lp = IR::HStack {218input: self.root,219exprs,220schema: Arc::new(new_schema),221options,222};223self.add_alp(lp)224}225226pub fn with_columns_simple<I, J: Into<Node>>(self, exprs: I, options: ProjectionOptions) -> Self227where228I: IntoIterator<Item = J>,229{230let schema = self.schema();231let mut new_schema = (**schema).clone();232233let iter = exprs.into_iter();234let mut expr_irs = Vec::with_capacity(iter.size_hint().0);235for node in iter {236let node = node.into();237let field = self238.expr_arena239.get(node)240.to_field(&ToFieldContext::new(self.expr_arena, &schema))241.unwrap();242243expr_irs.push(244ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))245.with_dtype(field.dtype.clone()),246);247new_schema.with_column(field.name().clone(), field.dtype().clone());248}249250let lp = IR::HStack {251input: self.root,252exprs: expr_irs,253schema: Arc::new(new_schema),254options,255};256self.add_alp(lp)257}258259// call this if the schema needs to be updated260pub fn explode(self, columns: Arc<[PlSmallStr]>, options: ExplodeOptions) -> Self {261let lp = IR::MapFunction {262input: self.root,263function: FunctionIR::Explode {264columns,265options,266schema: Default::default(),267},268};269self.add_alp(lp)270}271272pub fn group_by(273self,274keys: Vec<ExprIR>,275aggs: Vec<ExprIR>,276apply: Option<PlanCallback<DataFrame, DataFrame>>,277maintain_order: bool,278options: Arc<GroupbyOptions>,279) -> Self {280let current_schema = self.schema();281let mut schema = expr_irs_to_schema(&keys, ¤t_schema, self.expr_arena)282.expect("no valid schema can be derived for the key expression");283284#[cfg(feature = "dynamic_group_by")]285{286if let Some(options) = options.rolling.as_ref() {287let name = &options.index_column;288let dtype = current_schema.get(name).unwrap();289schema.with_column(name.clone(), dtype.clone());290} else if let Some(options) = options.dynamic.as_ref() {291let name = &options.index_column;292let dtype = current_schema.get(name).unwrap();293if options.include_boundaries {294schema.with_column("_lower_boundary".into(), dtype.clone());295schema.with_column("_upper_boundary".into(), dtype.clone());296}297schema.with_column(name.clone(), dtype.clone());298}299}300301let mut aggs_schema = expr_irs_to_schema(&aggs, ¤t_schema, self.expr_arena)302.expect("no valid schema can be derived for the agg expression");303304// Coerce aggregation column(s) into List unless not needed (auto-implode)305debug_assert!(aggs_schema.len() == aggs.len());306for ((_name, dtype), expr) in aggs_schema.iter_mut().zip(&aggs) {307if !expr.is_scalar(self.expr_arena) {308*dtype = dtype.clone().implode();309}310}311312schema.merge(aggs_schema);313314let lp = IR::GroupBy {315input: self.root,316keys,317aggs,318schema: Arc::new(schema),319apply,320maintain_order,321options,322};323self.add_alp(lp)324}325326pub fn join(327self,328other: Node,329left_on: Vec<ExprIR>,330right_on: Vec<ExprIR>,331options: Arc<JoinOptionsIR>,332) -> Self {333let schema_left = self.schema();334let schema_right = self.lp_arena.get(other).schema(self.lp_arena);335336let schema = det_join_schema(337&schema_left,338&schema_right,339&left_on,340&right_on,341&options,342self.expr_arena,343)344.unwrap();345346let lp = IR::Join {347input_left: self.root,348input_right: other,349schema,350left_on,351right_on,352options,353};354355self.add_alp(lp)356}357358#[cfg(feature = "pivot")]359pub fn unpivot(self, args: Arc<UnpivotArgsIR>) -> Self {360let lp = IR::MapFunction {361input: self.root,362function: FunctionIR::Unpivot {363args,364schema: Default::default(),365},366};367self.add_alp(lp)368}369370pub fn row_index(self, name: PlSmallStr, offset: Option<IdxSize>) -> Self {371let lp = IR::MapFunction {372input: self.root,373function: FunctionIR::RowIndex {374name,375offset,376schema: Default::default(),377},378};379self.add_alp(lp)380}381382pub fn hint(self, hint: HintIR) -> Self {383let lp = IR::MapFunction {384input: self.root,385function: FunctionIR::Hint(hint),386};387self.add_alp(lp)388}389}390391392