Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
pola-rs
GitHub Repository: pola-rs/polars
Path: blob/main/crates/polars-mem-engine/src/executors/scan/python_scan.rs
8354 views
1
use polars_core::utils::accumulate_dataframes_vertical;
2
use pyo3::exceptions::PyStopIteration;
3
use pyo3::prelude::*;
4
use pyo3::types::{PyBytes, PyNone};
5
use pyo3::{IntoPyObjectExt, PyTypeInfo, intern};
6
7
use self::python_dsl::PythonScanSource;
8
use super::*;
9
10
pub(crate) struct PythonScanExec {
11
pub(crate) options: PythonOptions,
12
pub(crate) predicate: Option<Arc<dyn PhysicalExpr>>,
13
pub(crate) predicate_serialized: Option<Vec<u8>>,
14
}
15
16
impl PythonScanExec {
17
/// Get the output schema. E.g. the schema the plugins produce, not consume.
18
fn get_schema(&self) -> &SchemaRef {
19
self.options
20
.output_schema
21
.as_ref()
22
.unwrap_or(&self.options.schema)
23
}
24
25
fn check_schema(&self, df: &DataFrame) -> PolarsResult<()> {
26
if self.options.validate_schema {
27
let output_schema = self.get_schema();
28
polars_ensure!(df.schema() == output_schema, SchemaMismatch: "user provided schema: {:?} doesn't match the DataFrame schema: {:?}", output_schema, df.schema());
29
}
30
Ok(())
31
}
32
33
fn finish_df(
34
&self,
35
py: Python,
36
df: Bound<'_, PyAny>,
37
state: &mut ExecutionState,
38
) -> PolarsResult<DataFrame> {
39
let df = python_df_to_rust(py, df)?;
40
py.detach(|| {
41
self.check_schema(&df)?;
42
43
if let Some(pred) = &self.predicate {
44
let mask = pred.evaluate(&df, state)?;
45
df.filter(mask.bool()?)
46
} else {
47
Ok(df)
48
}
49
})
50
}
51
}
52
53
impl Executor for PythonScanExec {
54
fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult<DataFrame> {
55
state.should_stop()?;
56
#[cfg(debug_assertions)]
57
{
58
if state.verbose() {
59
eprintln!("run PythonScanExec")
60
}
61
}
62
let with_columns = self.options.with_columns.take();
63
let n_rows = self.options.n_rows.take();
64
Python::attach(|py| {
65
let pl = PyModule::import(py, intern!(py, "polars")).unwrap();
66
let utils = pl.getattr(intern!(py, "_utils")).unwrap();
67
let callable = utils.getattr(intern!(py, "_execute_from_rust")).unwrap();
68
69
let python_scan_function = self.options.scan_fn.take().unwrap().0;
70
71
let with_columns = with_columns.map(|cols| cols.iter().cloned().collect::<Vec<_>>());
72
let mut could_serialize_predicate = true;
73
74
let predicate = match &self.options.predicate {
75
PythonPredicate::PyArrow(s) => s.into_bound_py_any(py).unwrap(),
76
PythonPredicate::None => None::<()>.into_bound_py_any(py).unwrap(),
77
PythonPredicate::Polars(_) => {
78
assert!(self.predicate.is_some(), "should be set");
79
80
match &self.predicate_serialized {
81
None => {
82
could_serialize_predicate = false;
83
PyNone::get(py).to_owned().into_any()
84
},
85
Some(buf) => PyBytes::new(py, buf).into_any(),
86
}
87
},
88
};
89
90
match self.options.python_source {
91
PythonScanSource::Cuda => {
92
let args = (
93
python_scan_function,
94
with_columns
95
.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),
96
predicate,
97
n_rows,
98
// If this boolean is true, callback should return
99
// a dataframe and list of timings [(start, end,
100
// name)]
101
state.has_node_timer(),
102
);
103
let result = callable.call1(args)?;
104
let df = if state.has_node_timer() {
105
let df = result.get_item(0);
106
let timing_info: Vec<(u64, u64, String)> = result.get_item(1)?.extract()?;
107
state.record_raw_timings(&timing_info);
108
df?
109
} else {
110
result
111
};
112
self.finish_df(py, df, state)
113
},
114
PythonScanSource::Pyarrow => {
115
let args = (
116
python_scan_function,
117
with_columns
118
.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),
119
predicate,
120
n_rows,
121
);
122
let df = callable.call1(args)?;
123
self.finish_df(py, df, state)
124
},
125
PythonScanSource::IOPlugin => {
126
// If there are filters, take smaller chunks to ensure we can keep memory
127
// pressure low.
128
let batch_size = if self.predicate.is_some() {
129
Some(100_000usize)
130
} else {
131
None
132
};
133
let args = (
134
python_scan_function,
135
with_columns
136
.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),
137
predicate,
138
n_rows,
139
batch_size,
140
);
141
142
let generator_init = callable.call1(args)?;
143
let generator = generator_init.get_item(0).map_err(
144
|_| polars_err!(ComputeError: "expected tuple got {}", generator_init),
145
)?;
146
let can_parse_predicate = generator_init.get_item(1).map_err(
147
|_| polars_err!(ComputeError: "expected tuple got {}", generator),
148
)?;
149
let can_parse_predicate = can_parse_predicate.extract::<bool>().map_err(
150
|_| polars_err!(ComputeError: "expected bool got {}", can_parse_predicate),
151
)? && could_serialize_predicate;
152
153
let mut chunks = vec![];
154
loop {
155
match generator.call_method0(intern!(py, "__next__")) {
156
Ok(out) => {
157
let mut df = python_df_to_rust(py, out)?;
158
if let (Some(pred), false) = (&self.predicate, can_parse_predicate)
159
{
160
py.detach(|| {
161
let mask = pred.evaluate(&df, state)?;
162
df = df.filter(mask.bool()?)?;
163
PolarsResult::Ok(())
164
})?
165
}
166
chunks.push(df)
167
},
168
Err(err) if err.matches(py, PyStopIteration::type_object(py))? => break,
169
Err(err) => {
170
polars_bail!(ComputeError: "caught exception during execution of a Python source, exception: {}", err)
171
},
172
}
173
}
174
if chunks.is_empty() {
175
return Ok(DataFrame::empty_with_schema(self.get_schema().as_ref()));
176
}
177
let df = accumulate_dataframes_vertical(chunks)?;
178
179
self.check_schema(&df)?;
180
Ok(df)
181
},
182
}
183
})
184
}
185
}
186
187