Path: blob/main/crates/polars-mem-engine/src/executors/scan/python_scan.rs
8354 views
use polars_core::utils::accumulate_dataframes_vertical;1use pyo3::exceptions::PyStopIteration;2use pyo3::prelude::*;3use pyo3::types::{PyBytes, PyNone};4use pyo3::{IntoPyObjectExt, PyTypeInfo, intern};56use self::python_dsl::PythonScanSource;7use super::*;89pub(crate) struct PythonScanExec {10pub(crate) options: PythonOptions,11pub(crate) predicate: Option<Arc<dyn PhysicalExpr>>,12pub(crate) predicate_serialized: Option<Vec<u8>>,13}1415impl PythonScanExec {16/// Get the output schema. E.g. the schema the plugins produce, not consume.17fn get_schema(&self) -> &SchemaRef {18self.options19.output_schema20.as_ref()21.unwrap_or(&self.options.schema)22}2324fn check_schema(&self, df: &DataFrame) -> PolarsResult<()> {25if self.options.validate_schema {26let output_schema = self.get_schema();27polars_ensure!(df.schema() == output_schema, SchemaMismatch: "user provided schema: {:?} doesn't match the DataFrame schema: {:?}", output_schema, df.schema());28}29Ok(())30}3132fn finish_df(33&self,34py: Python,35df: Bound<'_, PyAny>,36state: &mut ExecutionState,37) -> PolarsResult<DataFrame> {38let df = python_df_to_rust(py, df)?;39py.detach(|| {40self.check_schema(&df)?;4142if let Some(pred) = &self.predicate {43let mask = pred.evaluate(&df, state)?;44df.filter(mask.bool()?)45} else {46Ok(df)47}48})49}50}5152impl Executor for PythonScanExec {53fn execute(&mut self, state: &mut ExecutionState) -> PolarsResult<DataFrame> {54state.should_stop()?;55#[cfg(debug_assertions)]56{57if state.verbose() {58eprintln!("run PythonScanExec")59}60}61let with_columns = self.options.with_columns.take();62let n_rows = self.options.n_rows.take();63Python::attach(|py| {64let pl = PyModule::import(py, intern!(py, "polars")).unwrap();65let utils = pl.getattr(intern!(py, "_utils")).unwrap();66let callable = utils.getattr(intern!(py, "_execute_from_rust")).unwrap();6768let python_scan_function = self.options.scan_fn.take().unwrap().0;6970let with_columns = with_columns.map(|cols| cols.iter().cloned().collect::<Vec<_>>());71let mut could_serialize_predicate = true;7273let predicate = match &self.options.predicate {74PythonPredicate::PyArrow(s) => s.into_bound_py_any(py).unwrap(),75PythonPredicate::None => None::<()>.into_bound_py_any(py).unwrap(),76PythonPredicate::Polars(_) => {77assert!(self.predicate.is_some(), "should be set");7879match &self.predicate_serialized {80None => {81could_serialize_predicate = false;82PyNone::get(py).to_owned().into_any()83},84Some(buf) => PyBytes::new(py, buf).into_any(),85}86},87};8889match self.options.python_source {90PythonScanSource::Cuda => {91let args = (92python_scan_function,93with_columns94.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),95predicate,96n_rows,97// If this boolean is true, callback should return98// a dataframe and list of timings [(start, end,99// name)]100state.has_node_timer(),101);102let result = callable.call1(args)?;103let df = if state.has_node_timer() {104let df = result.get_item(0);105let timing_info: Vec<(u64, u64, String)> = result.get_item(1)?.extract()?;106state.record_raw_timings(&timing_info);107df?108} else {109result110};111self.finish_df(py, df, state)112},113PythonScanSource::Pyarrow => {114let args = (115python_scan_function,116with_columns117.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),118predicate,119n_rows,120);121let df = callable.call1(args)?;122self.finish_df(py, df, state)123},124PythonScanSource::IOPlugin => {125// If there are filters, take smaller chunks to ensure we can keep memory126// pressure low.127let batch_size = if self.predicate.is_some() {128Some(100_000usize)129} else {130None131};132let args = (133python_scan_function,134with_columns135.map(|x| x.into_iter().map(|x| x.to_string()).collect::<Vec<_>>()),136predicate,137n_rows,138batch_size,139);140141let generator_init = callable.call1(args)?;142let generator = generator_init.get_item(0).map_err(143|_| polars_err!(ComputeError: "expected tuple got {}", generator_init),144)?;145let can_parse_predicate = generator_init.get_item(1).map_err(146|_| polars_err!(ComputeError: "expected tuple got {}", generator),147)?;148let can_parse_predicate = can_parse_predicate.extract::<bool>().map_err(149|_| polars_err!(ComputeError: "expected bool got {}", can_parse_predicate),150)? && could_serialize_predicate;151152let mut chunks = vec![];153loop {154match generator.call_method0(intern!(py, "__next__")) {155Ok(out) => {156let mut df = python_df_to_rust(py, out)?;157if let (Some(pred), false) = (&self.predicate, can_parse_predicate)158{159py.detach(|| {160let mask = pred.evaluate(&df, state)?;161df = df.filter(mask.bool()?)?;162PolarsResult::Ok(())163})?164}165chunks.push(df)166},167Err(err) if err.matches(py, PyStopIteration::type_object(py))? => break,168Err(err) => {169polars_bail!(ComputeError: "caught exception during execution of a Python source, exception: {}", err)170},171}172}173if chunks.is_empty() {174return Ok(DataFrame::empty_with_schema(self.get_schema().as_ref()));175}176let df = accumulate_dataframes_vertical(chunks)?;177178self.check_schema(&df)?;179Ok(df)180},181}182})183}184}185186187