Path: blob/main/crates/polars-plan/src/plans/python/pyarrow.rs
6940 views
use std::fmt::Write;12use polars_core::datatypes::AnyValue;3use polars_core::prelude::{TimeUnit, TimeZone};45use crate::prelude::*;67#[derive(Default, Copy, Clone)]8pub struct PyarrowArgs {9// pyarrow doesn't allow `filter([True, False])`10// but does allow `filter(field("a").isin([True, False]))`11allow_literal_series: bool,12}1314fn to_py_datetime(v: i64, tu: &TimeUnit, tz: Option<&TimeZone>) -> String {15// note: `to_py_datetime` and the `Datetime`16// dtype have to be in-scope on the python side17match tz {18None => format!("to_py_datetime({},'{}')", v, tu.to_ascii()),19Some(tz) => format!("to_py_datetime({},'{}',{})", v, tu.to_ascii(), tz),20}21}2223fn sanitize(name: &str) -> Option<&str> {24if name.chars().all(|c| match c {25' ' => true,26'-' => true,27'_' => true,28c => c.is_alphanumeric(),29}) {30Some(name)31} else {32None33}34}3536// convert to a pyarrow expression that can be evaluated with pythons eval37pub fn predicate_to_pa(38predicate: Node,39expr_arena: &Arena<AExpr>,40args: PyarrowArgs,41) -> Option<String> {42match expr_arena.get(predicate) {43AExpr::BinaryExpr { left, right, op } => {44if op.is_comparison_or_bitwise() {45let left = predicate_to_pa(*left, expr_arena, args)?;46let right = predicate_to_pa(*right, expr_arena, args)?;47Some(format!("({left} {op} {right})"))48} else {49None50}51},52AExpr::Column(name) => {53let name = sanitize(name)?;54Some(format!("pa.compute.field('{name}')"))55},56AExpr::Literal(LiteralValue::Series(s)) => {57if !args.allow_literal_series || s.is_empty() || s.len() > 100 {58None59} else {60let mut list_repr = String::with_capacity(s.len() * 5);61list_repr.push('[');62for av in s.rechunk().iter() {63match av {64AnyValue::Boolean(v) => {65let s = if v { "True" } else { "False" };66write!(list_repr, "{s},").unwrap();67},68#[cfg(feature = "dtype-datetime")]69AnyValue::Datetime(v, tu, tz) => {70let dtm = to_py_datetime(v, &tu, tz);71write!(list_repr, "{dtm},").unwrap();72},73#[cfg(feature = "dtype-date")]74AnyValue::Date(v) => {75write!(list_repr, "to_py_date({v}),").unwrap();76},77AnyValue::String(s) => {78let _ = sanitize(s)?;79write!(list_repr, "{av},").unwrap();80},81// Hard to sanitize82AnyValue::Binary(_)83| AnyValue::Struct(_, _, _)84| AnyValue::List(_)85| AnyValue::Array(_, _) => return None,86_ => {87write!(list_repr, "{av},").unwrap();88},89}90}91// pop last comma92list_repr.pop();93list_repr.push(']');94Some(list_repr)95}96},97AExpr::Literal(lv) => {98let av = lv.to_any_value()?;99let dtype = av.dtype();100match av.as_borrowed() {101AnyValue::String(s) => {102let s = sanitize(s)?;103Some(format!("'{s}'"))104},105AnyValue::Boolean(val) => {106// python bools are capitalized107if val {108Some("pa.compute.scalar(True)".to_string())109} else {110Some("pa.compute.scalar(False)".to_string())111}112},113#[cfg(feature = "dtype-date")]114AnyValue::Date(v) => {115// the function `to_py_date` and the `Date`116// dtype have to be in scope on the python side117Some(format!("to_py_date({v})"))118},119#[cfg(feature = "dtype-datetime")]120AnyValue::Datetime(v, tu, tz) => Some(to_py_datetime(v, &tu, tz)),121// Hard to sanitize122AnyValue::Binary(_)123| AnyValue::Struct(_, _, _)124| AnyValue::List(_)125| AnyValue::Array(_, _) => None,126// Activate once pyarrow supports them127// #[cfg(feature = "dtype-time")]128// AnyValue::Time(v) => {129// // the function `to_py_time` has to be in scope130// // on the python side131// Some(format!("to_py_time(value={v})"))132// }133// #[cfg(feature = "dtype-duration")]134// AnyValue::Duration(v, tu) => {135// // the function `to_py_timedelta` has to be in scope136// // on the python side137// Some(format!(138// "to_py_timedelta(value={}, tu='{}')",139// v,140// tu.to_ascii()141// ))142// }143av => {144if dtype.is_float() {145let val = av.extract::<f64>()?;146Some(format!("{val}"))147} else if dtype.is_integer() {148let val = av.extract::<i64>()?;149Some(format!("{val}"))150} else {151None152}153},154}155},156#[cfg(feature = "is_in")]157AExpr::Function {158function: IRFunctionExpr::Boolean(IRBooleanFunction::IsIn { .. }),159input,160..161} => {162let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?;163let mut args = args;164args.allow_literal_series = true;165let values = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?;166167Some(format!("({col}).isin({values})"))168},169#[cfg(feature = "is_between")]170AExpr::Function {171function: IRFunctionExpr::Boolean(IRBooleanFunction::IsBetween { closed }),172input,173..174} => {175if !matches!(expr_arena.get(input.first()?.node()), AExpr::Column(_)) {176None177} else {178let col = predicate_to_pa(input.first()?.node(), expr_arena, args)?;179let left_cmp_op = match closed {180ClosedInterval::None | ClosedInterval::Right => Operator::Gt,181ClosedInterval::Both | ClosedInterval::Left => Operator::GtEq,182};183let right_cmp_op = match closed {184ClosedInterval::None | ClosedInterval::Left => Operator::Lt,185ClosedInterval::Both | ClosedInterval::Right => Operator::LtEq,186};187188let lower = predicate_to_pa(input.get(1)?.node(), expr_arena, args)?;189let upper = predicate_to_pa(input.get(2)?.node(), expr_arena, args)?;190191Some(format!(192"(({col} {left_cmp_op} {lower}) & ({col} {right_cmp_op} {upper}))"193))194}195},196AExpr::Function {197function, input, ..198} => {199let input = input.first().unwrap().node();200let input = predicate_to_pa(input, expr_arena, args)?;201202match function {203IRFunctionExpr::Boolean(IRBooleanFunction::Not) => Some(format!("~({input})")),204IRFunctionExpr::Boolean(IRBooleanFunction::IsNull) => {205Some(format!("({input}).is_null()"))206},207IRFunctionExpr::Boolean(IRBooleanFunction::IsNotNull) => {208Some(format!("~({input}).is_null()"))209},210_ => None,211}212},213_ => None,214}215}216217218