Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-python/src/cloud_server.rs
7884 views
1
use polars_core::error::{PolarsResult, polars_err};
2
use polars_expr::state::ExecutionState;
3
use polars_mem_engine::create_physical_plan;
4
use polars_plan::plans::{AExpr, IR, IRPlan};
5
use polars_plan::prelude::{Arena, Node};
6
use polars_utils::pl_serialize;
7
use pyo3::intern;
8
use pyo3::prelude::{PyAnyMethods, PyModule, Python, *};
9
use pyo3::types::IntoPyDict;
10
11
use crate::PyDataFrame;
12
use crate::error::PyPolarsErr;
13
use crate::lazyframe::visit::NodeTraverser;
14
use crate::utils::EnterPolarsExt;
15
16
/// Take a serialized `IRPlan` and execute it on the GPU engine.
17
///
18
/// This is done as a Python function because the `NodeTraverser` class created for this purpose
19
/// must exactly match the one expected by the `cudf_polars` package.
20
#[pyfunction]
21
pub fn _execute_ir_plan_with_gpu(ir_plan_ser: Vec<u8>, py: Python) -> PyResult<PyDataFrame> {
22
// Deserialize into IRPlan.
23
let mut ir_plan: IRPlan =
24
pl_serialize::deserialize_from_reader::<_, _, false>(ir_plan_ser.as_slice())
25
.map_err(PyPolarsErr::from)?;
26
27
// Edit for use with GPU engine.
28
gpu_post_opt(
29
py,
30
ir_plan.lp_top,
31
&mut ir_plan.lp_arena,
32
&mut ir_plan.expr_arena,
33
)
34
.map_err(PyPolarsErr::from)?;
35
36
// Convert to physical plan.
37
let mut physical_plan = create_physical_plan(
38
ir_plan.lp_top,
39
&mut ir_plan.lp_arena,
40
&mut ir_plan.expr_arena,
41
None,
42
)
43
.map_err(PyPolarsErr::from)?;
44
45
// Execute the plan.
46
let mut state = ExecutionState::new();
47
py.enter_polars_df(|| physical_plan.execute(&mut state))
48
}
49
50
/// Prepare the IR for execution by the Polars GPU engine.
51
fn gpu_post_opt(
52
py: Python<'_>,
53
root: Node,
54
lp_arena: &mut Arena<IR>,
55
expr_arena: &mut Arena<AExpr>,
56
) -> PolarsResult<()> {
57
// Get cuDF Python function.
58
let cudf = PyModule::import(py, intern!(py, "cudf_polars")).unwrap();
59
let lambda = cudf.getattr(intern!(py, "execute_with_cudf")).unwrap();
60
61
// Define cuDF config.
62
let polars = PyModule::import(py, intern!(py, "polars")).unwrap();
63
let engine = polars.getattr(intern!(py, "GPUEngine")).unwrap();
64
let kwargs = [("raise_on_fail", true)].into_py_dict(py).unwrap();
65
let engine = engine.call((), Some(&kwargs)).unwrap();
66
67
// Define node traverser.
68
let nt = NodeTraverser::new(root, std::mem::take(lp_arena), std::mem::take(expr_arena));
69
70
// Get a copy of the arenas.
71
let arenas = nt.get_arenas();
72
73
// Pass the node visitor which allows the Python callback to replace parts of the query plan.
74
// Remove "cuda" or specify better once we have multiple post-opt callbacks.
75
let kwargs = [("config", engine)].into_py_dict(py).unwrap();
76
lambda
77
.call((nt,), Some(&kwargs))
78
.map_err(|e| polars_err!(ComputeError: "'cuda' conversion failed: {}", e))?;
79
80
// Unpack the arena's.
81
// At this point the `nt` is useless.
82
std::mem::swap(lp_arena, &mut *arenas.0.lock().unwrap());
83
std::mem::swap(expr_arena, &mut *arenas.1.lock().unwrap());
84
85
Ok(())
86
}
87
88