Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-mem-engine/src/executors/scan/python_scan.rs
6940 views
1
use polars_core::utils::accumulate_dataframes_vertical;
2
use pyo3::exceptions::PyStopIteration;
3
use pyo3::prelude::*;
4
use pyo3::types::{PyBytes, PyNone};
5
use pyo3::{IntoPyObjectExt, PyTypeInfo, intern};
6
7
use self::python_dsl::PythonScanSource;
8
use super::*;
9
10
pub(crate) struct PythonScanExec {
11
pub(crate) options: PythonOptions,
12
pub(crate) predicate: Option<Arc<dyn PhysicalExpr>>,
13
pub(crate) predicate_serialized: Option<Vec<u8>>,
14
}
15
16
impl PythonScanExec {
17
/// Get the output schema. E.g. the schema the plugins produce, not consume.
18
fn get_schema(&self) -> &SchemaRef {
19
self.options
20
.output_schema
21
.as_ref()
22
.unwrap_or(&self.options.schema)
23
}
24
25
fn check_schema(&self, df: &DataFrame) -> PolarsResult<()> {
26
if self.options.validate_schema {
27
let output_schema = self.get_schema();
28
polars_ensure!(df.schema() == output_schema, SchemaMismatch: "user provided schema: {:?} doesn't match the DataFrame schema: {:?}", output_schema, df.schema());
29
}
30
Ok(())
31
}
32
33
fn finish_df(
34
&self,
35
py: Python,
36
df: Bound<'_, PyAny>,
37
state: &mut ExecutionState,
38
) -> PolarsResult<DataFrame> {
39
let df = python_df_to_rust(py, df)?;
40
41
self.check_schema(&df)?;
42
43
if let Some(pred) = &self.predicate {
44
let mask = pred.evaluate(&df, state)?;
45
df.filter(mask.bool()?)
46
} else {
47
Ok(df)
48
}
49
}
50
}
51
52
impl Executor for PythonScanExec {
53
fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult<DataFrame> {
54
state.should_stop()?;
55
#[cfg(debug_assertions)]
56
{
57
if state.verbose() {
58
eprintln!("run PythonScanExec")
59
}
60
}
61
let with_columns = self.options.with_columns.take();
62
let n_rows = self.options.n_rows.take();
63
Python::with_gil(|py| {
64
let pl = PyModule::import(py, intern!(py, "polars")).unwrap();
65
let utils = pl.getattr(intern!(py, "_utils")).unwrap();
66
let callable = utils.getattr(intern!(py, "_execute_from_rust")).unwrap();
67
68
let python_scan_function = self.options.scan_fn.take().unwrap().0;
69
70
let with_columns = with_columns.map(|cols| cols.iter().cloned().collect::<Vec<_>>());
71
let mut could_serialize_predicate = true;
72
73
let predicate = match &self.options.predicate {
74
PythonPredicate::PyArrow(s) => s.into_bound_py_any(py).unwrap(),
75
PythonPredicate::None => None::<()>.into_bound_py_any(py).unwrap(),
76
PythonPredicate::Polars(_) => {
77
assert!(self.predicate.is_some(), "should be set");
78
79
match &self.predicate_serialized {
80
None => {
81
could_serialize_predicate = false;
82
PyNone::get(py).to_owned().into_any()
83
},
84
Some(buf) => PyBytes::new(py, buf).into_any(),
85
}
86
},
87
};
88
89
match self.options.python_source {
90
PythonScanSource::Cuda => {
91
let args = (
92
python_scan_function,
93
with_columns
94
.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),
95
predicate,
96
n_rows,
97
// If this boolean is true, callback should return
98
// a dataframe and list of timings [(start, end,
99
// name)]
100
state.has_node_timer(),
101
);
102
let result = callable.call1(args)?;
103
let df = if state.has_node_timer() {
104
let df = result.get_item(0);
105
let timing_info: Vec<(u64, u64, String)> = result.get_item(1)?.extract()?;
106
state.record_raw_timings(&timing_info);
107
df?
108
} else {
109
result
110
};
111
self.finish_df(py, df, state)
112
},
113
PythonScanSource::Pyarrow => {
114
let args = (
115
python_scan_function,
116
with_columns
117
.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),
118
predicate,
119
n_rows,
120
);
121
let df = callable.call1(args)?;
122
self.finish_df(py, df, state)
123
},
124
PythonScanSource::IOPlugin => {
125
// If there are filters, take smaller chunks to ensure we can keep memory
126
// pressure low.
127
let batch_size = if self.predicate.is_some() {
128
Some(100_000usize)
129
} else {
130
None
131
};
132
let args = (
133
python_scan_function,
134
with_columns
135
.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),
136
predicate,
137
n_rows,
138
batch_size,
139
);
140
141
let generator_init = callable.call1(args)?;
142
let generator = generator_init.get_item(0).map_err(
143
|_| polars_err!(ComputeError: "expected tuple got {}", generator_init),
144
)?;
145
let can_parse_predicate = generator_init.get_item(1).map_err(
146
|_| polars_err!(ComputeError: "expected tuple got {}", generator),
147
)?;
148
let can_parse_predicate = can_parse_predicate.extract::<bool>().map_err(
149
|_| polars_err!(ComputeError: "expected bool got {}", can_parse_predicate),
150
)? && could_serialize_predicate;
151
152
let mut chunks = vec![];
153
loop {
154
match generator.call_method0(intern!(py, "__next__")) {
155
Ok(out) => {
156
let mut df = python_df_to_rust(py, out)?;
157
if let (Some(pred), false) = (&self.predicate, can_parse_predicate)
158
{
159
let mask = pred.evaluate(&df, state)?;
160
df = df.filter(mask.bool()?)?;
161
}
162
chunks.push(df)
163
},
164
Err(err) if err.matches(py, PyStopIteration::type_object(py))? => break,
165
Err(err) => {
166
polars_bail!(ComputeError: "caught exception during execution of a Python source, exception: {}", err)
167
},
168
}
169
}
170
if chunks.is_empty() {
171
return Ok(DataFrame::empty_with_schema(self.get_schema().as_ref()));
172
}
173
let df = accumulate_dataframes_vertical(chunks)?;
174
175
self.check_schema(&df)?;
176
Ok(df)
177
},
178
}
179
})
180
}
181
}
182
183