Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-plan/src/plans/expr_ir.rs
8424 views
1
use std::borrow::{Borrow, BorrowMut};
2
use std::hash::Hash;
3
#[cfg(feature = "cse")]
4
use std::hash::Hasher;
5
use std::sync::OnceLock;
6
7
use polars_utils::format_pl_smallstr;
8
#[cfg(feature = "ir_serde")]
9
use serde::{Deserialize, Serialize};
10
11
use super::*;
12
use crate::constants::{get_len_name, get_literal_name, get_pl_element_name};
13
14
#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)]
15
#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
16
pub enum OutputName {
17
/// No not yet set.
18
#[default]
19
None,
20
/// The most left-hand-side literal will be the output name.
21
LiteralLhs(PlSmallStr),
22
/// The most left-hand-side column will be the output name.
23
ColumnLhs(PlSmallStr),
24
/// Rename the output as `PlSmallStr`.
25
Alias(PlSmallStr),
26
#[cfg(feature = "dtype-struct")]
27
/// A struct field.
28
Field(PlSmallStr),
29
}
30
31
impl OutputName {
32
pub fn get(&self) -> Option<&PlSmallStr> {
33
match self {
34
OutputName::Alias(name) => Some(name),
35
OutputName::ColumnLhs(name) => Some(name),
36
OutputName::LiteralLhs(name) => Some(name),
37
#[cfg(feature = "dtype-struct")]
38
OutputName::Field(name) => Some(name),
39
OutputName::None => None,
40
}
41
}
42
43
pub fn unwrap(&self) -> &PlSmallStr {
44
self.get().expect("no output name set")
45
}
46
47
pub fn into_inner(self) -> Option<PlSmallStr> {
48
match self {
49
OutputName::Alias(name) => Some(name),
50
OutputName::ColumnLhs(name) => Some(name),
51
OutputName::LiteralLhs(name) => Some(name),
52
#[cfg(feature = "dtype-struct")]
53
OutputName::Field(name) => Some(name),
54
OutputName::None => None,
55
}
56
}
57
58
pub(crate) fn is_none(&self) -> bool {
59
matches!(self, OutputName::None)
60
}
61
}
62
63
#[derive(Clone, Debug)]
64
#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
65
pub struct ExprIR {
66
/// Output name of this expression.
67
output_name: OutputName,
68
/// Output dtype of this expression
69
/// Reduced expression.
70
/// This expression is pruned from `alias` and already expanded.
71
node: Node,
72
#[cfg_attr(feature = "ir_serde", serde(skip))]
73
output_dtype: OnceLock<DataType>,
74
}
75
76
impl Eq for ExprIR {}
77
78
impl PartialEq for ExprIR {
79
fn eq(&self, other: &Self) -> bool {
80
self.node == other.node && self.output_name == other.output_name
81
}
82
}
83
84
impl Borrow<Node> for ExprIR {
85
fn borrow(&self) -> &Node {
86
&self.node
87
}
88
}
89
90
impl BorrowMut<Node> for ExprIR {
91
fn borrow_mut(&mut self) -> &mut Node {
92
&mut self.node
93
}
94
}
95
96
impl Borrow<Node> for &ExprIR {
97
fn borrow(&self) -> &Node {
98
&self.node
99
}
100
}
101
102
impl ExprIR {
103
pub fn new(node: Node, output_name: OutputName) -> Self {
104
debug_assert!(!output_name.is_none());
105
ExprIR {
106
output_name,
107
node,
108
output_dtype: OnceLock::new(),
109
}
110
}
111
112
pub fn from_column_name(name: PlSmallStr, expr_arena: &mut Arena<AExpr>) -> Self {
113
let node = expr_arena.add(AExpr::Column(name.clone()));
114
ExprIR::new(node, OutputName::ColumnLhs(name))
115
}
116
117
pub fn with_dtype(self, dtype: DataType) -> Self {
118
let _ = self.output_dtype.set(dtype);
119
self
120
}
121
122
pub(crate) fn set_dtype(&mut self, dtype: DataType) {
123
self.output_dtype = OnceLock::from(dtype);
124
}
125
126
pub fn from_node(node: Node, arena: &Arena<AExpr>) -> Self {
127
let mut out = Self {
128
node,
129
output_name: OutputName::None,
130
output_dtype: OnceLock::new(),
131
};
132
out.node = node;
133
for (_, ae) in arena.iter(node) {
134
match ae {
135
AExpr::Element => {
136
out.output_name = OutputName::ColumnLhs(get_pl_element_name());
137
break;
138
},
139
AExpr::Column(name) => {
140
out.output_name = OutputName::ColumnLhs(name.clone());
141
break;
142
},
143
#[cfg(feature = "dtype-struct")]
144
AExpr::StructField(name) => {
145
out.output_name = OutputName::Field(name.clone());
146
break;
147
},
148
AExpr::Literal(lv) => {
149
if let LiteralValue::Series(s) = lv {
150
out.output_name = OutputName::LiteralLhs(s.name().clone());
151
} else {
152
out.output_name = OutputName::LiteralLhs(get_literal_name());
153
}
154
break;
155
},
156
AExpr::Function {
157
input, function, ..
158
} => {
159
match function {
160
#[cfg(feature = "dtype-struct")]
161
IRFunctionExpr::StructExpr(IRStructFunction::FieldByName(name)) => {
162
out.output_name = OutputName::Field(name.clone());
163
},
164
_ => {
165
if input.is_empty() {
166
out.output_name =
167
OutputName::LiteralLhs(format_pl_smallstr!("{}", function));
168
} else {
169
out.output_name = input[0].output_name.clone();
170
}
171
},
172
}
173
break;
174
},
175
AExpr::AnonymousFunction { input, fmt_str, .. } => {
176
if input.is_empty() {
177
out.output_name = OutputName::LiteralLhs(fmt_str.as_ref().clone());
178
} else {
179
out.output_name = input[0].output_name.clone();
180
}
181
break;
182
},
183
AExpr::Len => {
184
out.output_name = OutputName::LiteralLhs(get_len_name());
185
break;
186
},
187
_ => {},
188
}
189
}
190
debug_assert!(!out.output_name.is_none());
191
out
192
}
193
194
#[inline]
195
pub fn node(&self) -> Node {
196
self.node
197
}
198
199
/// Create a `ExprIR` structure that implements display
200
pub fn display<'a>(&'a self, expr_arena: &'a Arena<AExpr>) -> ExprIRDisplay<'a> {
201
ExprIRDisplay {
202
node: self.node(),
203
output_name: self.output_name_inner(),
204
expr_arena,
205
}
206
}
207
208
pub fn set_node(&mut self, node: Node) {
209
self.node = node;
210
self.output_dtype = OnceLock::new();
211
}
212
213
pub(crate) fn set_alias(&mut self, name: PlSmallStr) {
214
self.output_name = OutputName::Alias(name)
215
}
216
217
pub fn with_alias(&self, name: PlSmallStr) -> Self {
218
Self {
219
output_name: OutputName::Alias(name),
220
node: self.node,
221
output_dtype: self.output_dtype.clone(),
222
}
223
}
224
225
pub(crate) fn set_columnlhs(&mut self, name: PlSmallStr) {
226
debug_assert!(matches!(
227
self.output_name,
228
OutputName::ColumnLhs(_) | OutputName::None
229
));
230
self.output_name = OutputName::ColumnLhs(name)
231
}
232
233
pub fn output_name_inner(&self) -> &OutputName {
234
&self.output_name
235
}
236
237
pub fn output_name(&self) -> &PlSmallStr {
238
self.output_name.unwrap()
239
}
240
241
pub fn to_expr(&self, expr_arena: &Arena<AExpr>) -> Expr {
242
let out = node_to_expr(self.node, expr_arena);
243
244
match &self.output_name {
245
OutputName::Alias(name) if expr_arena.get(self.node).to_name(expr_arena) != name => {
246
out.alias(name.clone())
247
},
248
_ => out,
249
}
250
}
251
252
pub fn get_alias(&self) -> Option<&PlSmallStr> {
253
match &self.output_name {
254
OutputName::Alias(name) => Some(name),
255
_ => None,
256
}
257
}
258
259
// Utility for debugging.
260
#[cfg(debug_assertions)]
261
#[allow(dead_code)]
262
pub(crate) fn print(&self, expr_arena: &Arena<AExpr>) {
263
eprintln!("{:?}", self.to_expr(expr_arena))
264
}
265
266
pub(crate) fn has_alias(&self) -> bool {
267
matches!(self.output_name, OutputName::Alias(_))
268
}
269
270
#[cfg(feature = "cse")]
271
pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
272
traverse_and_hash_aexpr(self.node, expr_arena, state);
273
if let Some(alias) = self.get_alias() {
274
alias.hash(state)
275
}
276
}
277
278
pub fn is_scalar(&self, expr_arena: &Arena<AExpr>) -> bool {
279
is_scalar_ae(self.node, expr_arena)
280
}
281
282
pub fn is_length_preserving(&self, expr_arena: &Arena<AExpr>) -> bool {
283
is_length_preserving_ae(self.node, expr_arena)
284
}
285
286
pub fn dtype(&self, schema: &Schema, expr_arena: &Arena<AExpr>) -> PolarsResult<&DataType> {
287
match self.output_dtype.get() {
288
Some(dtype) => Ok(dtype),
289
None => {
290
let dtype = expr_arena
291
.get(self.node)
292
.to_dtype(&ToFieldContext::new(expr_arena, schema))?;
293
let _ = self.output_dtype.set(dtype);
294
Ok(self.output_dtype.get().unwrap())
295
},
296
}
297
}
298
299
pub fn field(&self, schema: &Schema, expr_arena: &Arena<AExpr>) -> PolarsResult<Field> {
300
let dtype = self.dtype(schema, expr_arena)?;
301
let name = self.output_name();
302
Ok(Field::new(name.clone(), dtype.clone()))
303
}
304
305
pub fn into_inner(self) -> (Node, OutputName) {
306
(self.node, self.output_name)
307
}
308
}
309
310
impl AsRef<ExprIR> for ExprIR {
311
fn as_ref(&self) -> &ExprIR {
312
self
313
}
314
}
315
316
/// A Node that is restricted to `AExpr::Column`
317
#[repr(transparent)]
318
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
319
pub struct ColumnNode(pub(crate) Node);
320
321
impl From<ColumnNode> for Node {
322
fn from(value: ColumnNode) -> Self {
323
value.0
324
}
325
}
326
impl From<&ExprIR> for Node {
327
fn from(value: &ExprIR) -> Self {
328
value.node()
329
}
330
}
331
332
pub(crate) fn name_to_expr_ir(name: PlSmallStr, expr_arena: &mut Arena<AExpr>) -> ExprIR {
333
ExprIR::from_column_name(name, expr_arena)
334
}
335
336
pub(crate) fn names_to_expr_irs<I, S>(names: I, expr_arena: &mut Arena<AExpr>) -> Vec<ExprIR>
337
where
338
I: IntoIterator<Item = S>,
339
S: Into<PlSmallStr>,
340
{
341
names
342
.into_iter()
343
.map(|name| {
344
let name = name.into();
345
name_to_expr_ir(name, expr_arena)
346
})
347
.collect()
348
}
349
350