Path: blob/main/crates/polars-python/src/lazyframe/visit.rs
7889 views
use std::sync::{Arc, Mutex};12use polars::prelude::PolarsError;3use polars::prelude::python_dsl::PythonScanSource;4use polars_plan::plans::{ExprToIRContext, IR, ToFieldContext, to_expr_ir};5use polars_plan::prelude::expr_ir::ExprIR;6use polars_plan::prelude::{AExpr, PythonOptions};7use polars_utils::arena::{Arena, Node};8use pyo3::prelude::*;9use pyo3::types::{PyDict, PyList};1011use super::PyLazyFrame;12use super::visitor::{expr_nodes, nodes};13use crate::error::PyPolarsErr;14use crate::{PyExpr, Wrap, raise_err};1516#[derive(Clone)]17#[pyclass(frozen)]18pub struct PyExprIR {19#[pyo3(get)]20node: usize,21#[pyo3(get)]22output_name: String,23}2425impl From<ExprIR> for PyExprIR {26fn from(value: ExprIR) -> Self {27Self {28node: value.node().0,29output_name: value.output_name().to_string(),30}31}32}3334impl From<&ExprIR> for PyExprIR {35fn from(value: &ExprIR) -> Self {36Self {37node: value.node().0,38output_name: value.output_name().to_string(),39}40}41}4243type Version = (u16, u16);4445#[pyclass]46pub struct NodeTraverser {47root: Node,48lp_arena: Arc<Mutex<Arena<IR>>>,49expr_arena: Arc<Mutex<Arena<AExpr>>>,50scratch: Vec<Node>,51expr_scratch: Vec<ExprIR>,52expr_mapping: Option<Vec<Node>>,53}5455impl NodeTraverser {56// Versioning for IR, (major, minor)57// Increment major on breaking changes to the IR (e.g. renaming58// fields, reordering tuples), minor on backwards compatible59// changes (e.g. exposing a new expression node).60const VERSION: Version = (12, 0);6162pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {63Self {64root,65lp_arena: Arc::new(Mutex::new(lp_arena)),66expr_arena: Arc::new(Mutex::new(expr_arena)),67scratch: vec![],68expr_scratch: vec![],69expr_mapping: None,70}71}7273#[allow(clippy::type_complexity)]74pub fn get_arenas(&self) -> (Arc<Mutex<Arena<IR>>>, Arc<Mutex<Arena<AExpr>>>) {75(self.lp_arena.clone(), self.expr_arena.clone())76}7778fn fill_inputs(&mut self) {79let lp_arena = self.lp_arena.lock().unwrap();80let this_node = lp_arena.get(self.root);81self.scratch.clear();82this_node.copy_inputs(&mut self.scratch);83}8485fn fill_expressions(&mut self) {86let lp_arena = self.lp_arena.lock().unwrap();87let this_node = lp_arena.get(self.root);88self.expr_scratch.clear();89this_node.copy_exprs(&mut self.expr_scratch);90}9192fn scratch_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {93PyList::new(py, self.scratch.drain(..).map(|node| node.0))94}9596fn expr_to_list<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {97PyList::new(98py,99self.expr_scratch100.drain(..)101.map(|e| PyExprIR::from(e).into_pyobject(py).unwrap()),102)103}104}105106#[pymethods]107impl NodeTraverser {108/// Get expression nodes109fn get_exprs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {110self.fill_expressions();111self.expr_to_list(py)112}113114/// Get input nodes115fn get_inputs<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyList>> {116self.fill_inputs();117self.scratch_to_list(py)118}119120/// The current version of the IR121fn version(&self) -> Version {122NodeTraverser::VERSION123}124125/// Get Schema of current node as python dict<str, pl.DataType>126fn get_schema<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {127let lp_arena = self.lp_arena.lock().unwrap();128let schema = lp_arena.get(self.root).schema(&lp_arena);129Wrap((**schema).clone()).into_pyobject(py)130}131132/// Get expression dtype of expr_node, the schema used is that of the current root node133fn get_dtype<'py>(&self, expr_node: usize, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {134let expr_node = Node(expr_node);135let lp_arena = self.lp_arena.lock().unwrap();136let schema = lp_arena.get(self.root).schema(&lp_arena);137let expr_arena = self.expr_arena.lock().unwrap();138let field = expr_arena139.get(expr_node)140.to_field(&ToFieldContext::new(&expr_arena, &schema))141.map_err(PyPolarsErr::from)?;142Wrap(field.dtype).into_pyobject(py)143}144145/// Set the current node in the plan.146fn set_node(&mut self, node: usize) {147self.root = Node(node);148}149150/// Get the current node in the plan.151fn get_node(&mut self) -> usize {152self.root.0153}154155/// Set a python UDF that will replace the subtree location with this function src.156#[pyo3(signature = (function, is_pure = false))]157fn set_udf(&mut self, function: Py<PyAny>, is_pure: bool) {158let mut lp_arena = self.lp_arena.lock().unwrap();159let schema = lp_arena.get(self.root).schema(&lp_arena).into_owned();160let ir = IR::PythonScan {161options: PythonOptions {162scan_fn: Some(function.into()),163schema,164output_schema: None,165with_columns: None,166python_source: PythonScanSource::Cuda,167predicate: Default::default(),168n_rows: None,169validate_schema: false,170is_pure,171},172};173lp_arena.replace(self.root, ir);174}175176fn view_current_node(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {177let lp_arena = self.lp_arena.lock().unwrap();178let lp_node = lp_arena.get(self.root);179nodes::into_py(py, lp_node)180}181182fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<Py<PyAny>> {183let expr_arena = self.expr_arena.lock().unwrap();184let n = match &self.expr_mapping {185Some(mapping) => *mapping.get(node).unwrap(),186None => Node(node),187};188let expr = expr_arena.get(n);189expr_nodes::into_py(py, expr)190}191192/// Add some expressions to the arena and return their new node ids as well193/// as the total number of nodes in the arena.194fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {195let lp_arena = self.lp_arena.lock().unwrap();196let schema = lp_arena.get(self.root).schema(&lp_arena);197let mut expr_arena = self.expr_arena.lock().unwrap();198Ok((199expressions200.into_iter()201.map(|e| {202let mut ctx = ExprToIRContext::new(&mut expr_arena, &schema);203ctx.allow_unknown = true;204// NOTE: Probably throwing away the output names here is not okay?205to_expr_ir(e.inner, &mut ctx)206.map_err(PyPolarsErr::from)207.map(|v| v.node().0)208})209.collect::<Result<_, PyPolarsErr>>()?,210expr_arena.len(),211))212}213214/// Set up a mapping of expression nodes used in `view_expression_node``.215/// With a mapping set, `view_expression_node(i)` produces the node for216/// `mapping[i]`.217fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {218if mapping.len() != self.expr_arena.lock().unwrap().len() {219raise_err!("Invalid mapping length", ComputeError);220}221self.expr_mapping = Some(mapping.into_iter().map(Node).collect());222Ok(())223}224225/// Unset the expression mapping (reinstates the identity map)226fn unset_expr_mapping(&mut self) {227self.expr_mapping = None;228}229}230231#[pymethods]232#[allow(clippy::should_implement_trait)]233impl PyLazyFrame {234fn visit(&self) -> PyResult<NodeTraverser> {235let mut lp_arena = Arena::with_capacity(16);236let mut expr_arena = Arena::with_capacity(16);237let root = self238.ldf239.read()240.clone()241.optimize(&mut lp_arena, &mut expr_arena)242.map_err(PyPolarsErr::from)?;243Ok(NodeTraverser {244root,245lp_arena: Arc::new(Mutex::new(lp_arena)),246expr_arena: Arc::new(Mutex::new(expr_arena)),247scratch: vec![],248expr_scratch: vec![],249expr_mapping: None,250})251}252}253254255