Path: blob/main/crates/polars-mem-engine/src/executors/scan/python_scan.rs
6940 views
use polars_core::utils::accumulate_dataframes_vertical;1use pyo3::exceptions::PyStopIteration;2use pyo3::prelude::*;3use pyo3::types::{PyBytes, PyNone};4use pyo3::{IntoPyObjectExt, PyTypeInfo, intern};56use self::python_dsl::PythonScanSource;7use super::*;89pub(crate) struct PythonScanExec {10pub(crate) options: PythonOptions,11pub(crate) predicate: Option<Arc<dyn PhysicalExpr>>,12pub(crate) predicate_serialized: Option<Vec<u8>>,13}1415impl PythonScanExec {16/// Get the output schema. E.g. the schema the plugins produce, not consume.17fn get_schema(&self) -> &SchemaRef {18self.options19.output_schema20.as_ref()21.unwrap_or(&self.options.schema)22}2324fn check_schema(&self, df: &DataFrame) -> PolarsResult<()> {25if self.options.validate_schema {26let output_schema = self.get_schema();27polars_ensure!(df.schema() == output_schema, SchemaMismatch: "user provided schema: {:?} doesn't match the DataFrame schema: {:?}", output_schema, df.schema());28}29Ok(())30}3132fn finish_df(33&self,34py: Python,35df: Bound<'_, PyAny>,36state: &mut ExecutionState,37) -> PolarsResult<DataFrame> {38let df = python_df_to_rust(py, df)?;3940self.check_schema(&df)?;4142if let Some(pred) = &self.predicate {43let mask = pred.evaluate(&df, state)?;44df.filter(mask.bool()?)45} else {46Ok(df)47}48}49}5051impl Executor for PythonScanExec {52fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult<DataFrame> {53state.should_stop()?;54#[cfg(debug_assertions)]55{56if state.verbose() {57eprintln!("run PythonScanExec")58}59}60let with_columns = self.options.with_columns.take();61let n_rows = self.options.n_rows.take();62Python::with_gil(|py| {63let pl = PyModule::import(py, intern!(py, "polars")).unwrap();64let utils = pl.getattr(intern!(py, "_utils")).unwrap();65let callable = utils.getattr(intern!(py, "_execute_from_rust")).unwrap();6667let python_scan_function = self.options.scan_fn.take().unwrap().0;6869let with_columns = with_columns.map(|cols| cols.iter().cloned().collect::<Vec<_>>());70let mut could_serialize_predicate = true;7172let predicate = match &self.options.predicate {73PythonPredicate::PyArrow(s) => s.into_bound_py_any(py).unwrap(),74PythonPredicate::None => None::<()>.into_bound_py_any(py).unwrap(),75PythonPredicate::Polars(_) => {76assert!(self.predicate.is_some(), "should be set");7778match &self.predicate_serialized {79None => {80could_serialize_predicate = false;81PyNone::get(py).to_owned().into_any()82},83Some(buf) => PyBytes::new(py, buf).into_any(),84}85},86};8788match self.options.python_source {89PythonScanSource::Cuda => {90let args = (91python_scan_function,92with_columns93.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),94predicate,95n_rows,96// If this boolean is true, callback should return97// a dataframe and list of timings [(start, end,98// name)]99state.has_node_timer(),100);101let result = callable.call1(args)?;102let df = if state.has_node_timer() {103let df = result.get_item(0);104let timing_info: Vec<(u64, u64, String)> = result.get_item(1)?.extract()?;105state.record_raw_timings(&timing_info);106df?107} else {108result109};110self.finish_df(py, df, state)111},112PythonScanSource::Pyarrow => {113let args = (114python_scan_function,115with_columns116.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),117predicate,118n_rows,119);120let df = callable.call1(args)?;121self.finish_df(py, df, state)122},123PythonScanSource::IOPlugin => {124// If there are filters, take smaller chunks to ensure we can keep memory125// pressure low.126let batch_size = if self.predicate.is_some() {127Some(100_000usize)128} else {129None130};131let args = (132python_scan_function,133with_columns134.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),135predicate,136n_rows,137batch_size,138);139140let generator_init = callable.call1(args)?;141let generator = generator_init.get_item(0).map_err(142|_| polars_err!(ComputeError: "expected tuple got {}", generator_init),143)?;144let can_parse_predicate = generator_init.get_item(1).map_err(145|_| polars_err!(ComputeError: "expected tuple got {}", generator),146)?;147let can_parse_predicate = can_parse_predicate.extract::<bool>().map_err(148|_| polars_err!(ComputeError: "expected bool got {}", can_parse_predicate),149)? && could_serialize_predicate;150151let mut chunks = vec![];152loop {153match generator.call_method0(intern!(py, "__next__")) {154Ok(out) => {155let mut df = python_df_to_rust(py, out)?;156if let (Some(pred), false) = (&self.predicate, can_parse_predicate)157{158let mask = pred.evaluate(&df, state)?;159df = df.filter(mask.bool()?)?;160}161chunks.push(df)162},163Err(err) if err.matches(py, PyStopIteration::type_object(py))? => break,164Err(err) => {165polars_bail!(ComputeError: "caught exception during execution of a Python source, exception: {}", err)166},167}168}169if chunks.is_empty() {170return Ok(DataFrame::empty_with_schema(self.get_schema().as_ref()));171}172let df = accumulate_dataframes_vertical(chunks)?;173174self.check_schema(&df)?;175Ok(df)176},177}178})179}180}181182183