Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/builder_ir.rs
8430 views
1
use std::borrow::Cow;
2
3
use super::*;
4
5
pub struct IRBuilder<'a> {
6
root: Node,
7
expr_arena: &'a mut Arena<AExpr>,
8
lp_arena: &'a mut Arena<IR>,
9
}
10
11
impl<'a> IRBuilder<'a> {
12
pub fn new(root: Node, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {
13
IRBuilder {
14
root,
15
expr_arena,
16
lp_arena,
17
}
18
}
19
20
pub fn from_lp(lp: IR, expr_arena: &'a mut Arena<AExpr>, lp_arena: &'a mut Arena<IR>) -> Self {
21
let root = lp_arena.add(lp);
22
IRBuilder {
23
root,
24
expr_arena,
25
lp_arena,
26
}
27
}
28
29
pub fn add_alp(self, lp: IR) -> Self {
30
let node = self.lp_arena.add(lp);
31
IRBuilder::new(node, self.expr_arena, self.lp_arena)
32
}
33
34
/// Adds IR and runs optimizations on its expressions (simplify, coerce, type-check).
35
pub fn add_alp_optimize_exprs<F>(self, f: F) -> PolarsResult<Self>
36
where
37
F: FnOnce(Node) -> IR,
38
{
39
let lp = f(self.root);
40
let ir_name = lp.name();
41
42
let b = self.add_alp(lp);
43
44
// Run the optimizer
45
let mut conversion_optimizer = ConversionOptimizer::new(true, true, true);
46
conversion_optimizer.fill_scratch(b.lp_arena.get(b.root).exprs(), b.expr_arena);
47
conversion_optimizer
48
.optimize_exprs(b.expr_arena, b.lp_arena, b.root, false)
49
.map_err(|e| e.context(format!("optimizing '{ir_name}' failed").into()))?;
50
51
Ok(b)
52
}
53
54
/// An escape hatch to add an `Expr`. Working with IR is preferred.
55
pub fn add_expr(&mut self, expr: Expr) -> PolarsResult<ExprIR> {
56
let schema = self.lp_arena.get(self.root).schema(self.lp_arena);
57
let mut ctx = ExprToIRContext::new(self.expr_arena, &schema);
58
to_expr_ir(expr, &mut ctx)
59
}
60
61
pub fn project(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {
62
// if len == 0, no projection has to be done. This is a select all operation.
63
if exprs.is_empty() {
64
self
65
} else {
66
let input_schema = self.schema();
67
let schema = expr_irs_to_schema(&exprs, &input_schema, self.expr_arena)
68
.expect("no valid schema can be derived for the query");
69
70
let lp = IR::Select {
71
expr: exprs,
72
input: self.root,
73
schema: Arc::new(schema),
74
options,
75
};
76
let node = self.lp_arena.add(lp);
77
IRBuilder::new(node, self.expr_arena, self.lp_arena)
78
}
79
}
80
81
pub fn project_simple_nodes<I, N>(self, nodes: I) -> PolarsResult<Self>
82
where
83
I: IntoIterator<Item = N>,
84
N: Into<Node>,
85
I::IntoIter: ExactSizeIterator,
86
{
87
let names = nodes
88
.into_iter()
89
.map(|node| match self.expr_arena.get(node.into()) {
90
AExpr::Column(name) => name,
91
_ => unreachable!(),
92
});
93
// This is a duplication of `project_simple` because we already borrow self.expr_arena :/
94
if names.size_hint().0 == 0 {
95
Ok(self)
96
} else {
97
let input_schema = self.schema();
98
let mut count = 0;
99
let schema = names
100
.map(|name| {
101
let dtype = input_schema.try_get(name)?;
102
count += 1;
103
Ok(Field::new(name.clone(), dtype.clone()))
104
})
105
.collect::<PolarsResult<Schema>>()?;
106
107
polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");
108
109
let lp = IR::SimpleProjection {
110
input: self.root,
111
columns: Arc::new(schema),
112
};
113
let node = self.lp_arena.add(lp);
114
Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))
115
}
116
}
117
118
pub fn project_simple<I, S>(self, names: I) -> PolarsResult<Self>
119
where
120
I: IntoIterator<Item = S>,
121
I::IntoIter: ExactSizeIterator,
122
S: Into<PlSmallStr>,
123
{
124
let names = names.into_iter();
125
// if len == 0, no projection has to be done. This is a select all operation.
126
if names.size_hint().0 == 0 {
127
Ok(self)
128
} else {
129
let input_schema = self.schema();
130
let mut count = 0;
131
let schema = names
132
.map(|name| {
133
let name: PlSmallStr = name.into();
134
let dtype = input_schema.try_get(name.as_str())?;
135
count += 1;
136
Ok(Field::new(name, dtype.clone()))
137
})
138
.collect::<PolarsResult<Schema>>()?;
139
140
polars_ensure!(count == schema.len(), Duplicate: "found duplicate columns");
141
142
let lp = IR::SimpleProjection {
143
input: self.root,
144
columns: Arc::new(schema),
145
};
146
let node = self.lp_arena.add(lp);
147
Ok(IRBuilder::new(node, self.expr_arena, self.lp_arena))
148
}
149
}
150
151
pub fn drop<I, S>(self, names: I) -> Self
152
where
153
I: IntoIterator<Item = S>,
154
I::IntoIter: ExactSizeIterator,
155
S: Into<PlSmallStr>,
156
{
157
let names = names.into_iter();
158
// if len == 0, no projection has to be done. This is a select all operation.
159
if names.size_hint().0 == 0 {
160
self
161
} else {
162
let mut schema = self.schema().as_ref().as_ref().clone();
163
164
for name in names {
165
let name: PlSmallStr = name.into();
166
schema.remove(&name);
167
}
168
169
let lp = IR::SimpleProjection {
170
input: self.root,
171
columns: Arc::new(schema),
172
};
173
let node = self.lp_arena.add(lp);
174
IRBuilder::new(node, self.expr_arena, self.lp_arena)
175
}
176
}
177
178
pub fn sort(
179
self,
180
by_column: Vec<ExprIR>,
181
slice: Option<(i64, usize)>,
182
sort_options: SortMultipleOptions,
183
) -> Self {
184
let ir = IR::Sort {
185
input: self.root,
186
by_column,
187
slice,
188
sort_options,
189
};
190
let node = self.lp_arena.add(ir);
191
IRBuilder::new(node, self.expr_arena, self.lp_arena)
192
}
193
194
pub fn node(self) -> Node {
195
self.root
196
}
197
198
pub fn build(self) -> IR {
199
if self.root.0 == self.lp_arena.len() {
200
self.lp_arena.pop().unwrap()
201
} else {
202
self.lp_arena.take(self.root)
203
}
204
}
205
206
pub fn schema(&'a self) -> Cow<'a, SchemaRef> {
207
self.lp_arena.get(self.root).schema(self.lp_arena)
208
}
209
210
pub fn with_columns(self, exprs: Vec<ExprIR>, options: ProjectionOptions) -> Self {
211
let schema = self.schema();
212
let mut new_schema = (**schema).clone();
213
214
let hstack_schema = expr_irs_to_schema(&exprs, &schema, self.expr_arena)
215
.expect("no valid schema can be derived for the query");
216
new_schema.merge(hstack_schema);
217
218
let lp = IR::HStack {
219
input: self.root,
220
exprs,
221
schema: Arc::new(new_schema),
222
options,
223
};
224
self.add_alp(lp)
225
}
226
227
pub fn with_columns_simple<I, J: Into<Node>>(self, exprs: I, options: ProjectionOptions) -> Self
228
where
229
I: IntoIterator<Item = J>,
230
{
231
let schema = self.schema();
232
let mut new_schema = (**schema).clone();
233
234
let iter = exprs.into_iter();
235
let mut expr_irs = Vec::with_capacity(iter.size_hint().0);
236
for node in iter {
237
let node = node.into();
238
let field = self
239
.expr_arena
240
.get(node)
241
.to_field(&ToFieldContext::new(self.expr_arena, &schema))
242
.unwrap();
243
244
expr_irs.push(
245
ExprIR::new(node, OutputName::ColumnLhs(field.name.clone()))
246
.with_dtype(field.dtype.clone()),
247
);
248
new_schema.with_column(field.name().clone(), field.dtype().clone());
249
}
250
251
let lp = IR::HStack {
252
input: self.root,
253
exprs: expr_irs,
254
schema: Arc::new(new_schema),
255
options,
256
};
257
self.add_alp(lp)
258
}
259
260
// call this if the schema needs to be updated
261
pub fn explode(self, columns: Arc<[PlSmallStr]>, options: ExplodeOptions) -> Self {
262
let lp = IR::MapFunction {
263
input: self.root,
264
function: FunctionIR::Explode {
265
columns,
266
options,
267
schema: Default::default(),
268
},
269
};
270
self.add_alp(lp)
271
}
272
273
pub fn group_by(
274
self,
275
keys: Vec<ExprIR>,
276
aggs: Vec<ExprIR>,
277
apply: Option<PlanCallback<DataFrame, DataFrame>>,
278
maintain_order: bool,
279
options: Arc<GroupbyOptions>,
280
) -> Self {
281
let current_schema = self.schema();
282
let mut schema = expr_irs_to_schema(&keys, &current_schema, self.expr_arena)
283
.expect("no valid schema can be derived for the key expression");
284
285
#[cfg(feature = "dynamic_group_by")]
286
{
287
if let Some(options) = options.rolling.as_ref() {
288
let name = &options.index_column;
289
let dtype = current_schema.get(name).unwrap();
290
schema.with_column(name.clone(), dtype.clone());
291
} else if let Some(options) = options.dynamic.as_ref() {
292
let name = &options.index_column;
293
let dtype = current_schema.get(name).unwrap();
294
if options.include_boundaries {
295
schema.with_column("_lower_boundary".into(), dtype.clone());
296
schema.with_column("_upper_boundary".into(), dtype.clone());
297
}
298
schema.with_column(name.clone(), dtype.clone());
299
}
300
}
301
302
let mut aggs_schema = expr_irs_to_schema(&aggs, &current_schema, self.expr_arena)
303
.expect("no valid schema can be derived for the agg expression");
304
305
// Coerce aggregation column(s) into List unless not needed (auto-implode)
306
debug_assert!(aggs_schema.len() == aggs.len());
307
for ((_name, dtype), expr) in aggs_schema.iter_mut().zip(&aggs) {
308
if !expr.is_scalar(self.expr_arena) {
309
*dtype = dtype.clone().implode();
310
}
311
}
312
313
schema.merge(aggs_schema);
314
315
let lp = IR::GroupBy {
316
input: self.root,
317
keys,
318
aggs,
319
schema: Arc::new(schema),
320
apply,
321
maintain_order,
322
options,
323
};
324
self.add_alp(lp)
325
}
326
327
pub fn join(
328
self,
329
other: Node,
330
left_on: Vec<ExprIR>,
331
right_on: Vec<ExprIR>,
332
options: Arc<JoinOptionsIR>,
333
) -> Self {
334
let schema_left = self.schema();
335
let schema_right = self.lp_arena.get(other).schema(self.lp_arena);
336
337
let schema = det_join_schema(
338
&schema_left,
339
&schema_right,
340
&left_on,
341
&right_on,
342
&options,
343
self.expr_arena,
344
)
345
.unwrap();
346
347
let lp = IR::Join {
348
input_left: self.root,
349
input_right: other,
350
schema,
351
left_on,
352
right_on,
353
options,
354
};
355
356
self.add_alp(lp)
357
}
358
359
#[cfg(feature = "pivot")]
360
pub fn unpivot(self, args: Arc<UnpivotArgsIR>) -> Self {
361
let lp = IR::MapFunction {
362
input: self.root,
363
function: FunctionIR::Unpivot {
364
args,
365
schema: Default::default(),
366
},
367
};
368
self.add_alp(lp)
369
}
370
371
pub fn row_index(self, name: PlSmallStr, offset: Option<IdxSize>) -> Self {
372
let lp = IR::MapFunction {
373
input: self.root,
374
function: FunctionIR::RowIndex {
375
name,
376
offset,
377
schema: Default::default(),
378
},
379
};
380
self.add_alp(lp)
381
}
382
383
pub fn hint(self, hint: HintIR) -> Self {
384
let lp = IR::MapFunction {
385
input: self.root,
386
function: FunctionIR::Hint(hint),
387
};
388
self.add_alp(lp)
389
}
390
}
391
392