Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-python/src/lazyframe/visit.rs
7889 views
1
use std::sync::{Arc, Mutex};
2
3
use polars::prelude::PolarsError;
4
use polars::prelude::python_dsl::PythonScanSource;
5
use polars_plan::plans::{ExprToIRContext, IR, ToFieldContext, to_expr_ir};
6
use polars_plan::prelude::expr_ir::ExprIR;
7
use polars_plan::prelude::{AExpr, PythonOptions};
8
use polars_utils::arena::{Arena, Node};
9
use pyo3::prelude::*;
10
use pyo3::types::{PyDict, PyList};
11
12
use super::PyLazyFrame;
13
use super::visitor::{expr_nodes, nodes};
14
use crate::error::PyPolarsErr;
15
use crate::{PyExpr, Wrap, raise_err};
16
17
#[derive(Clone)]
18
#[pyclass(frozen)]
19
pub struct PyExprIR {
20
#[pyo3(get)]
21
node: usize,
22
#[pyo3(get)]
23
output_name: String,
24
}
25
26
impl From<ExprIR> for PyExprIR {
27
fn from(value: ExprIR) -> Self {
28
Self {
29
node: value.node().0,
30
output_name: value.output_name().to_string(),
31
}
32
}
33
}
34
35
impl From<&ExprIR> for PyExprIR {
36
fn from(value: &ExprIR) -> Self {
37
Self {
38
node: value.node().0,
39
output_name: value.output_name().to_string(),
40
}
41
}
42
}
43
44
type Version = (u16, u16);
45
46
#[pyclass]
47
pub struct NodeTraverser {
48
root: Node,
49
lp_arena: Arc<Mutex<Arena<IR>>>,
50
expr_arena: Arc<Mutex<Arena<AExpr>>>,
51
scratch: Vec<Node>,
52
expr_scratch: Vec<ExprIR>,
53
expr_mapping: Option<Vec<Node>>,
54
}
55
56
impl NodeTraverser {
57
// Versioning for IR, (major, minor)
58
// Increment major on breaking changes to the IR (e.g. renaming
59
// fields, reordering tuples), minor on backwards compatible
60
// changes (e.g. exposing a new expression node).
61
const VERSION: Version = (12, 0);
62
63
pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
64
Self {
65
root,
66
lp_arena: Arc::new(Mutex::new(lp_arena)),
67
expr_arena: Arc::new(Mutex::new(expr_arena)),
68
scratch: vec![],
69
expr_scratch: vec![],
70
expr_mapping: None,
71
}
72
}
73
74
#[allow(clippy::type_complexity)]
75
pub fn get_arenas(&self) -> (Arc<Mutex<Arena<IR>>>, Arc<Mutex<Arena<AExpr>>>) {
76
(self.lp_arena.clone(), self.expr_arena.clone())
77
}
78
79
fn fill_inputs(&mut self) {
80
let lp_arena = self.lp_arena.lock().unwrap();
81
let this_node = lp_arena.get(self.root);
82
self.scratch.clear();
83
this_node.copy_inputs(&mut self.scratch);
84
}
85
86
fn fill_expressions(&mut self) {
87
let lp_arena = self.lp_arena.lock().unwrap();
88
let this_node = lp_arena.get(self.root);
89
self.expr_scratch.clear();
90
this_node.copy_exprs(&mut self.expr_scratch);
91
}
92
93
fn scratch_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
94
PyList::new(py, self.scratch.drain(..).map(|node| node.0))
95
}
96
97
fn expr_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
98
PyList::new(
99
py,
100
self.expr_scratch
101
.drain(..)
102
.map(|e| PyExprIR::from(e).into_pyobject(py).unwrap()),
103
)
104
}
105
}
106
107
#[pymethods]
108
impl NodeTraverser {
109
/// Get expression nodes
110
fn get_exprs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
111
self.fill_expressions();
112
self.expr_to_list(py)
113
}
114
115
/// Get input nodes
116
fn get_inputs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {
117
self.fill_inputs();
118
self.scratch_to_list(py)
119
}
120
121
/// The current version of the IR
122
fn version(&self) -> Version {
123
NodeTraverser::VERSION
124
}
125
126
/// Get Schema of current node as python dict<str, pl.DataType>
127
fn get_schema<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
128
let lp_arena = self.lp_arena.lock().unwrap();
129
let schema = lp_arena.get(self.root).schema(&lp_arena);
130
Wrap((**schema).clone()).into_pyobject(py)
131
}
132
133
/// Get expression dtype of expr_node, the schema used is that of the current root node
134
fn get_dtype<'py>(&self, expr_node: usize, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
135
let expr_node = Node(expr_node);
136
let lp_arena = self.lp_arena.lock().unwrap();
137
let schema = lp_arena.get(self.root).schema(&lp_arena);
138
let expr_arena = self.expr_arena.lock().unwrap();
139
let field = expr_arena
140
.get(expr_node)
141
.to_field(&ToFieldContext::new(&expr_arena, &schema))
142
.map_err(PyPolarsErr::from)?;
143
Wrap(field.dtype).into_pyobject(py)
144
}
145
146
/// Set the current node in the plan.
147
fn set_node(&mut self, node: usize) {
148
self.root = Node(node);
149
}
150
151
/// Get the current node in the plan.
152
fn get_node(&mut self) -> usize {
153
self.root.0
154
}
155
156
/// Set a python UDF that will replace the subtree location with this function src.
157
#[pyo3(signature = (function, is_pure = false))]
158
fn set_udf(&mut self, function: Py<PyAny>, is_pure: bool) {
159
let mut lp_arena = self.lp_arena.lock().unwrap();
160
let schema = lp_arena.get(self.root).schema(&lp_arena).into_owned();
161
let ir = IR::PythonScan {
162
options: PythonOptions {
163
scan_fn: Some(function.into()),
164
schema,
165
output_schema: None,
166
with_columns: None,
167
python_source: PythonScanSource::Cuda,
168
predicate: Default::default(),
169
n_rows: None,
170
validate_schema: false,
171
is_pure,
172
},
173
};
174
lp_arena.replace(self.root, ir);
175
}
176
177
fn view_current_node(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
178
let lp_arena = self.lp_arena.lock().unwrap();
179
let lp_node = lp_arena.get(self.root);
180
nodes::into_py(py, lp_node)
181
}
182
183
fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<Py<PyAny>> {
184
let expr_arena = self.expr_arena.lock().unwrap();
185
let n = match &self.expr_mapping {
186
Some(mapping) => *mapping.get(node).unwrap(),
187
None => Node(node),
188
};
189
let expr = expr_arena.get(n);
190
expr_nodes::into_py(py, expr)
191
}
192
193
/// Add some expressions to the arena and return their new node ids as well
194
/// as the total number of nodes in the arena.
195
fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {
196
let lp_arena = self.lp_arena.lock().unwrap();
197
let schema = lp_arena.get(self.root).schema(&lp_arena);
198
let mut expr_arena = self.expr_arena.lock().unwrap();
199
Ok((
200
expressions
201
.into_iter()
202
.map(|e| {
203
let mut ctx = ExprToIRContext::new(&mut expr_arena, &schema);
204
ctx.allow_unknown = true;
205
// NOTE: Probably throwing away the output names here is not okay?
206
to_expr_ir(e.inner, &mut ctx)
207
.map_err(PyPolarsErr::from)
208
.map(|v| v.node().0)
209
})
210
.collect::<Result<_, PyPolarsErr>>()?,
211
expr_arena.len(),
212
))
213
}
214
215
/// Set up a mapping of expression nodes used in `view_expression_node``.
216
/// With a mapping set, `view_expression_node(i)` produces the node for
217
/// `mapping[i]`.
218
fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {
219
if mapping.len() != self.expr_arena.lock().unwrap().len() {
220
raise_err!("Invalid mapping length", ComputeError);
221
}
222
self.expr_mapping = Some(mapping.into_iter().map(Node).collect());
223
Ok(())
224
}
225
226
/// Unset the expression mapping (reinstates the identity map)
227
fn unset_expr_mapping(&mut self) {
228
self.expr_mapping = None;
229
}
230
}
231
232
#[pymethods]
233
#[allow(clippy::should_implement_trait)]
234
impl PyLazyFrame {
235
fn visit(&self) -> PyResult<NodeTraverser> {
236
let mut lp_arena = Arena::with_capacity(16);
237
let mut expr_arena = Arena::with_capacity(16);
238
let root = self
239
.ldf
240
.read()
241
.clone()
242
.optimize(&mut lp_arena, &mut expr_arena)
243
.map_err(PyPolarsErr::from)?;
244
Ok(NodeTraverser {
245
root,
246
lp_arena: Arc::new(Mutex::new(lp_arena)),
247
expr_arena: Arc::new(Mutex::new(expr_arena)),
248
scratch: vec![],
249
expr_scratch: vec![],
250
expr_mapping: None,
251
})
252
}
253
}
254
255